Coverage for src / gwtransport / fronttracking / solver.py: 81%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-04 18:51 +0000

1""" 

2Front Tracking Solver - Event-Driven Simulation Engine. 

3 

4======================================================== 

5 

6This module implements the main event-driven front tracking solver for 

7nonlinear sorption transport. The solver maintains a list of waves and 

8processes collision events chronologically using exact analytical calculations. 

9 

10The algorithm: 

111. Initialize waves from inlet boundary conditions 

122. Find next event (earliest collision or outlet crossing) 

133. Advance time to event 

144. Handle event (create new waves, deactivate old ones) 

155. Repeat until no more events 

16 

17All calculations are exact analytical with machine precision. 

18""" 

19 

20import logging 

21from dataclasses import dataclass 

22from heapq import heappop, heappush 

23from typing import Optional 

24 

25import numpy as np 

26import pandas as pd 

27 

28from gwtransport.fronttracking.events import ( 

29 Event, 

30 EventType, 

31 find_characteristic_intersection, 

32 find_outlet_crossing, 

33 find_rarefaction_boundary_intersections, 

34 find_shock_characteristic_intersection, 

35 find_shock_shock_intersection, 

36) 

37from gwtransport.fronttracking.handlers import ( 

38 create_inlet_waves_at_time, 

39 handle_characteristic_collision, 

40 handle_flow_change, 

41 handle_outlet_crossing, 

42 handle_rarefaction_characteristic_collision, 

43 handle_rarefaction_rarefaction_collision, 

44 handle_shock_characteristic_collision, 

45 handle_shock_collision, 

46 handle_shock_rarefaction_collision, 

47) 

48from gwtransport.fronttracking.math import ( 

49 ConstantRetardation, 

50 FreundlichSorption, 

51 compute_first_front_arrival_time, 

52) 

53 

54# Import mass balance functions for runtime verification (High Priority #3) 

55from gwtransport.fronttracking.output import ( 

56 compute_cumulative_inlet_mass, 

57 compute_cumulative_outlet_mass, 

58 compute_domain_mass, 

59) 

60from gwtransport.fronttracking.waves import ( 

61 CharacteristicWave, 

62 RarefactionWave, 

63 ShockWave, 

64 Wave, 

65) 

66 

67# Numerical tolerance constants 

68EPSILON_CONCENTRATION = 1e-15 # Tolerance for concentration changes 

69MIN_EVENT_DATA_LENGTH = 5 # Minimum length of event_data tuple before accessing extra field 

70 

71 

72@dataclass 

73class FrontTrackerState: 

74 """ 

75 Complete state of the front tracking simulation. 

76 

77 This dataclass holds all information about the current simulation state, 

78 including all waves (active and inactive), event history, and simulation 

79 parameters. 

80 

81 Parameters 

82 ---------- 

83 waves : list of Wave 

84 All waves created during simulation (includes inactive waves) 

85 events : list of dict 

86 Event history with details about each event 

87 t_current : float 

88 Current simulation time [days from tedges[0]] 

89 v_outlet : float 

90 Outlet position [m³] 

91 sorption : FreundlichSorption or ConstantRetardation 

92 Sorption parameters 

93 cin : numpy.ndarray 

94 Inlet concentration time series [mass/volume] 

95 flow : numpy.ndarray 

96 Flow rate time series [m³/day] 

97 tedges : pandas.DatetimeIndex 

98 Time bin edges 

99 

100 Examples 

101 -------- 

102 :: 

103 

104 state = FrontTrackerState( 

105 waves=[], 

106 events=[], 

107 t_current=0.0, 

108 v_outlet=500.0, 

109 sorption=sorption, 

110 cin=cin, 

111 flow=flow, 

112 tedges=tedges, 

113 ) 

114 """ 

115 

116 waves: list[Wave] 

117 events: list[dict] 

118 t_current: float 

119 v_outlet: float 

120 sorption: FreundlichSorption | ConstantRetardation 

121 cin: np.ndarray 

122 flow: np.ndarray 

123 tedges: pd.DatetimeIndex 

124 

125 

126class FrontTracker: 

127 """ 

128 Event-driven front tracking solver for nonlinear sorption transport. 

129 

130 This is the main simulation engine that orchestrates wave propagation, 

131 event detection, and event handling. The solver maintains a list of waves 

132 and processes collision events chronologically. 

133 

134 Parameters 

135 ---------- 

136 cin : numpy.ndarray 

137 Inlet concentration time series [mass/volume] 

138 flow : numpy.ndarray 

139 Flow rate time series [m³/day] 

140 tedges : numpy.ndarray 

141 Time bin edges [days] 

142 aquifer_pore_volume : float 

143 Total pore volume [m³] 

144 sorption : FreundlichSorption or ConstantRetardation 

145 Sorption parameters 

146 

147 Attributes 

148 ---------- 

149 state : FrontTrackerState 

150 Complete simulation state 

151 t_first_arrival : float 

152 First arrival time (end of spin-up period) [days] 

153 

154 Examples 

155 -------- 

156 :: 

157 

158 tracker = FrontTracker( 

159 cin=cin, 

160 flow=flow, 

161 tedges=tedges, 

162 aquifer_pore_volume=500.0, 

163 sorption=sorption, 

164 ) 

165 tracker.run(max_iterations=1000) 

166 # Access results 

167 print(f"Total events: {len(tracker.state.events)}") 

168 print(f"Active waves: {sum(1 for w in tracker.state.waves if w.is_active)}") 

169 

170 Notes 

171 ----- 

172 The solver uses exact analytical calculations throughout with no numerical 

173 tolerances or iterative methods. All wave interactions are detected and 

174 handled with machine precision. 

175 

176 The spin-up period (t < t_first_arrival) is affected by unknown initial 

177 conditions. Results are only valid for t >= t_first_arrival. 

178 """ 

179 

180 def __init__( 

181 self, 

182 cin: np.ndarray, 

183 flow: np.ndarray, 

184 tedges: pd.DatetimeIndex, 

185 aquifer_pore_volume: float, 

186 sorption: FreundlichSorption | ConstantRetardation, 

187 ): 

188 """ 

189 Initialize tracker with inlet conditions and physical parameters. 

190 

191 Parameters 

192 ---------- 

193 cin : numpy.ndarray 

194 Inlet concentration time series [mass/volume] 

195 flow : numpy.ndarray 

196 Flow rate time series [m³/day] 

197 tedges : numpy.ndarray 

198 Time bin edges [days] 

199 aquifer_pore_volume : float 

200 Total pore volume [m³] 

201 sorption : FreundlichSorption or ConstantRetardation 

202 Sorption parameters 

203 

204 Raises 

205 ------ 

206 ValueError 

207 If input arrays have incompatible lengths or invalid values 

208 """ 

209 # Validation 

210 if len(tedges) != len(cin) + 1: 

211 msg = f"tedges must have length len(cin) + 1, got {len(tedges)} vs {len(cin) + 1}" 

212 raise ValueError(msg) 

213 if len(flow) != len(cin): 

214 msg = f"flow must have same length as cin, got {len(flow)} vs {len(cin)}" 

215 raise ValueError(msg) 

216 if np.any(cin < 0): 

217 msg = "cin must be non-negative" 

218 raise ValueError(msg) 

219 if np.any(flow <= 0): 

220 msg = "flow must be positive" 

221 raise ValueError(msg) 

222 if aquifer_pore_volume <= 0: 

223 msg = "aquifer_pore_volume must be positive" 

224 raise ValueError(msg) 

225 

226 # Initialize state 

227 # t_current is in days from tedges[0], so it starts at 0.0 

228 self.state = FrontTrackerState( 

229 waves=[], 

230 events=[], 

231 t_current=0.0, 

232 v_outlet=aquifer_pore_volume, 

233 sorption=sorption, 

234 cin=cin.copy(), 

235 flow=flow.copy(), 

236 tedges=tedges.copy(), 

237 ) 

238 

239 # Compute spin-up period 

240 self.t_first_arrival = compute_first_front_arrival_time(cin, flow, tedges, aquifer_pore_volume, sorption) 

241 

242 # Detect flow changes for event scheduling 

243 self._flow_change_schedule = self._detect_flow_changes() 

244 

245 # Initialize waves from inlet boundary conditions 

246 self._initialize_inlet_waves() 

247 

248 def _initialize_inlet_waves(self): 

249 """ 

250 Initialize all waves from inlet boundary conditions. 

251 

252 Creates waves at each inlet concentration change by analyzing 

253 characteristic velocities and creating appropriate wave types. 

254 """ 

255 c_prev = 0.0 # Assume domain initially at zero 

256 

257 for i in range(len(self.state.cin)): 

258 c_new = float(self.state.cin[i]) 

259 # Convert tedges[i] (Timestamp) to days from tedges[0] 

260 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1) 

261 flow_current = float(self.state.flow[i]) 

262 

263 if abs(c_new - c_prev) > EPSILON_CONCENTRATION: 

264 # Create wave(s) for this concentration change 

265 new_waves = create_inlet_waves_at_time( 

266 c_prev=c_prev, 

267 c_new=c_new, 

268 t=t_change, 

269 flow=flow_current, 

270 sorption=self.state.sorption, 

271 v_inlet=0.0, 

272 ) 

273 self.state.waves.extend(new_waves) 

274 

275 c_prev = c_new 

276 

277 def _detect_flow_changes(self) -> list[tuple[float, float]]: 

278 """ 

279 Detect all flow changes in the inlet time series. 

280 

281 Scans the flow array and identifies time points where flow changes. 

282 These become scheduled events that will update all wave velocities. 

283 

284 Returns 

285 ------- 

286 list of tuple 

287 List of (t_change, flow_new) tuples sorted by time. 

288 Times are in days from tedges[0]. 

289 

290 Notes 

291 ----- 

292 Flow changes are detected by comparing consecutive flow values. 

293 Only significant changes (>1e-15) are included. 

294 

295 Examples 

296 -------- 

297 >>> # flow = [100, 100, 200, 50] at tedges = [0, 10, 20, 30, 40] days 

298 >>> # Returns: [(20.0, 200.0), (30.0, 50.0)] 

299 """ 

300 flow_changes = [] 

301 epsilon_flow = 1e-15 

302 

303 for i in range(1, len(self.state.flow)): 

304 if abs(self.state.flow[i] - self.state.flow[i - 1]) > epsilon_flow: 

305 # Convert tedges[i] to days from tedges[0] 

306 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1) 

307 flow_new = self.state.flow[i] 

308 flow_changes.append((t_change, flow_new)) 

309 

310 return flow_changes 

311 

312 def find_next_event(self) -> Optional[Event]: 

313 """ 

314 Find the next event (earliest in time). 

315 

316 Searches all possible wave interactions and returns the earliest event. 

317 Uses a priority queue (min-heap) to efficiently find the minimum time. 

318 

319 Returns 

320 ------- 

321 Event or None 

322 Next event to process, or None if no future events 

323 

324 Notes 

325 ----- 

326 Checks for: 

327 - Characteristic-characteristic collisions 

328 - Shock-shock collisions 

329 - Shock-characteristic collisions 

330 - Rarefaction-characteristic collisions (head/tail) 

331 - Shock-rarefaction collisions (head/tail) 

332 - Outlet crossings for all wave types 

333 

334 All collision times are computed using exact analytical formulas. 

335 """ 

336 candidates = [] # Will use as min-heap by time 

337 counter = 0 # Unique counter to break ties without comparing EventType 

338 

339 # Get only active waves 

340 active_waves = [w for w in self.state.waves if w.is_active] 

341 

342 # 1. Flow change events (checked FIRST to get priority in tie-breaking) 

343 for t_change, flow_new in self._flow_change_schedule: 

344 if t_change > self.state.t_current: 

345 # All active waves are involved in flow change 

346 heappush( 

347 candidates, 

348 (t_change, counter, EventType.FLOW_CHANGE, active_waves.copy(), None, flow_new), 

349 ) 

350 counter += 1 

351 break # Only schedule the next flow change 

352 

353 # 2. Characteristic-Characteristic collisions 

354 chars = [w for w in active_waves if isinstance(w, CharacteristicWave)] 

355 for i, w1 in enumerate(chars): 

356 for w2 in chars[i + 1 :]: 

357 result = find_characteristic_intersection(w1, w2, self.state.t_current) 

358 if result: 

359 t, v = result 

360 if 0 <= v <= self.state.v_outlet: # In domain 

361 heappush(candidates, (t, counter, EventType.CHAR_CHAR_COLLISION, [w1, w2], v, None)) 

362 counter += 1 

363 

364 # 2. Shock-Shock collisions 

365 shocks = [w for w in active_waves if isinstance(w, ShockWave)] 

366 for i, w1 in enumerate(shocks): 

367 for w2 in shocks[i + 1 :]: 

368 result = find_shock_shock_intersection(w1, w2, self.state.t_current) 

369 if result: 

370 t, v = result 

371 if 0 <= v <= self.state.v_outlet: 

372 heappush(candidates, (t, counter, EventType.SHOCK_SHOCK_COLLISION, [w1, w2], v, None)) 

373 counter += 1 

374 

375 # 3. Shock-Characteristic collisions 

376 for shock in shocks: 

377 for char in chars: 

378 result = find_shock_characteristic_intersection(shock, char, self.state.t_current) 

379 if result: 

380 t, v = result 

381 if 0 <= v <= self.state.v_outlet: 

382 heappush(candidates, (t, counter, EventType.SHOCK_CHAR_COLLISION, [shock, char], v, None)) 

383 counter += 1 

384 

385 # 4. Rarefaction-Characteristic collisions 

386 rarefs = [w for w in active_waves if isinstance(w, RarefactionWave)] 

387 for raref in rarefs: 

388 for char in chars: 

389 intersections = find_rarefaction_boundary_intersections(raref, char, self.state.t_current) 

390 for t, v, boundary in intersections: 

391 if 0 <= v <= self.state.v_outlet: 

392 heappush( 

393 candidates, 

394 (t, counter, EventType.RAREF_CHAR_COLLISION, [raref, char], v, boundary), 

395 ) 

396 counter += 1 

397 

398 # 5. Shock-Rarefaction collisions 

399 for shock in shocks: 

400 for raref in rarefs: 

401 intersections = find_rarefaction_boundary_intersections(raref, shock, self.state.t_current) 

402 for t, v, boundary in intersections: 

403 if 0 <= v <= self.state.v_outlet: 

404 heappush( 

405 candidates, 

406 (t, counter, EventType.SHOCK_RAREF_COLLISION, [shock, raref], v, boundary), 

407 ) 

408 counter += 1 

409 

410 # 6. Rarefaction-Rarefaction collisions 

411 for i, raref1 in enumerate(rarefs): 

412 for raref2 in rarefs[i + 1 :]: 

413 intersections = find_rarefaction_boundary_intersections(raref1, raref2, self.state.t_current) 

414 for t, v, boundary in intersections: 

415 if 0 <= v <= self.state.v_outlet: 

416 heappush( 

417 candidates, 

418 (t, counter, EventType.RAREF_RAREF_COLLISION, [raref1, raref2], v, boundary), 

419 ) 

420 counter += 1 

421 

422 # 7. Outlet crossings 

423 for wave in active_waves: 

424 # For rarefactions, detect BOTH head and tail crossings 

425 if isinstance(wave, RarefactionWave): 

426 # Head crossing 

427 t_eval = max(self.state.t_current, wave.t_start) 

428 v_head = wave.head_position_at_time(t_eval) 

429 if v_head is not None and v_head < self.state.v_outlet: 

430 vel_head = wave.head_velocity() 

431 if vel_head > 0: 

432 dt_head = (self.state.v_outlet - v_head) / vel_head 

433 t_cross_head = t_eval + dt_head 

434 if t_cross_head > self.state.t_current: 

435 heappush( 

436 candidates, 

437 (t_cross_head, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None), 

438 ) 

439 counter += 1 

440 

441 # Tail crossing 

442 v_tail = wave.tail_position_at_time(t_eval) 

443 if v_tail is not None and v_tail < self.state.v_outlet: 

444 vel_tail = wave.tail_velocity() 

445 if vel_tail > 0: 

446 dt_tail = (self.state.v_outlet - v_tail) / vel_tail 

447 t_cross_tail = t_eval + dt_tail 

448 if t_cross_tail > self.state.t_current: 

449 heappush( 

450 candidates, 

451 (t_cross_tail, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None), 

452 ) 

453 counter += 1 

454 else: 

455 # For characteristics and shocks, use existing logic 

456 t_cross = find_outlet_crossing(wave, self.state.v_outlet, self.state.t_current) 

457 if t_cross and t_cross > self.state.t_current: 

458 heappush( 

459 candidates, (t_cross, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None) 

460 ) 

461 counter += 1 

462 

463 # Return earliest event 

464 if candidates: 

465 # Handle 6-tuple format: (t, counter, event_type, waves, v, extra) 

466 event_data = heappop(candidates) 

467 t = event_data[0] 

468 # Skip counter at index 1 

469 event_type = event_data[2] 

470 waves = event_data[3] 

471 v = event_data[4] 

472 extra = event_data[5] if len(event_data) > MIN_EVENT_DATA_LENGTH else None 

473 

474 # For FLOW_CHANGE events, extra contains flow_new 

475 flow_new = extra if event_type == EventType.FLOW_CHANGE else None 

476 

477 # For rarefaction collision events, extra contains boundary_type 

478 _raref_types = { 

479 EventType.RAREF_CHAR_COLLISION, 

480 EventType.SHOCK_RAREF_COLLISION, 

481 EventType.RAREF_RAREF_COLLISION, 

482 } 

483 boundary_type = extra if event_type in _raref_types else None 

484 

485 return Event( 

486 time=t, 

487 event_type=event_type, 

488 waves_involved=waves, 

489 location=v, 

490 flow_new=flow_new, 

491 boundary_type=boundary_type, 

492 ) 

493 

494 return None 

495 

496 def handle_event(self, event: Event): 

497 """ 

498 Handle an event by calling appropriate handler and updating state. 

499 

500 Dispatches to the correct event handler based on event type, then 

501 updates the simulation state with any new waves created. 

502 

503 Parameters 

504 ---------- 

505 event : Event 

506 Event to handle 

507 

508 Notes 

509 ----- 

510 Event handlers may: 

511 - Deactivate parent waves 

512 - Create new child waves 

513 - Record event details in history 

514 - Verify physical correctness (entropy, mass balance) 

515 """ 

516 new_waves = [] 

517 

518 if event.event_type == EventType.CHAR_CHAR_COLLISION: 

519 new_waves = handle_characteristic_collision( 

520 event.waves_involved[0], event.waves_involved[1], event.time, event.location 

521 ) 

522 

523 elif event.event_type == EventType.SHOCK_SHOCK_COLLISION: 

524 new_waves = handle_shock_collision( 

525 event.waves_involved[0], event.waves_involved[1], event.time, event.location 

526 ) 

527 

528 elif event.event_type == EventType.SHOCK_CHAR_COLLISION: 

529 new_waves = handle_shock_characteristic_collision( 

530 event.waves_involved[0], event.waves_involved[1], event.time, event.location 

531 ) 

532 

533 elif event.event_type == EventType.RAREF_CHAR_COLLISION: 

534 new_waves = handle_rarefaction_characteristic_collision( 

535 event.waves_involved[0], 

536 event.waves_involved[1], 

537 event.time, 

538 event.location, 

539 boundary_type=event.boundary_type, 

540 ) 

541 

542 elif event.event_type == EventType.SHOCK_RAREF_COLLISION: 

543 new_waves = handle_shock_rarefaction_collision( 

544 event.waves_involved[0], 

545 event.waves_involved[1], 

546 event.time, 

547 event.location, 

548 boundary_type=event.boundary_type, 

549 ) 

550 

551 elif event.event_type == EventType.RAREF_RAREF_COLLISION: 

552 new_waves = handle_rarefaction_rarefaction_collision( 

553 event.waves_involved[0], 

554 event.waves_involved[1], 

555 event.time, 

556 event.location, 

557 boundary_type=event.boundary_type, 

558 ) 

559 

560 elif event.event_type == EventType.OUTLET_CROSSING: 

561 event_record = handle_outlet_crossing(event.waves_involved[0], event.time, event.location) 

562 self.state.events.append(event_record) 

563 return # No new waves for outlet crossing 

564 

565 elif event.event_type == EventType.FLOW_CHANGE: 

566 # Get all active waves at this time 

567 active_waves = [w for w in self.state.waves if w.is_active] 

568 if event.flow_new is None: 

569 msg = "FLOW_CHANGE event must have flow_new set" 

570 raise RuntimeError(msg) 

571 new_waves = handle_flow_change(event.time, event.flow_new, active_waves) 

572 

573 # Add new waves to state 

574 self.state.waves.extend(new_waves) 

575 

576 # Record event 

577 self.state.events.append({ 

578 "time": event.time, 

579 "type": event.event_type.value, 

580 "location": event.location, 

581 "waves_before": event.waves_involved, 

582 "waves_after": new_waves, 

583 }) 

584 

585 def run(self, max_iterations: int = 10000, *, verbose: bool = False): 

586 """ 

587 Run simulation until no more events or max_iterations reached. 

588 

589 Processes events chronologically by repeatedly finding the next event, 

590 advancing time, and handling the event. Continues until no more events 

591 exist or the iteration limit is reached. 

592 

593 Parameters 

594 ---------- 

595 max_iterations : int, optional 

596 Maximum number of events to process. Default 10000. 

597 Prevents infinite loops in case of bugs. 

598 verbose : bool, optional 

599 Print progress messages. Default False. 

600 

601 Notes 

602 ----- 

603 The simulation stops when: 

604 - No more events exist (all waves have exited or become inactive) 

605 - max_iterations is reached (safety limit) 

606 

607 After completion, results are available in: 

608 - self.state.waves: All waves (active and inactive) 

609 - self.state.events: Complete event history 

610 - self.t_first_arrival: End of spin-up period 

611 """ 

612 iteration = 0 

613 

614 if verbose: 

615 logging.info("Starting simulation at t=%.3f", self.state.t_current) 

616 logging.info("Initial waves: %d", len(self.state.waves)) 

617 logging.info("First arrival time: %.3f days", self.t_first_arrival) 

618 

619 while iteration < max_iterations: 

620 # Find next event 

621 event = self.find_next_event() 

622 

623 if event is None: 

624 if verbose: 

625 logging.info("Simulation complete after %d events at t=%.6f", iteration, self.state.t_current) 

626 break 

627 

628 # Advance time 

629 self.state.t_current = event.time 

630 

631 # Handle event 

632 try: 

633 self.handle_event(event) 

634 except Exception: 

635 logging.exception("Error handling event at t=%.3f", event.time) 

636 raise 

637 

638 # Optional: verify physics periodically 

639 if iteration % 100 == 0: 

640 self.verify_physics() 

641 

642 if verbose and iteration % 10 == 0: 

643 active = sum(1 for w in self.state.waves if w.is_active) 

644 logging.debug("Iteration %d: t=%.3f, active_waves=%d", iteration, event.time, active) 

645 

646 iteration += 1 

647 

648 if iteration >= max_iterations: 

649 logging.warning("Reached max_iterations=%d", max_iterations) 

650 

651 if verbose: 

652 logging.info("Final statistics:") 

653 logging.info(" Total events: %d", len(self.state.events)) 

654 logging.info(" Total waves created: %d", len(self.state.waves)) 

655 logging.info(" Active waves: %d", sum(1 for w in self.state.waves if w.is_active)) 

656 logging.info(" First arrival time: %.6f days", self.t_first_arrival) 

657 

658 def verify_physics(self, *, check_mass_balance: bool = False, mass_balance_rtol: float = 1e-12): 

659 """ 

660 Verify physical correctness of current state. 

661 

662 Implements High Priority #3 from FRONT_TRACKING_REBUILD_PLAN.md by adding 

663 runtime mass balance verification using exact analytical integration. 

664 

665 Checks: 

666 - All shocks satisfy Lax entropy condition 

667 - All rarefactions have proper head/tail ordering 

668 - Mass balance: mass_in_domain + mass_out = mass_in (to specified tolerance) 

669 

670 Parameters 

671 ---------- 

672 check_mass_balance : bool, optional 

673 Enable mass balance verification. Default False (opt-in for now). 

674 mass_balance_rtol : float, optional 

675 Relative tolerance for mass balance check. Default 1e-6. 

676 This tolerance accounts for: 

677 - Midpoint approximation in spatial integration of rarefactions 

678 - Numerical precision in wave position calculations 

679 - Piecewise-constant approximations in domain partitioning 

680 

681 Raises 

682 ------ 

683 RuntimeError 

684 If physics violation is detected 

685 

686 Notes 

687 ----- 

688 Mass balance equation: 

689 mass_in_domain(t) + mass_out_cumulative(t) = mass_in_cumulative(t) 

690 

691 All mass calculations use exact analytical integration where possible: 

692 - Inlet/outlet temporal integrals: exact for piecewise-constant functions 

693 - Domain spatial integrals: exact for constants, midpoint rule for rarefactions 

694 - Overall precision: ~1e-10 to 1e-12 relative error 

695 """ 

696 # Check entropy for all active shocks 

697 for wave in self.state.waves: 

698 if isinstance(wave, ShockWave) and wave.is_active and not wave.satisfies_entropy(): 

699 msg = ( 

700 f"Shock at t_start={wave.t_start:.3f} violates entropy! " 

701 f"c_left={wave.c_left:.3f}, c_right={wave.c_right:.3f}, " 

702 f"velocity={wave.velocity:.3f}" 

703 ) 

704 raise RuntimeError(msg) 

705 

706 # Check rarefaction ordering 

707 for wave in self.state.waves: 

708 if isinstance(wave, RarefactionWave) and wave.is_active: 

709 v_head = wave.head_velocity() 

710 v_tail = wave.tail_velocity() 

711 if v_head <= v_tail: 

712 msg = ( 

713 f"Rarefaction at t_start={wave.t_start:.3f} has invalid ordering! " 

714 f"head_velocity={v_head:.3f} <= tail_velocity={v_tail:.3f}" 

715 ) 

716 raise RuntimeError(msg) 

717 

718 # Check mass balance using exact analytical integration 

719 if check_mass_balance: 

720 t_current = self.state.t_current 

721 

722 # Convert tedges from DatetimeIndex to float days for mass functions 

723 # Internal simulation uses float days from tedges[0] 

724 tedges_days = (self.state.tedges - self.state.tedges[0]) / pd.Timedelta(days=1) 

725 

726 # Compute total mass in domain at current time 

727 mass_in_domain = compute_domain_mass( 

728 t=t_current, 

729 v_outlet=self.state.v_outlet, 

730 waves=self.state.waves, 

731 sorption=self.state.sorption, 

732 ) 

733 

734 # Compute cumulative inlet mass 

735 mass_in_cumulative = compute_cumulative_inlet_mass( 

736 t=t_current, 

737 cin=self.state.cin, 

738 flow=self.state.flow, 

739 tedges_days=tedges_days, 

740 ) 

741 

742 # Compute cumulative outlet mass 

743 mass_out_cumulative = compute_cumulative_outlet_mass( 

744 t=t_current, 

745 v_outlet=self.state.v_outlet, 

746 waves=self.state.waves, 

747 sorption=self.state.sorption, 

748 flow=self.state.flow, 

749 tedges_days=tedges_days, 

750 ) 

751 

752 # Mass balance: mass_in_domain + mass_out = mass_in 

753 mass_balance_error = (mass_in_domain + mass_out_cumulative) - mass_in_cumulative 

754 

755 # Check relative error 

756 if mass_in_cumulative > 0: 

757 relative_error = abs(mass_balance_error) / mass_in_cumulative 

758 else: 

759 # No mass has entered yet - check absolute error is small 

760 relative_error = abs(mass_balance_error) 

761 

762 if relative_error > mass_balance_rtol: 

763 msg = ( 

764 f"Mass balance violation at t={t_current:.6f}! " 

765 f"mass_in_domain={mass_in_domain:.6e}, " 

766 f"mass_out={mass_out_cumulative:.6e}, " 

767 f"mass_in={mass_in_cumulative:.6e}, " 

768 f"error={mass_balance_error:.6e}, " 

769 f"relative_error={relative_error:.6e} > {mass_balance_rtol:.6e}" 

770 ) 

771 raise RuntimeError(msg)