Coverage for src / gwtransport / fronttracking / solver.py: 81%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-27 06:32 +0000

1""" 

2Front Tracking Solver - Event-Driven Simulation Engine. 

3 

4======================================================== 

5 

6This module implements the main event-driven front tracking solver for 

7nonlinear sorption transport. The solver maintains a list of waves and 

8processes collision events chronologically using exact analytical calculations. 

9 

10The algorithm: 

111. Initialize waves from inlet boundary conditions 

122. Find next event (earliest collision or outlet crossing) 

133. Advance time to event 

144. Handle event (create new waves, deactivate old ones) 

155. Repeat until no more events 

16 

17All calculations are exact analytical with machine precision. 

18""" 

19 

20import logging 

21from dataclasses import dataclass 

22from heapq import heappop, heappush 

23 

24import numpy as np 

25import pandas as pd 

26 

27from gwtransport.fronttracking.events import ( 

28 Event, 

29 EventType, 

30 find_characteristic_intersection, 

31 find_outlet_crossing, 

32 find_rarefaction_boundary_intersections, 

33 find_shock_characteristic_intersection, 

34 find_shock_shock_intersection, 

35) 

36from gwtransport.fronttracking.handlers import ( 

37 create_inlet_waves_at_time, 

38 handle_characteristic_collision, 

39 handle_flow_change, 

40 handle_outlet_crossing, 

41 handle_rarefaction_characteristic_collision, 

42 handle_rarefaction_rarefaction_collision, 

43 handle_shock_characteristic_collision, 

44 handle_shock_collision, 

45 handle_shock_rarefaction_collision, 

46) 

47from gwtransport.fronttracking.math import ( 

48 SorptionModel, 

49 compute_first_front_arrival_time, 

50) 

51 

52# Import mass balance functions for runtime verification (High Priority #3) 

53from gwtransport.fronttracking.output import ( 

54 compute_cumulative_inlet_mass, 

55 compute_cumulative_outlet_mass, 

56 compute_domain_mass, 

57) 

58from gwtransport.fronttracking.waves import ( 

59 CharacteristicWave, 

60 RarefactionWave, 

61 ShockWave, 

62 Wave, 

63) 

64 

65logger = logging.getLogger(__name__) 

66 

67# Numerical tolerance constants 

68EPSILON_CONCENTRATION = 1e-15 # Tolerance for concentration changes 

69MIN_EVENT_DATA_LENGTH = 5 # Minimum length of event_data tuple before accessing extra field 

70 

71 

72@dataclass 

73class FrontTrackerState: 

74 """ 

75 Complete state of the front tracking simulation. 

76 

77 This dataclass holds all information about the current simulation state, 

78 including all waves (active and inactive), event history, and simulation 

79 parameters. 

80 

81 Parameters 

82 ---------- 

83 waves : list of Wave 

84 All waves created during simulation (includes inactive waves) 

85 events : list of dict 

86 Event history with details about each event 

87 t_current : float 

88 Current simulation time [days from tedges[0]] 

89 v_outlet : float 

90 Outlet position [m³] 

91 sorption : FreundlichSorption or ConstantRetardation 

92 Sorption parameters 

93 cin : numpy.ndarray 

94 Inlet concentration time series [mass/volume] 

95 flow : numpy.ndarray 

96 Flow rate time series [m³/day] 

97 tedges : pandas.DatetimeIndex 

98 Time bin edges 

99 

100 Examples 

101 -------- 

102 :: 

103 

104 state = FrontTrackerState( 

105 waves=[], 

106 events=[], 

107 t_current=0.0, 

108 v_outlet=500.0, 

109 sorption=sorption, 

110 cin=cin, 

111 flow=flow, 

112 tedges=tedges, 

113 ) 

114 """ 

115 

116 waves: list[Wave] 

117 events: list[dict] 

118 t_current: float 

119 v_outlet: float 

120 sorption: SorptionModel 

121 cin: np.ndarray 

122 flow: np.ndarray 

123 tedges: pd.DatetimeIndex 

124 

125 

126class FrontTracker: 

127 """ 

128 Event-driven front tracking solver for nonlinear sorption transport. 

129 

130 This is the main simulation engine that orchestrates wave propagation, 

131 event detection, and event handling. The solver maintains a list of waves 

132 and processes collision events chronologically. 

133 

134 Parameters 

135 ---------- 

136 cin : numpy.ndarray 

137 Inlet concentration time series [mass/volume] 

138 flow : numpy.ndarray 

139 Flow rate time series [m³/day] 

140 tedges : numpy.ndarray 

141 Time bin edges [days] 

142 aquifer_pore_volume : float 

143 Total pore volume [m³] 

144 sorption : FreundlichSorption or ConstantRetardation 

145 Sorption parameters 

146 

147 Attributes 

148 ---------- 

149 state : FrontTrackerState 

150 Complete simulation state 

151 t_first_arrival : float 

152 First arrival time (end of spin-up period) [days] 

153 

154 Notes 

155 ----- 

156 The solver uses exact analytical calculations throughout with no numerical 

157 tolerances or iterative methods. All wave interactions are detected and 

158 handled with machine precision. 

159 

160 The spin-up period (t < t_first_arrival) is affected by unknown initial 

161 conditions. Results are only valid for t >= t_first_arrival. 

162 

163 Examples 

164 -------- 

165 :: 

166 

167 tracker = FrontTracker( 

168 cin=cin, 

169 flow=flow, 

170 tedges=tedges, 

171 aquifer_pore_volume=500.0, 

172 sorption=sorption, 

173 ) 

174 tracker.run(max_iterations=1000) 

175 # Access results 

176 print(f"Total events: {len(tracker.state.events)}") 

177 print(f"Active waves: {sum(1 for w in tracker.state.waves if w.is_active)}") 

178 """ 

179 

180 def __init__( 

181 self, 

182 cin: np.ndarray, 

183 flow: np.ndarray, 

184 tedges: pd.DatetimeIndex, 

185 aquifer_pore_volume: float, 

186 sorption: SorptionModel, 

187 ): 

188 """ 

189 Initialize tracker with inlet conditions and physical parameters. 

190 

191 Parameters 

192 ---------- 

193 cin : numpy.ndarray 

194 Inlet concentration time series [mass/volume] 

195 flow : numpy.ndarray 

196 Flow rate time series [m³/day] 

197 tedges : numpy.ndarray 

198 Time bin edges [days] 

199 aquifer_pore_volume : float 

200 Total pore volume [m³] 

201 sorption : FreundlichSorption or ConstantRetardation 

202 Sorption parameters 

203 

204 Raises 

205 ------ 

206 ValueError 

207 If input arrays have incompatible lengths or invalid values 

208 """ 

209 # Validation 

210 if len(tedges) != len(cin) + 1: 

211 msg = f"tedges must have length len(cin) + 1, got {len(tedges)} vs {len(cin) + 1}" 

212 raise ValueError(msg) 

213 if len(flow) != len(cin): 

214 msg = f"flow must have same length as cin, got {len(flow)} vs {len(cin)}" 

215 raise ValueError(msg) 

216 if np.any(cin < 0): 

217 msg = "cin must be non-negative" 

218 raise ValueError(msg) 

219 if np.any(flow < 0): 

220 msg = "flow must be non-negative (negative flow not supported)" 

221 raise ValueError(msg) 

222 if aquifer_pore_volume <= 0: 

223 msg = "aquifer_pore_volume must be positive" 

224 raise ValueError(msg) 

225 

226 # Initialize state 

227 # t_current is in days from tedges[0], so it starts at 0.0 

228 self.state = FrontTrackerState( 

229 waves=[], 

230 events=[], 

231 t_current=0.0, 

232 v_outlet=aquifer_pore_volume, 

233 sorption=sorption, 

234 cin=cin.copy(), 

235 flow=flow.copy(), 

236 tedges=tedges.copy(), 

237 ) 

238 

239 # Compute spin-up period 

240 self.t_first_arrival = compute_first_front_arrival_time(cin, flow, tedges, aquifer_pore_volume, sorption) 

241 

242 # Detect flow changes for event scheduling 

243 self._flow_change_schedule = self._detect_flow_changes() 

244 

245 # Initialize waves from inlet boundary conditions 

246 self._initialize_inlet_waves() 

247 

248 def _initialize_inlet_waves(self): 

249 """ 

250 Initialize all waves from inlet boundary conditions. 

251 

252 Creates waves at each inlet concentration change by analyzing 

253 characteristic velocities and creating appropriate wave types. 

254 """ 

255 c_prev = 0.0 # Assume domain initially at zero 

256 

257 for i in range(len(self.state.cin)): 

258 c_new = float(self.state.cin[i]) 

259 # Convert tedges[i] (Timestamp) to days from tedges[0] 

260 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1) 

261 flow_current = float(self.state.flow[i]) 

262 

263 if abs(c_new - c_prev) > EPSILON_CONCENTRATION: 

264 # Create wave(s) for this concentration change 

265 new_waves = create_inlet_waves_at_time( 

266 c_prev=c_prev, 

267 c_new=c_new, 

268 t=t_change, 

269 flow=flow_current, 

270 sorption=self.state.sorption, 

271 v_inlet=0.0, 

272 ) 

273 self.state.waves.extend(new_waves) 

274 

275 c_prev = c_new 

276 

277 def _detect_flow_changes(self) -> list[tuple[float, float]]: 

278 """ 

279 Detect all flow changes in the inlet time series. 

280 

281 Scans the flow array and identifies time points where flow changes. 

282 These become scheduled events that will update all wave velocities. 

283 

284 Returns 

285 ------- 

286 list of tuple 

287 List of (t_change, flow_new) tuples sorted by time. 

288 Times are in days from tedges[0]. 

289 

290 Notes 

291 ----- 

292 Flow changes are detected by comparing consecutive flow values. 

293 Only significant changes (>1e-15) are included. 

294 

295 Examples 

296 -------- 

297 >>> # flow = [100, 100, 200, 50] at tedges = [0, 10, 20, 30, 40] days 

298 >>> # Returns: [(20.0, 200.0), (30.0, 50.0)] 

299 """ 

300 flow_changes = [] 

301 epsilon_flow = 1e-15 

302 

303 for i in range(1, len(self.state.flow)): 

304 if abs(self.state.flow[i] - self.state.flow[i - 1]) > epsilon_flow: 

305 # Convert tedges[i] to days from tedges[0] 

306 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1) 

307 flow_new = self.state.flow[i] 

308 flow_changes.append((t_change, flow_new)) 

309 

310 return flow_changes 

311 

312 def find_next_event(self) -> Event | None: 

313 """ 

314 Find the next event (earliest in time). 

315 

316 Searches all possible wave interactions and returns the earliest event. 

317 Uses a priority queue (min-heap) to efficiently find the minimum time. 

318 

319 Returns 

320 ------- 

321 Event or None 

322 Next event to process, or None if no future events 

323 

324 Notes 

325 ----- 

326 Checks for: 

327 - Characteristic-characteristic collisions 

328 - Shock-shock collisions 

329 - Shock-characteristic collisions 

330 - Rarefaction-characteristic collisions (head/tail) 

331 - Shock-rarefaction collisions (head/tail) 

332 - Outlet crossings for all wave types 

333 

334 All collision times are computed using exact analytical formulas. 

335 """ 

336 candidates = [] # Will use as min-heap by time 

337 counter = 0 # Unique counter to break ties without comparing EventType 

338 

339 # Get only active waves 

340 active_waves = [w for w in self.state.waves if w.is_active] 

341 

342 # 1. Flow change events (checked FIRST to get priority in tie-breaking) 

343 for t_change, flow_new in self._flow_change_schedule: 

344 if t_change > self.state.t_current: 

345 # All active waves are involved in flow change 

346 heappush( 

347 candidates, 

348 (t_change, counter, EventType.FLOW_CHANGE, active_waves.copy(), None, flow_new), 

349 ) 

350 counter += 1 

351 break # Only schedule the next flow change 

352 

353 # 2. Characteristic-Characteristic collisions 

354 chars = [w for w in active_waves if isinstance(w, CharacteristicWave)] 

355 for i, w1 in enumerate(chars): 

356 for w2 in chars[i + 1 :]: 

357 result = find_characteristic_intersection(w1, w2, self.state.t_current) 

358 if result: 

359 t, v = result 

360 if 0 <= v <= self.state.v_outlet: # In domain 

361 heappush(candidates, (t, counter, EventType.CHAR_CHAR_COLLISION, [w1, w2], v, None)) 

362 counter += 1 

363 

364 # 2. Shock-Shock collisions 

365 shocks = [w for w in active_waves if isinstance(w, ShockWave)] 

366 for i, w1 in enumerate(shocks): 

367 for w2 in shocks[i + 1 :]: 

368 result = find_shock_shock_intersection(w1, w2, self.state.t_current) 

369 if result: 

370 t, v = result 

371 if 0 <= v <= self.state.v_outlet: 

372 heappush(candidates, (t, counter, EventType.SHOCK_SHOCK_COLLISION, [w1, w2], v, None)) 

373 counter += 1 

374 

375 # 3. Shock-Characteristic collisions 

376 for shock in shocks: 

377 for char in chars: 

378 result = find_shock_characteristic_intersection(shock, char, self.state.t_current) 

379 if result: 

380 t, v = result 

381 if 0 <= v <= self.state.v_outlet: 

382 heappush(candidates, (t, counter, EventType.SHOCK_CHAR_COLLISION, [shock, char], v, None)) 

383 counter += 1 

384 

385 # 4. Rarefaction-Characteristic collisions 

386 rarefs = [w for w in active_waves if isinstance(w, RarefactionWave)] 

387 for raref in rarefs: 

388 for char in chars: 

389 intersections = find_rarefaction_boundary_intersections(raref, char, self.state.t_current) 

390 for t, v, boundary in intersections: 

391 if 0 <= v <= self.state.v_outlet: 

392 heappush( 

393 candidates, 

394 (t, counter, EventType.RAREF_CHAR_COLLISION, [raref, char], v, boundary), 

395 ) 

396 counter += 1 

397 

398 # 5. Shock-Rarefaction collisions 

399 for shock in shocks: 

400 for raref in rarefs: 

401 intersections = find_rarefaction_boundary_intersections(raref, shock, self.state.t_current) 

402 for t, v, boundary in intersections: 

403 if 0 <= v <= self.state.v_outlet: 

404 heappush( 

405 candidates, 

406 (t, counter, EventType.SHOCK_RAREF_COLLISION, [shock, raref], v, boundary), 

407 ) 

408 counter += 1 

409 

410 # 6. Rarefaction-Rarefaction collisions 

411 for i, raref1 in enumerate(rarefs): 

412 for raref2 in rarefs[i + 1 :]: 

413 intersections = find_rarefaction_boundary_intersections(raref1, raref2, self.state.t_current) 

414 for t, v, boundary in intersections: 

415 if 0 <= v <= self.state.v_outlet: 

416 heappush( 

417 candidates, 

418 (t, counter, EventType.RAREF_RAREF_COLLISION, [raref1, raref2], v, boundary), 

419 ) 

420 counter += 1 

421 

422 # 7. Outlet crossings 

423 for wave in active_waves: 

424 # For rarefactions, detect BOTH head and tail crossings 

425 if isinstance(wave, RarefactionWave): 

426 # Head crossing 

427 t_eval = max(self.state.t_current, wave.t_start) 

428 v_head = wave.head_position_at_time(t_eval) 

429 if v_head is not None and v_head < self.state.v_outlet: 

430 vel_head = wave.head_velocity() 

431 if vel_head > 0: 

432 dt_head = (self.state.v_outlet - v_head) / vel_head 

433 t_cross_head = t_eval + dt_head 

434 if t_cross_head > self.state.t_current: 

435 heappush( 

436 candidates, 

437 (t_cross_head, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None), 

438 ) 

439 counter += 1 

440 

441 # Tail crossing 

442 v_tail = wave.tail_position_at_time(t_eval) 

443 if v_tail is not None and v_tail < self.state.v_outlet: 

444 vel_tail = wave.tail_velocity() 

445 if vel_tail > 0: 

446 dt_tail = (self.state.v_outlet - v_tail) / vel_tail 

447 t_cross_tail = t_eval + dt_tail 

448 if t_cross_tail > self.state.t_current: 

449 heappush( 

450 candidates, 

451 (t_cross_tail, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None), 

452 ) 

453 counter += 1 

454 else: 

455 # For characteristics and shocks, use existing logic 

456 t_cross = find_outlet_crossing(wave, self.state.v_outlet, self.state.t_current) 

457 if t_cross and t_cross > self.state.t_current: 

458 heappush( 

459 candidates, (t_cross, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None) 

460 ) 

461 counter += 1 

462 

463 # Return earliest event 

464 if candidates: 

465 # Handle 6-tuple format: (t, counter, event_type, waves, v, extra) 

466 event_data = heappop(candidates) 

467 t = event_data[0] 

468 # Skip counter at index 1 

469 event_type = event_data[2] 

470 waves = event_data[3] 

471 v = event_data[4] 

472 extra = event_data[5] if len(event_data) > MIN_EVENT_DATA_LENGTH else None 

473 

474 # For FLOW_CHANGE events, extra contains flow_new 

475 flow_new = extra if event_type == EventType.FLOW_CHANGE else None 

476 

477 # For rarefaction collision events, extra contains boundary_type 

478 raref_types = { 

479 EventType.RAREF_CHAR_COLLISION, 

480 EventType.SHOCK_RAREF_COLLISION, 

481 EventType.RAREF_RAREF_COLLISION, 

482 } 

483 boundary_type = extra if event_type in raref_types else None 

484 

485 return Event( 

486 time=t, 

487 event_type=event_type, 

488 waves_involved=waves, 

489 location=v, 

490 flow_new=flow_new, 

491 boundary_type=boundary_type, 

492 ) 

493 

494 return None 

495 

496 def handle_event(self, event: Event): 

497 """ 

498 Handle an event by calling appropriate handler and updating state. 

499 

500 Dispatches to the correct event handler based on event type, then 

501 updates the simulation state with any new waves created. 

502 

503 Parameters 

504 ---------- 

505 event : Event 

506 Event to handle 

507 

508 Raises 

509 ------ 

510 RuntimeError 

511 If the FLOW_CHANGE event does not have ``flow_new`` set. 

512 

513 Notes 

514 ----- 

515 Event handlers may: 

516 - Deactivate parent waves 

517 - Create new child waves 

518 - Record event details in history 

519 - Verify physical correctness (entropy, mass balance) 

520 """ 

521 new_waves = [] 

522 

523 if event.event_type == EventType.CHAR_CHAR_COLLISION: 

524 new_waves = handle_characteristic_collision( 

525 event.waves_involved[0], event.waves_involved[1], event.time, event.location 

526 ) 

527 

528 elif event.event_type == EventType.SHOCK_SHOCK_COLLISION: 

529 new_waves = handle_shock_collision( 

530 event.waves_involved[0], event.waves_involved[1], event.time, event.location 

531 ) 

532 

533 elif event.event_type == EventType.SHOCK_CHAR_COLLISION: 

534 new_waves = handle_shock_characteristic_collision( 

535 event.waves_involved[0], event.waves_involved[1], event.time, event.location 

536 ) 

537 

538 elif event.event_type == EventType.RAREF_CHAR_COLLISION: 

539 new_waves = handle_rarefaction_characteristic_collision( 

540 event.waves_involved[0], 

541 event.waves_involved[1], 

542 event.time, 

543 event.location, 

544 boundary_type=event.boundary_type, 

545 ) 

546 

547 elif event.event_type == EventType.SHOCK_RAREF_COLLISION: 

548 new_waves = handle_shock_rarefaction_collision( 

549 event.waves_involved[0], 

550 event.waves_involved[1], 

551 event.time, 

552 event.location, 

553 boundary_type=event.boundary_type, 

554 ) 

555 

556 elif event.event_type == EventType.RAREF_RAREF_COLLISION: 

557 new_waves = handle_rarefaction_rarefaction_collision( 

558 event.waves_involved[0], 

559 event.waves_involved[1], 

560 event.time, 

561 event.location, 

562 boundary_type=event.boundary_type, 

563 ) 

564 

565 elif event.event_type == EventType.OUTLET_CROSSING: 

566 event_record = handle_outlet_crossing(event.waves_involved[0], event.time, event.location) 

567 self.state.events.append(event_record) 

568 return # No new waves for outlet crossing 

569 

570 elif event.event_type == EventType.FLOW_CHANGE: 

571 # Get all active waves at this time 

572 active_waves = [w for w in self.state.waves if w.is_active] 

573 if event.flow_new is None: 

574 msg = "FLOW_CHANGE event must have flow_new set" 

575 raise RuntimeError(msg) 

576 new_waves = handle_flow_change(event.time, event.flow_new, active_waves) 

577 

578 # Add new waves to state 

579 self.state.waves.extend(new_waves) 

580 

581 # Record event 

582 self.state.events.append({ 

583 "time": event.time, 

584 "type": event.event_type.value, 

585 "location": event.location, 

586 "waves_before": event.waves_involved, 

587 "waves_after": new_waves, 

588 }) 

589 

590 def run(self, max_iterations: int = 10000, *, verbose: bool = False): 

591 """ 

592 Run simulation until no more events or max_iterations reached. 

593 

594 Processes events chronologically by repeatedly finding the next event, 

595 advancing time, and handling the event. Continues until no more events 

596 exist or the iteration limit is reached. 

597 

598 Parameters 

599 ---------- 

600 max_iterations : int, optional 

601 Maximum number of events to process. Default 10000. 

602 Prevents infinite loops in case of bugs. 

603 verbose : bool, optional 

604 Print progress messages. Default False. 

605 

606 Notes 

607 ----- 

608 The simulation stops when: 

609 - No more events exist (all waves have exited or become inactive) 

610 - max_iterations is reached (safety limit) 

611 

612 After completion, results are available in: 

613 - self.state.waves: All waves (active and inactive) 

614 - self.state.events: Complete event history 

615 - self.t_first_arrival: End of spin-up period 

616 """ 

617 iteration = 0 

618 

619 if verbose: 

620 logger.info("Starting simulation at t=%.3f", self.state.t_current) 

621 logger.info("Initial waves: %d", len(self.state.waves)) 

622 logger.info("First arrival time: %.3f days", self.t_first_arrival) 

623 

624 while iteration < max_iterations: 

625 # Find next event 

626 event = self.find_next_event() 

627 

628 if event is None: 

629 if verbose: 

630 logger.info("Simulation complete after %d events at t=%.6f", iteration, self.state.t_current) 

631 break 

632 

633 # Advance time 

634 self.state.t_current = event.time 

635 

636 # Handle event 

637 try: 

638 self.handle_event(event) 

639 except Exception: 

640 logger.exception("Error handling event at t=%.3f", event.time) 

641 raise 

642 

643 # Optional: verify physics periodically 

644 if iteration % 100 == 0: 

645 self.verify_physics() 

646 

647 if verbose and iteration % 10 == 0: 

648 active = sum(1 for w in self.state.waves if w.is_active) 

649 logger.debug("Iteration %d: t=%.3f, active_waves=%d", iteration, event.time, active) 

650 

651 iteration += 1 

652 

653 if iteration >= max_iterations: 

654 logger.warning("Reached max_iterations=%d", max_iterations) 

655 

656 if verbose: 

657 logger.info("Final statistics:") 

658 logger.info(" Total events: %d", len(self.state.events)) 

659 logger.info(" Total waves created: %d", len(self.state.waves)) 

660 logger.info(" Active waves: %d", sum(1 for w in self.state.waves if w.is_active)) 

661 logger.info(" First arrival time: %.6f days", self.t_first_arrival) 

662 

663 def verify_physics(self, *, check_mass_balance: bool = False, mass_balance_rtol: float = 1e-12): 

664 """ 

665 Verify physical correctness of current state. 

666 

667 Implements High Priority #3 from FRONT_TRACKING_REBUILD_PLAN.md by adding 

668 runtime mass balance verification using exact analytical integration. 

669 

670 Checks: 

671 - All shocks satisfy Lax entropy condition 

672 - All rarefactions have proper head/tail ordering 

673 - Mass balance: mass_in_domain + mass_out = mass_in (to specified tolerance) 

674 

675 Parameters 

676 ---------- 

677 check_mass_balance : bool, optional 

678 Enable mass balance verification. Default False (opt-in for now). 

679 mass_balance_rtol : float, optional 

680 Relative tolerance for mass balance check. Default 1e-6. 

681 This tolerance accounts for: 

682 - Midpoint approximation in spatial integration of rarefactions 

683 - Numerical precision in wave position calculations 

684 - Piecewise-constant approximations in domain partitioning 

685 

686 Raises 

687 ------ 

688 RuntimeError 

689 If physics violation is detected 

690 

691 Notes 

692 ----- 

693 Mass balance equation: 

694 mass_in_domain(t) + mass_out_cumulative(t) = mass_in_cumulative(t) 

695 

696 All mass calculations use exact analytical integration where possible: 

697 - Inlet/outlet temporal integrals: exact for piecewise-constant functions 

698 - Domain spatial integrals: exact for constants, midpoint rule for rarefactions 

699 - Overall precision: ~1e-10 to 1e-12 relative error 

700 """ 

701 # Check entropy for all active shocks 

702 for wave in self.state.waves: 

703 if isinstance(wave, ShockWave) and wave.is_active and not wave.satisfies_entropy(): 

704 msg = ( 

705 f"Shock at t_start={wave.t_start:.3f} violates entropy! " 

706 f"c_left={wave.c_left:.3f}, c_right={wave.c_right:.3f}, " 

707 f"velocity={wave.velocity:.3f}" 

708 ) 

709 raise RuntimeError(msg) 

710 

711 # Check rarefaction ordering 

712 for wave in self.state.waves: 

713 if isinstance(wave, RarefactionWave) and wave.is_active: 

714 v_head = wave.head_velocity() 

715 v_tail = wave.tail_velocity() 

716 if v_head <= v_tail: 

717 msg = ( 

718 f"Rarefaction at t_start={wave.t_start:.3f} has invalid ordering! " 

719 f"head_velocity={v_head:.3f} <= tail_velocity={v_tail:.3f}" 

720 ) 

721 raise RuntimeError(msg) 

722 

723 # Check mass balance using exact analytical integration 

724 if check_mass_balance: 

725 t_current = self.state.t_current 

726 

727 # Convert tedges from DatetimeIndex to float days for mass functions 

728 # Internal simulation uses float days from tedges[0] 

729 tedges_days = (self.state.tedges - self.state.tedges[0]) / pd.Timedelta(days=1) 

730 

731 # Compute total mass in domain at current time 

732 mass_in_domain = compute_domain_mass( 

733 t=t_current, 

734 v_outlet=self.state.v_outlet, 

735 waves=self.state.waves, 

736 sorption=self.state.sorption, 

737 ) 

738 

739 # Compute cumulative inlet mass 

740 mass_in_cumulative = compute_cumulative_inlet_mass( 

741 t=t_current, 

742 cin=self.state.cin, 

743 flow=self.state.flow, 

744 tedges_days=tedges_days, 

745 ) 

746 

747 # Compute cumulative outlet mass 

748 mass_out_cumulative = compute_cumulative_outlet_mass( 

749 t=t_current, 

750 v_outlet=self.state.v_outlet, 

751 waves=self.state.waves, 

752 sorption=self.state.sorption, 

753 flow=self.state.flow, 

754 tedges_days=tedges_days, 

755 ) 

756 

757 # Mass balance: mass_in_domain + mass_out = mass_in 

758 mass_balance_error = (mass_in_domain + mass_out_cumulative) - mass_in_cumulative 

759 

760 # Check relative error 

761 if mass_in_cumulative > 0: 

762 relative_error = abs(mass_balance_error) / mass_in_cumulative 

763 else: 

764 # No mass has entered yet - check absolute error is small 

765 relative_error = abs(mass_balance_error) 

766 

767 if relative_error > mass_balance_rtol: 

768 msg = ( 

769 f"Mass balance violation at t={t_current:.6f}! " 

770 f"mass_in_domain={mass_in_domain:.6e}, " 

771 f"mass_out={mass_out_cumulative:.6e}, " 

772 f"mass_in={mass_in_cumulative:.6e}, " 

773 f"error={mass_balance_error:.6e}, " 

774 f"relative_error={relative_error:.6e} > {mass_balance_rtol:.6e}" 

775 ) 

776 raise RuntimeError(msg)