Coverage for src / gwtransport / fronttracking / solver.py: 81%
232 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-27 06:32 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-27 06:32 +0000
1"""
2Front Tracking Solver - Event-Driven Simulation Engine.
4========================================================
6This module implements the main event-driven front tracking solver for
7nonlinear sorption transport. The solver maintains a list of waves and
8processes collision events chronologically using exact analytical calculations.
10The algorithm:
111. Initialize waves from inlet boundary conditions
122. Find next event (earliest collision or outlet crossing)
133. Advance time to event
144. Handle event (create new waves, deactivate old ones)
155. Repeat until no more events
17All calculations are exact analytical with machine precision.
18"""
20import logging
21from dataclasses import dataclass
22from heapq import heappop, heappush
24import numpy as np
25import pandas as pd
27from gwtransport.fronttracking.events import (
28 Event,
29 EventType,
30 find_characteristic_intersection,
31 find_outlet_crossing,
32 find_rarefaction_boundary_intersections,
33 find_shock_characteristic_intersection,
34 find_shock_shock_intersection,
35)
36from gwtransport.fronttracking.handlers import (
37 create_inlet_waves_at_time,
38 handle_characteristic_collision,
39 handle_flow_change,
40 handle_outlet_crossing,
41 handle_rarefaction_characteristic_collision,
42 handle_rarefaction_rarefaction_collision,
43 handle_shock_characteristic_collision,
44 handle_shock_collision,
45 handle_shock_rarefaction_collision,
46)
47from gwtransport.fronttracking.math import (
48 SorptionModel,
49 compute_first_front_arrival_time,
50)
52# Import mass balance functions for runtime verification (High Priority #3)
53from gwtransport.fronttracking.output import (
54 compute_cumulative_inlet_mass,
55 compute_cumulative_outlet_mass,
56 compute_domain_mass,
57)
58from gwtransport.fronttracking.waves import (
59 CharacteristicWave,
60 RarefactionWave,
61 ShockWave,
62 Wave,
63)
65logger = logging.getLogger(__name__)
67# Numerical tolerance constants
68EPSILON_CONCENTRATION = 1e-15 # Tolerance for concentration changes
69MIN_EVENT_DATA_LENGTH = 5 # Minimum length of event_data tuple before accessing extra field
72@dataclass
73class FrontTrackerState:
74 """
75 Complete state of the front tracking simulation.
77 This dataclass holds all information about the current simulation state,
78 including all waves (active and inactive), event history, and simulation
79 parameters.
81 Parameters
82 ----------
83 waves : list of Wave
84 All waves created during simulation (includes inactive waves)
85 events : list of dict
86 Event history with details about each event
87 t_current : float
88 Current simulation time [days from tedges[0]]
89 v_outlet : float
90 Outlet position [m³]
91 sorption : FreundlichSorption or ConstantRetardation
92 Sorption parameters
93 cin : numpy.ndarray
94 Inlet concentration time series [mass/volume]
95 flow : numpy.ndarray
96 Flow rate time series [m³/day]
97 tedges : pandas.DatetimeIndex
98 Time bin edges
100 Examples
101 --------
102 ::
104 state = FrontTrackerState(
105 waves=[],
106 events=[],
107 t_current=0.0,
108 v_outlet=500.0,
109 sorption=sorption,
110 cin=cin,
111 flow=flow,
112 tedges=tedges,
113 )
114 """
116 waves: list[Wave]
117 events: list[dict]
118 t_current: float
119 v_outlet: float
120 sorption: SorptionModel
121 cin: np.ndarray
122 flow: np.ndarray
123 tedges: pd.DatetimeIndex
126class FrontTracker:
127 """
128 Event-driven front tracking solver for nonlinear sorption transport.
130 This is the main simulation engine that orchestrates wave propagation,
131 event detection, and event handling. The solver maintains a list of waves
132 and processes collision events chronologically.
134 Parameters
135 ----------
136 cin : numpy.ndarray
137 Inlet concentration time series [mass/volume]
138 flow : numpy.ndarray
139 Flow rate time series [m³/day]
140 tedges : numpy.ndarray
141 Time bin edges [days]
142 aquifer_pore_volume : float
143 Total pore volume [m³]
144 sorption : FreundlichSorption or ConstantRetardation
145 Sorption parameters
147 Attributes
148 ----------
149 state : FrontTrackerState
150 Complete simulation state
151 t_first_arrival : float
152 First arrival time (end of spin-up period) [days]
154 Notes
155 -----
156 The solver uses exact analytical calculations throughout with no numerical
157 tolerances or iterative methods. All wave interactions are detected and
158 handled with machine precision.
160 The spin-up period (t < t_first_arrival) is affected by unknown initial
161 conditions. Results are only valid for t >= t_first_arrival.
163 Examples
164 --------
165 ::
167 tracker = FrontTracker(
168 cin=cin,
169 flow=flow,
170 tedges=tedges,
171 aquifer_pore_volume=500.0,
172 sorption=sorption,
173 )
174 tracker.run(max_iterations=1000)
175 # Access results
176 print(f"Total events: {len(tracker.state.events)}")
177 print(f"Active waves: {sum(1 for w in tracker.state.waves if w.is_active)}")
178 """
180 def __init__(
181 self,
182 cin: np.ndarray,
183 flow: np.ndarray,
184 tedges: pd.DatetimeIndex,
185 aquifer_pore_volume: float,
186 sorption: SorptionModel,
187 ):
188 """
189 Initialize tracker with inlet conditions and physical parameters.
191 Parameters
192 ----------
193 cin : numpy.ndarray
194 Inlet concentration time series [mass/volume]
195 flow : numpy.ndarray
196 Flow rate time series [m³/day]
197 tedges : numpy.ndarray
198 Time bin edges [days]
199 aquifer_pore_volume : float
200 Total pore volume [m³]
201 sorption : FreundlichSorption or ConstantRetardation
202 Sorption parameters
204 Raises
205 ------
206 ValueError
207 If input arrays have incompatible lengths or invalid values
208 """
209 # Validation
210 if len(tedges) != len(cin) + 1:
211 msg = f"tedges must have length len(cin) + 1, got {len(tedges)} vs {len(cin) + 1}"
212 raise ValueError(msg)
213 if len(flow) != len(cin):
214 msg = f"flow must have same length as cin, got {len(flow)} vs {len(cin)}"
215 raise ValueError(msg)
216 if np.any(cin < 0):
217 msg = "cin must be non-negative"
218 raise ValueError(msg)
219 if np.any(flow < 0):
220 msg = "flow must be non-negative (negative flow not supported)"
221 raise ValueError(msg)
222 if aquifer_pore_volume <= 0:
223 msg = "aquifer_pore_volume must be positive"
224 raise ValueError(msg)
226 # Initialize state
227 # t_current is in days from tedges[0], so it starts at 0.0
228 self.state = FrontTrackerState(
229 waves=[],
230 events=[],
231 t_current=0.0,
232 v_outlet=aquifer_pore_volume,
233 sorption=sorption,
234 cin=cin.copy(),
235 flow=flow.copy(),
236 tedges=tedges.copy(),
237 )
239 # Compute spin-up period
240 self.t_first_arrival = compute_first_front_arrival_time(cin, flow, tedges, aquifer_pore_volume, sorption)
242 # Detect flow changes for event scheduling
243 self._flow_change_schedule = self._detect_flow_changes()
245 # Initialize waves from inlet boundary conditions
246 self._initialize_inlet_waves()
248 def _initialize_inlet_waves(self):
249 """
250 Initialize all waves from inlet boundary conditions.
252 Creates waves at each inlet concentration change by analyzing
253 characteristic velocities and creating appropriate wave types.
254 """
255 c_prev = 0.0 # Assume domain initially at zero
257 for i in range(len(self.state.cin)):
258 c_new = float(self.state.cin[i])
259 # Convert tedges[i] (Timestamp) to days from tedges[0]
260 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1)
261 flow_current = float(self.state.flow[i])
263 if abs(c_new - c_prev) > EPSILON_CONCENTRATION:
264 # Create wave(s) for this concentration change
265 new_waves = create_inlet_waves_at_time(
266 c_prev=c_prev,
267 c_new=c_new,
268 t=t_change,
269 flow=flow_current,
270 sorption=self.state.sorption,
271 v_inlet=0.0,
272 )
273 self.state.waves.extend(new_waves)
275 c_prev = c_new
277 def _detect_flow_changes(self) -> list[tuple[float, float]]:
278 """
279 Detect all flow changes in the inlet time series.
281 Scans the flow array and identifies time points where flow changes.
282 These become scheduled events that will update all wave velocities.
284 Returns
285 -------
286 list of tuple
287 List of (t_change, flow_new) tuples sorted by time.
288 Times are in days from tedges[0].
290 Notes
291 -----
292 Flow changes are detected by comparing consecutive flow values.
293 Only significant changes (>1e-15) are included.
295 Examples
296 --------
297 >>> # flow = [100, 100, 200, 50] at tedges = [0, 10, 20, 30, 40] days
298 >>> # Returns: [(20.0, 200.0), (30.0, 50.0)]
299 """
300 flow_changes = []
301 epsilon_flow = 1e-15
303 for i in range(1, len(self.state.flow)):
304 if abs(self.state.flow[i] - self.state.flow[i - 1]) > epsilon_flow:
305 # Convert tedges[i] to days from tedges[0]
306 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1)
307 flow_new = self.state.flow[i]
308 flow_changes.append((t_change, flow_new))
310 return flow_changes
312 def find_next_event(self) -> Event | None:
313 """
314 Find the next event (earliest in time).
316 Searches all possible wave interactions and returns the earliest event.
317 Uses a priority queue (min-heap) to efficiently find the minimum time.
319 Returns
320 -------
321 Event or None
322 Next event to process, or None if no future events
324 Notes
325 -----
326 Checks for:
327 - Characteristic-characteristic collisions
328 - Shock-shock collisions
329 - Shock-characteristic collisions
330 - Rarefaction-characteristic collisions (head/tail)
331 - Shock-rarefaction collisions (head/tail)
332 - Outlet crossings for all wave types
334 All collision times are computed using exact analytical formulas.
335 """
336 candidates = [] # Will use as min-heap by time
337 counter = 0 # Unique counter to break ties without comparing EventType
339 # Get only active waves
340 active_waves = [w for w in self.state.waves if w.is_active]
342 # 1. Flow change events (checked FIRST to get priority in tie-breaking)
343 for t_change, flow_new in self._flow_change_schedule:
344 if t_change > self.state.t_current:
345 # All active waves are involved in flow change
346 heappush(
347 candidates,
348 (t_change, counter, EventType.FLOW_CHANGE, active_waves.copy(), None, flow_new),
349 )
350 counter += 1
351 break # Only schedule the next flow change
353 # 2. Characteristic-Characteristic collisions
354 chars = [w for w in active_waves if isinstance(w, CharacteristicWave)]
355 for i, w1 in enumerate(chars):
356 for w2 in chars[i + 1 :]:
357 result = find_characteristic_intersection(w1, w2, self.state.t_current)
358 if result:
359 t, v = result
360 if 0 <= v <= self.state.v_outlet: # In domain
361 heappush(candidates, (t, counter, EventType.CHAR_CHAR_COLLISION, [w1, w2], v, None))
362 counter += 1
364 # 2. Shock-Shock collisions
365 shocks = [w for w in active_waves if isinstance(w, ShockWave)]
366 for i, w1 in enumerate(shocks):
367 for w2 in shocks[i + 1 :]:
368 result = find_shock_shock_intersection(w1, w2, self.state.t_current)
369 if result:
370 t, v = result
371 if 0 <= v <= self.state.v_outlet:
372 heappush(candidates, (t, counter, EventType.SHOCK_SHOCK_COLLISION, [w1, w2], v, None))
373 counter += 1
375 # 3. Shock-Characteristic collisions
376 for shock in shocks:
377 for char in chars:
378 result = find_shock_characteristic_intersection(shock, char, self.state.t_current)
379 if result:
380 t, v = result
381 if 0 <= v <= self.state.v_outlet:
382 heappush(candidates, (t, counter, EventType.SHOCK_CHAR_COLLISION, [shock, char], v, None))
383 counter += 1
385 # 4. Rarefaction-Characteristic collisions
386 rarefs = [w for w in active_waves if isinstance(w, RarefactionWave)]
387 for raref in rarefs:
388 for char in chars:
389 intersections = find_rarefaction_boundary_intersections(raref, char, self.state.t_current)
390 for t, v, boundary in intersections:
391 if 0 <= v <= self.state.v_outlet:
392 heappush(
393 candidates,
394 (t, counter, EventType.RAREF_CHAR_COLLISION, [raref, char], v, boundary),
395 )
396 counter += 1
398 # 5. Shock-Rarefaction collisions
399 for shock in shocks:
400 for raref in rarefs:
401 intersections = find_rarefaction_boundary_intersections(raref, shock, self.state.t_current)
402 for t, v, boundary in intersections:
403 if 0 <= v <= self.state.v_outlet:
404 heappush(
405 candidates,
406 (t, counter, EventType.SHOCK_RAREF_COLLISION, [shock, raref], v, boundary),
407 )
408 counter += 1
410 # 6. Rarefaction-Rarefaction collisions
411 for i, raref1 in enumerate(rarefs):
412 for raref2 in rarefs[i + 1 :]:
413 intersections = find_rarefaction_boundary_intersections(raref1, raref2, self.state.t_current)
414 for t, v, boundary in intersections:
415 if 0 <= v <= self.state.v_outlet:
416 heappush(
417 candidates,
418 (t, counter, EventType.RAREF_RAREF_COLLISION, [raref1, raref2], v, boundary),
419 )
420 counter += 1
422 # 7. Outlet crossings
423 for wave in active_waves:
424 # For rarefactions, detect BOTH head and tail crossings
425 if isinstance(wave, RarefactionWave):
426 # Head crossing
427 t_eval = max(self.state.t_current, wave.t_start)
428 v_head = wave.head_position_at_time(t_eval)
429 if v_head is not None and v_head < self.state.v_outlet:
430 vel_head = wave.head_velocity()
431 if vel_head > 0:
432 dt_head = (self.state.v_outlet - v_head) / vel_head
433 t_cross_head = t_eval + dt_head
434 if t_cross_head > self.state.t_current:
435 heappush(
436 candidates,
437 (t_cross_head, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None),
438 )
439 counter += 1
441 # Tail crossing
442 v_tail = wave.tail_position_at_time(t_eval)
443 if v_tail is not None and v_tail < self.state.v_outlet:
444 vel_tail = wave.tail_velocity()
445 if vel_tail > 0:
446 dt_tail = (self.state.v_outlet - v_tail) / vel_tail
447 t_cross_tail = t_eval + dt_tail
448 if t_cross_tail > self.state.t_current:
449 heappush(
450 candidates,
451 (t_cross_tail, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None),
452 )
453 counter += 1
454 else:
455 # For characteristics and shocks, use existing logic
456 t_cross = find_outlet_crossing(wave, self.state.v_outlet, self.state.t_current)
457 if t_cross and t_cross > self.state.t_current:
458 heappush(
459 candidates, (t_cross, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None)
460 )
461 counter += 1
463 # Return earliest event
464 if candidates:
465 # Handle 6-tuple format: (t, counter, event_type, waves, v, extra)
466 event_data = heappop(candidates)
467 t = event_data[0]
468 # Skip counter at index 1
469 event_type = event_data[2]
470 waves = event_data[3]
471 v = event_data[4]
472 extra = event_data[5] if len(event_data) > MIN_EVENT_DATA_LENGTH else None
474 # For FLOW_CHANGE events, extra contains flow_new
475 flow_new = extra if event_type == EventType.FLOW_CHANGE else None
477 # For rarefaction collision events, extra contains boundary_type
478 raref_types = {
479 EventType.RAREF_CHAR_COLLISION,
480 EventType.SHOCK_RAREF_COLLISION,
481 EventType.RAREF_RAREF_COLLISION,
482 }
483 boundary_type = extra if event_type in raref_types else None
485 return Event(
486 time=t,
487 event_type=event_type,
488 waves_involved=waves,
489 location=v,
490 flow_new=flow_new,
491 boundary_type=boundary_type,
492 )
494 return None
496 def handle_event(self, event: Event):
497 """
498 Handle an event by calling appropriate handler and updating state.
500 Dispatches to the correct event handler based on event type, then
501 updates the simulation state with any new waves created.
503 Parameters
504 ----------
505 event : Event
506 Event to handle
508 Raises
509 ------
510 RuntimeError
511 If the FLOW_CHANGE event does not have ``flow_new`` set.
513 Notes
514 -----
515 Event handlers may:
516 - Deactivate parent waves
517 - Create new child waves
518 - Record event details in history
519 - Verify physical correctness (entropy, mass balance)
520 """
521 new_waves = []
523 if event.event_type == EventType.CHAR_CHAR_COLLISION:
524 new_waves = handle_characteristic_collision(
525 event.waves_involved[0], event.waves_involved[1], event.time, event.location
526 )
528 elif event.event_type == EventType.SHOCK_SHOCK_COLLISION:
529 new_waves = handle_shock_collision(
530 event.waves_involved[0], event.waves_involved[1], event.time, event.location
531 )
533 elif event.event_type == EventType.SHOCK_CHAR_COLLISION:
534 new_waves = handle_shock_characteristic_collision(
535 event.waves_involved[0], event.waves_involved[1], event.time, event.location
536 )
538 elif event.event_type == EventType.RAREF_CHAR_COLLISION:
539 new_waves = handle_rarefaction_characteristic_collision(
540 event.waves_involved[0],
541 event.waves_involved[1],
542 event.time,
543 event.location,
544 boundary_type=event.boundary_type,
545 )
547 elif event.event_type == EventType.SHOCK_RAREF_COLLISION:
548 new_waves = handle_shock_rarefaction_collision(
549 event.waves_involved[0],
550 event.waves_involved[1],
551 event.time,
552 event.location,
553 boundary_type=event.boundary_type,
554 )
556 elif event.event_type == EventType.RAREF_RAREF_COLLISION:
557 new_waves = handle_rarefaction_rarefaction_collision(
558 event.waves_involved[0],
559 event.waves_involved[1],
560 event.time,
561 event.location,
562 boundary_type=event.boundary_type,
563 )
565 elif event.event_type == EventType.OUTLET_CROSSING:
566 event_record = handle_outlet_crossing(event.waves_involved[0], event.time, event.location)
567 self.state.events.append(event_record)
568 return # No new waves for outlet crossing
570 elif event.event_type == EventType.FLOW_CHANGE:
571 # Get all active waves at this time
572 active_waves = [w for w in self.state.waves if w.is_active]
573 if event.flow_new is None:
574 msg = "FLOW_CHANGE event must have flow_new set"
575 raise RuntimeError(msg)
576 new_waves = handle_flow_change(event.time, event.flow_new, active_waves)
578 # Add new waves to state
579 self.state.waves.extend(new_waves)
581 # Record event
582 self.state.events.append({
583 "time": event.time,
584 "type": event.event_type.value,
585 "location": event.location,
586 "waves_before": event.waves_involved,
587 "waves_after": new_waves,
588 })
590 def run(self, max_iterations: int = 10000, *, verbose: bool = False):
591 """
592 Run simulation until no more events or max_iterations reached.
594 Processes events chronologically by repeatedly finding the next event,
595 advancing time, and handling the event. Continues until no more events
596 exist or the iteration limit is reached.
598 Parameters
599 ----------
600 max_iterations : int, optional
601 Maximum number of events to process. Default 10000.
602 Prevents infinite loops in case of bugs.
603 verbose : bool, optional
604 Print progress messages. Default False.
606 Notes
607 -----
608 The simulation stops when:
609 - No more events exist (all waves have exited or become inactive)
610 - max_iterations is reached (safety limit)
612 After completion, results are available in:
613 - self.state.waves: All waves (active and inactive)
614 - self.state.events: Complete event history
615 - self.t_first_arrival: End of spin-up period
616 """
617 iteration = 0
619 if verbose:
620 logger.info("Starting simulation at t=%.3f", self.state.t_current)
621 logger.info("Initial waves: %d", len(self.state.waves))
622 logger.info("First arrival time: %.3f days", self.t_first_arrival)
624 while iteration < max_iterations:
625 # Find next event
626 event = self.find_next_event()
628 if event is None:
629 if verbose:
630 logger.info("Simulation complete after %d events at t=%.6f", iteration, self.state.t_current)
631 break
633 # Advance time
634 self.state.t_current = event.time
636 # Handle event
637 try:
638 self.handle_event(event)
639 except Exception:
640 logger.exception("Error handling event at t=%.3f", event.time)
641 raise
643 # Optional: verify physics periodically
644 if iteration % 100 == 0:
645 self.verify_physics()
647 if verbose and iteration % 10 == 0:
648 active = sum(1 for w in self.state.waves if w.is_active)
649 logger.debug("Iteration %d: t=%.3f, active_waves=%d", iteration, event.time, active)
651 iteration += 1
653 if iteration >= max_iterations:
654 logger.warning("Reached max_iterations=%d", max_iterations)
656 if verbose:
657 logger.info("Final statistics:")
658 logger.info(" Total events: %d", len(self.state.events))
659 logger.info(" Total waves created: %d", len(self.state.waves))
660 logger.info(" Active waves: %d", sum(1 for w in self.state.waves if w.is_active))
661 logger.info(" First arrival time: %.6f days", self.t_first_arrival)
663 def verify_physics(self, *, check_mass_balance: bool = False, mass_balance_rtol: float = 1e-12):
664 """
665 Verify physical correctness of current state.
667 Implements High Priority #3 from FRONT_TRACKING_REBUILD_PLAN.md by adding
668 runtime mass balance verification using exact analytical integration.
670 Checks:
671 - All shocks satisfy Lax entropy condition
672 - All rarefactions have proper head/tail ordering
673 - Mass balance: mass_in_domain + mass_out = mass_in (to specified tolerance)
675 Parameters
676 ----------
677 check_mass_balance : bool, optional
678 Enable mass balance verification. Default False (opt-in for now).
679 mass_balance_rtol : float, optional
680 Relative tolerance for mass balance check. Default 1e-6.
681 This tolerance accounts for:
682 - Midpoint approximation in spatial integration of rarefactions
683 - Numerical precision in wave position calculations
684 - Piecewise-constant approximations in domain partitioning
686 Raises
687 ------
688 RuntimeError
689 If physics violation is detected
691 Notes
692 -----
693 Mass balance equation:
694 mass_in_domain(t) + mass_out_cumulative(t) = mass_in_cumulative(t)
696 All mass calculations use exact analytical integration where possible:
697 - Inlet/outlet temporal integrals: exact for piecewise-constant functions
698 - Domain spatial integrals: exact for constants, midpoint rule for rarefactions
699 - Overall precision: ~1e-10 to 1e-12 relative error
700 """
701 # Check entropy for all active shocks
702 for wave in self.state.waves:
703 if isinstance(wave, ShockWave) and wave.is_active and not wave.satisfies_entropy():
704 msg = (
705 f"Shock at t_start={wave.t_start:.3f} violates entropy! "
706 f"c_left={wave.c_left:.3f}, c_right={wave.c_right:.3f}, "
707 f"velocity={wave.velocity:.3f}"
708 )
709 raise RuntimeError(msg)
711 # Check rarefaction ordering
712 for wave in self.state.waves:
713 if isinstance(wave, RarefactionWave) and wave.is_active:
714 v_head = wave.head_velocity()
715 v_tail = wave.tail_velocity()
716 if v_head <= v_tail:
717 msg = (
718 f"Rarefaction at t_start={wave.t_start:.3f} has invalid ordering! "
719 f"head_velocity={v_head:.3f} <= tail_velocity={v_tail:.3f}"
720 )
721 raise RuntimeError(msg)
723 # Check mass balance using exact analytical integration
724 if check_mass_balance:
725 t_current = self.state.t_current
727 # Convert tedges from DatetimeIndex to float days for mass functions
728 # Internal simulation uses float days from tedges[0]
729 tedges_days = (self.state.tedges - self.state.tedges[0]) / pd.Timedelta(days=1)
731 # Compute total mass in domain at current time
732 mass_in_domain = compute_domain_mass(
733 t=t_current,
734 v_outlet=self.state.v_outlet,
735 waves=self.state.waves,
736 sorption=self.state.sorption,
737 )
739 # Compute cumulative inlet mass
740 mass_in_cumulative = compute_cumulative_inlet_mass(
741 t=t_current,
742 cin=self.state.cin,
743 flow=self.state.flow,
744 tedges_days=tedges_days,
745 )
747 # Compute cumulative outlet mass
748 mass_out_cumulative = compute_cumulative_outlet_mass(
749 t=t_current,
750 v_outlet=self.state.v_outlet,
751 waves=self.state.waves,
752 sorption=self.state.sorption,
753 flow=self.state.flow,
754 tedges_days=tedges_days,
755 )
757 # Mass balance: mass_in_domain + mass_out = mass_in
758 mass_balance_error = (mass_in_domain + mass_out_cumulative) - mass_in_cumulative
760 # Check relative error
761 if mass_in_cumulative > 0:
762 relative_error = abs(mass_balance_error) / mass_in_cumulative
763 else:
764 # No mass has entered yet - check absolute error is small
765 relative_error = abs(mass_balance_error)
767 if relative_error > mass_balance_rtol:
768 msg = (
769 f"Mass balance violation at t={t_current:.6f}! "
770 f"mass_in_domain={mass_in_domain:.6e}, "
771 f"mass_out={mass_out_cumulative:.6e}, "
772 f"mass_in={mass_in_cumulative:.6e}, "
773 f"error={mass_balance_error:.6e}, "
774 f"relative_error={relative_error:.6e} > {mass_balance_rtol:.6e}"
775 )
776 raise RuntimeError(msg)