Coverage for src / gwtransport / fronttracking / solver.py: 81%
232 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-04 18:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-04 18:51 +0000
1"""
2Front Tracking Solver - Event-Driven Simulation Engine.
4========================================================
6This module implements the main event-driven front tracking solver for
7nonlinear sorption transport. The solver maintains a list of waves and
8processes collision events chronologically using exact analytical calculations.
10The algorithm:
111. Initialize waves from inlet boundary conditions
122. Find next event (earliest collision or outlet crossing)
133. Advance time to event
144. Handle event (create new waves, deactivate old ones)
155. Repeat until no more events
17All calculations are exact analytical with machine precision.
18"""
20import logging
21from dataclasses import dataclass
22from heapq import heappop, heappush
23from typing import Optional
25import numpy as np
26import pandas as pd
28from gwtransport.fronttracking.events import (
29 Event,
30 EventType,
31 find_characteristic_intersection,
32 find_outlet_crossing,
33 find_rarefaction_boundary_intersections,
34 find_shock_characteristic_intersection,
35 find_shock_shock_intersection,
36)
37from gwtransport.fronttracking.handlers import (
38 create_inlet_waves_at_time,
39 handle_characteristic_collision,
40 handle_flow_change,
41 handle_outlet_crossing,
42 handle_rarefaction_characteristic_collision,
43 handle_rarefaction_rarefaction_collision,
44 handle_shock_characteristic_collision,
45 handle_shock_collision,
46 handle_shock_rarefaction_collision,
47)
48from gwtransport.fronttracking.math import (
49 ConstantRetardation,
50 FreundlichSorption,
51 compute_first_front_arrival_time,
52)
54# Import mass balance functions for runtime verification (High Priority #3)
55from gwtransport.fronttracking.output import (
56 compute_cumulative_inlet_mass,
57 compute_cumulative_outlet_mass,
58 compute_domain_mass,
59)
60from gwtransport.fronttracking.waves import (
61 CharacteristicWave,
62 RarefactionWave,
63 ShockWave,
64 Wave,
65)
67# Numerical tolerance constants
68EPSILON_CONCENTRATION = 1e-15 # Tolerance for concentration changes
69MIN_EVENT_DATA_LENGTH = 5 # Minimum length of event_data tuple before accessing extra field
72@dataclass
73class FrontTrackerState:
74 """
75 Complete state of the front tracking simulation.
77 This dataclass holds all information about the current simulation state,
78 including all waves (active and inactive), event history, and simulation
79 parameters.
81 Parameters
82 ----------
83 waves : list of Wave
84 All waves created during simulation (includes inactive waves)
85 events : list of dict
86 Event history with details about each event
87 t_current : float
88 Current simulation time [days from tedges[0]]
89 v_outlet : float
90 Outlet position [m³]
91 sorption : FreundlichSorption or ConstantRetardation
92 Sorption parameters
93 cin : numpy.ndarray
94 Inlet concentration time series [mass/volume]
95 flow : numpy.ndarray
96 Flow rate time series [m³/day]
97 tedges : pandas.DatetimeIndex
98 Time bin edges
100 Examples
101 --------
102 ::
104 state = FrontTrackerState(
105 waves=[],
106 events=[],
107 t_current=0.0,
108 v_outlet=500.0,
109 sorption=sorption,
110 cin=cin,
111 flow=flow,
112 tedges=tedges,
113 )
114 """
116 waves: list[Wave]
117 events: list[dict]
118 t_current: float
119 v_outlet: float
120 sorption: FreundlichSorption | ConstantRetardation
121 cin: np.ndarray
122 flow: np.ndarray
123 tedges: pd.DatetimeIndex
126class FrontTracker:
127 """
128 Event-driven front tracking solver for nonlinear sorption transport.
130 This is the main simulation engine that orchestrates wave propagation,
131 event detection, and event handling. The solver maintains a list of waves
132 and processes collision events chronologically.
134 Parameters
135 ----------
136 cin : numpy.ndarray
137 Inlet concentration time series [mass/volume]
138 flow : numpy.ndarray
139 Flow rate time series [m³/day]
140 tedges : numpy.ndarray
141 Time bin edges [days]
142 aquifer_pore_volume : float
143 Total pore volume [m³]
144 sorption : FreundlichSorption or ConstantRetardation
145 Sorption parameters
147 Attributes
148 ----------
149 state : FrontTrackerState
150 Complete simulation state
151 t_first_arrival : float
152 First arrival time (end of spin-up period) [days]
154 Examples
155 --------
156 ::
158 tracker = FrontTracker(
159 cin=cin,
160 flow=flow,
161 tedges=tedges,
162 aquifer_pore_volume=500.0,
163 sorption=sorption,
164 )
165 tracker.run(max_iterations=1000)
166 # Access results
167 print(f"Total events: {len(tracker.state.events)}")
168 print(f"Active waves: {sum(1 for w in tracker.state.waves if w.is_active)}")
170 Notes
171 -----
172 The solver uses exact analytical calculations throughout with no numerical
173 tolerances or iterative methods. All wave interactions are detected and
174 handled with machine precision.
176 The spin-up period (t < t_first_arrival) is affected by unknown initial
177 conditions. Results are only valid for t >= t_first_arrival.
178 """
180 def __init__(
181 self,
182 cin: np.ndarray,
183 flow: np.ndarray,
184 tedges: pd.DatetimeIndex,
185 aquifer_pore_volume: float,
186 sorption: FreundlichSorption | ConstantRetardation,
187 ):
188 """
189 Initialize tracker with inlet conditions and physical parameters.
191 Parameters
192 ----------
193 cin : numpy.ndarray
194 Inlet concentration time series [mass/volume]
195 flow : numpy.ndarray
196 Flow rate time series [m³/day]
197 tedges : numpy.ndarray
198 Time bin edges [days]
199 aquifer_pore_volume : float
200 Total pore volume [m³]
201 sorption : FreundlichSorption or ConstantRetardation
202 Sorption parameters
204 Raises
205 ------
206 ValueError
207 If input arrays have incompatible lengths or invalid values
208 """
209 # Validation
210 if len(tedges) != len(cin) + 1:
211 msg = f"tedges must have length len(cin) + 1, got {len(tedges)} vs {len(cin) + 1}"
212 raise ValueError(msg)
213 if len(flow) != len(cin):
214 msg = f"flow must have same length as cin, got {len(flow)} vs {len(cin)}"
215 raise ValueError(msg)
216 if np.any(cin < 0):
217 msg = "cin must be non-negative"
218 raise ValueError(msg)
219 if np.any(flow <= 0):
220 msg = "flow must be positive"
221 raise ValueError(msg)
222 if aquifer_pore_volume <= 0:
223 msg = "aquifer_pore_volume must be positive"
224 raise ValueError(msg)
226 # Initialize state
227 # t_current is in days from tedges[0], so it starts at 0.0
228 self.state = FrontTrackerState(
229 waves=[],
230 events=[],
231 t_current=0.0,
232 v_outlet=aquifer_pore_volume,
233 sorption=sorption,
234 cin=cin.copy(),
235 flow=flow.copy(),
236 tedges=tedges.copy(),
237 )
239 # Compute spin-up period
240 self.t_first_arrival = compute_first_front_arrival_time(cin, flow, tedges, aquifer_pore_volume, sorption)
242 # Detect flow changes for event scheduling
243 self._flow_change_schedule = self._detect_flow_changes()
245 # Initialize waves from inlet boundary conditions
246 self._initialize_inlet_waves()
248 def _initialize_inlet_waves(self):
249 """
250 Initialize all waves from inlet boundary conditions.
252 Creates waves at each inlet concentration change by analyzing
253 characteristic velocities and creating appropriate wave types.
254 """
255 c_prev = 0.0 # Assume domain initially at zero
257 for i in range(len(self.state.cin)):
258 c_new = float(self.state.cin[i])
259 # Convert tedges[i] (Timestamp) to days from tedges[0]
260 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1)
261 flow_current = float(self.state.flow[i])
263 if abs(c_new - c_prev) > EPSILON_CONCENTRATION:
264 # Create wave(s) for this concentration change
265 new_waves = create_inlet_waves_at_time(
266 c_prev=c_prev,
267 c_new=c_new,
268 t=t_change,
269 flow=flow_current,
270 sorption=self.state.sorption,
271 v_inlet=0.0,
272 )
273 self.state.waves.extend(new_waves)
275 c_prev = c_new
277 def _detect_flow_changes(self) -> list[tuple[float, float]]:
278 """
279 Detect all flow changes in the inlet time series.
281 Scans the flow array and identifies time points where flow changes.
282 These become scheduled events that will update all wave velocities.
284 Returns
285 -------
286 list of tuple
287 List of (t_change, flow_new) tuples sorted by time.
288 Times are in days from tedges[0].
290 Notes
291 -----
292 Flow changes are detected by comparing consecutive flow values.
293 Only significant changes (>1e-15) are included.
295 Examples
296 --------
297 >>> # flow = [100, 100, 200, 50] at tedges = [0, 10, 20, 30, 40] days
298 >>> # Returns: [(20.0, 200.0), (30.0, 50.0)]
299 """
300 flow_changes = []
301 epsilon_flow = 1e-15
303 for i in range(1, len(self.state.flow)):
304 if abs(self.state.flow[i] - self.state.flow[i - 1]) > epsilon_flow:
305 # Convert tedges[i] to days from tedges[0]
306 t_change = (self.state.tedges[i] - self.state.tedges[0]) / pd.Timedelta(days=1)
307 flow_new = self.state.flow[i]
308 flow_changes.append((t_change, flow_new))
310 return flow_changes
312 def find_next_event(self) -> Optional[Event]:
313 """
314 Find the next event (earliest in time).
316 Searches all possible wave interactions and returns the earliest event.
317 Uses a priority queue (min-heap) to efficiently find the minimum time.
319 Returns
320 -------
321 Event or None
322 Next event to process, or None if no future events
324 Notes
325 -----
326 Checks for:
327 - Characteristic-characteristic collisions
328 - Shock-shock collisions
329 - Shock-characteristic collisions
330 - Rarefaction-characteristic collisions (head/tail)
331 - Shock-rarefaction collisions (head/tail)
332 - Outlet crossings for all wave types
334 All collision times are computed using exact analytical formulas.
335 """
336 candidates = [] # Will use as min-heap by time
337 counter = 0 # Unique counter to break ties without comparing EventType
339 # Get only active waves
340 active_waves = [w for w in self.state.waves if w.is_active]
342 # 1. Flow change events (checked FIRST to get priority in tie-breaking)
343 for t_change, flow_new in self._flow_change_schedule:
344 if t_change > self.state.t_current:
345 # All active waves are involved in flow change
346 heappush(
347 candidates,
348 (t_change, counter, EventType.FLOW_CHANGE, active_waves.copy(), None, flow_new),
349 )
350 counter += 1
351 break # Only schedule the next flow change
353 # 2. Characteristic-Characteristic collisions
354 chars = [w for w in active_waves if isinstance(w, CharacteristicWave)]
355 for i, w1 in enumerate(chars):
356 for w2 in chars[i + 1 :]:
357 result = find_characteristic_intersection(w1, w2, self.state.t_current)
358 if result:
359 t, v = result
360 if 0 <= v <= self.state.v_outlet: # In domain
361 heappush(candidates, (t, counter, EventType.CHAR_CHAR_COLLISION, [w1, w2], v, None))
362 counter += 1
364 # 2. Shock-Shock collisions
365 shocks = [w for w in active_waves if isinstance(w, ShockWave)]
366 for i, w1 in enumerate(shocks):
367 for w2 in shocks[i + 1 :]:
368 result = find_shock_shock_intersection(w1, w2, self.state.t_current)
369 if result:
370 t, v = result
371 if 0 <= v <= self.state.v_outlet:
372 heappush(candidates, (t, counter, EventType.SHOCK_SHOCK_COLLISION, [w1, w2], v, None))
373 counter += 1
375 # 3. Shock-Characteristic collisions
376 for shock in shocks:
377 for char in chars:
378 result = find_shock_characteristic_intersection(shock, char, self.state.t_current)
379 if result:
380 t, v = result
381 if 0 <= v <= self.state.v_outlet:
382 heappush(candidates, (t, counter, EventType.SHOCK_CHAR_COLLISION, [shock, char], v, None))
383 counter += 1
385 # 4. Rarefaction-Characteristic collisions
386 rarefs = [w for w in active_waves if isinstance(w, RarefactionWave)]
387 for raref in rarefs:
388 for char in chars:
389 intersections = find_rarefaction_boundary_intersections(raref, char, self.state.t_current)
390 for t, v, boundary in intersections:
391 if 0 <= v <= self.state.v_outlet:
392 heappush(
393 candidates,
394 (t, counter, EventType.RAREF_CHAR_COLLISION, [raref, char], v, boundary),
395 )
396 counter += 1
398 # 5. Shock-Rarefaction collisions
399 for shock in shocks:
400 for raref in rarefs:
401 intersections = find_rarefaction_boundary_intersections(raref, shock, self.state.t_current)
402 for t, v, boundary in intersections:
403 if 0 <= v <= self.state.v_outlet:
404 heappush(
405 candidates,
406 (t, counter, EventType.SHOCK_RAREF_COLLISION, [shock, raref], v, boundary),
407 )
408 counter += 1
410 # 6. Rarefaction-Rarefaction collisions
411 for i, raref1 in enumerate(rarefs):
412 for raref2 in rarefs[i + 1 :]:
413 intersections = find_rarefaction_boundary_intersections(raref1, raref2, self.state.t_current)
414 for t, v, boundary in intersections:
415 if 0 <= v <= self.state.v_outlet:
416 heappush(
417 candidates,
418 (t, counter, EventType.RAREF_RAREF_COLLISION, [raref1, raref2], v, boundary),
419 )
420 counter += 1
422 # 7. Outlet crossings
423 for wave in active_waves:
424 # For rarefactions, detect BOTH head and tail crossings
425 if isinstance(wave, RarefactionWave):
426 # Head crossing
427 t_eval = max(self.state.t_current, wave.t_start)
428 v_head = wave.head_position_at_time(t_eval)
429 if v_head is not None and v_head < self.state.v_outlet:
430 vel_head = wave.head_velocity()
431 if vel_head > 0:
432 dt_head = (self.state.v_outlet - v_head) / vel_head
433 t_cross_head = t_eval + dt_head
434 if t_cross_head > self.state.t_current:
435 heappush(
436 candidates,
437 (t_cross_head, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None),
438 )
439 counter += 1
441 # Tail crossing
442 v_tail = wave.tail_position_at_time(t_eval)
443 if v_tail is not None and v_tail < self.state.v_outlet:
444 vel_tail = wave.tail_velocity()
445 if vel_tail > 0:
446 dt_tail = (self.state.v_outlet - v_tail) / vel_tail
447 t_cross_tail = t_eval + dt_tail
448 if t_cross_tail > self.state.t_current:
449 heappush(
450 candidates,
451 (t_cross_tail, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None),
452 )
453 counter += 1
454 else:
455 # For characteristics and shocks, use existing logic
456 t_cross = find_outlet_crossing(wave, self.state.v_outlet, self.state.t_current)
457 if t_cross and t_cross > self.state.t_current:
458 heappush(
459 candidates, (t_cross, counter, EventType.OUTLET_CROSSING, [wave], self.state.v_outlet, None)
460 )
461 counter += 1
463 # Return earliest event
464 if candidates:
465 # Handle 6-tuple format: (t, counter, event_type, waves, v, extra)
466 event_data = heappop(candidates)
467 t = event_data[0]
468 # Skip counter at index 1
469 event_type = event_data[2]
470 waves = event_data[3]
471 v = event_data[4]
472 extra = event_data[5] if len(event_data) > MIN_EVENT_DATA_LENGTH else None
474 # For FLOW_CHANGE events, extra contains flow_new
475 flow_new = extra if event_type == EventType.FLOW_CHANGE else None
477 # For rarefaction collision events, extra contains boundary_type
478 _raref_types = {
479 EventType.RAREF_CHAR_COLLISION,
480 EventType.SHOCK_RAREF_COLLISION,
481 EventType.RAREF_RAREF_COLLISION,
482 }
483 boundary_type = extra if event_type in _raref_types else None
485 return Event(
486 time=t,
487 event_type=event_type,
488 waves_involved=waves,
489 location=v,
490 flow_new=flow_new,
491 boundary_type=boundary_type,
492 )
494 return None
496 def handle_event(self, event: Event):
497 """
498 Handle an event by calling appropriate handler and updating state.
500 Dispatches to the correct event handler based on event type, then
501 updates the simulation state with any new waves created.
503 Parameters
504 ----------
505 event : Event
506 Event to handle
508 Notes
509 -----
510 Event handlers may:
511 - Deactivate parent waves
512 - Create new child waves
513 - Record event details in history
514 - Verify physical correctness (entropy, mass balance)
515 """
516 new_waves = []
518 if event.event_type == EventType.CHAR_CHAR_COLLISION:
519 new_waves = handle_characteristic_collision(
520 event.waves_involved[0], event.waves_involved[1], event.time, event.location
521 )
523 elif event.event_type == EventType.SHOCK_SHOCK_COLLISION:
524 new_waves = handle_shock_collision(
525 event.waves_involved[0], event.waves_involved[1], event.time, event.location
526 )
528 elif event.event_type == EventType.SHOCK_CHAR_COLLISION:
529 new_waves = handle_shock_characteristic_collision(
530 event.waves_involved[0], event.waves_involved[1], event.time, event.location
531 )
533 elif event.event_type == EventType.RAREF_CHAR_COLLISION:
534 new_waves = handle_rarefaction_characteristic_collision(
535 event.waves_involved[0],
536 event.waves_involved[1],
537 event.time,
538 event.location,
539 boundary_type=event.boundary_type,
540 )
542 elif event.event_type == EventType.SHOCK_RAREF_COLLISION:
543 new_waves = handle_shock_rarefaction_collision(
544 event.waves_involved[0],
545 event.waves_involved[1],
546 event.time,
547 event.location,
548 boundary_type=event.boundary_type,
549 )
551 elif event.event_type == EventType.RAREF_RAREF_COLLISION:
552 new_waves = handle_rarefaction_rarefaction_collision(
553 event.waves_involved[0],
554 event.waves_involved[1],
555 event.time,
556 event.location,
557 boundary_type=event.boundary_type,
558 )
560 elif event.event_type == EventType.OUTLET_CROSSING:
561 event_record = handle_outlet_crossing(event.waves_involved[0], event.time, event.location)
562 self.state.events.append(event_record)
563 return # No new waves for outlet crossing
565 elif event.event_type == EventType.FLOW_CHANGE:
566 # Get all active waves at this time
567 active_waves = [w for w in self.state.waves if w.is_active]
568 if event.flow_new is None:
569 msg = "FLOW_CHANGE event must have flow_new set"
570 raise RuntimeError(msg)
571 new_waves = handle_flow_change(event.time, event.flow_new, active_waves)
573 # Add new waves to state
574 self.state.waves.extend(new_waves)
576 # Record event
577 self.state.events.append({
578 "time": event.time,
579 "type": event.event_type.value,
580 "location": event.location,
581 "waves_before": event.waves_involved,
582 "waves_after": new_waves,
583 })
585 def run(self, max_iterations: int = 10000, *, verbose: bool = False):
586 """
587 Run simulation until no more events or max_iterations reached.
589 Processes events chronologically by repeatedly finding the next event,
590 advancing time, and handling the event. Continues until no more events
591 exist or the iteration limit is reached.
593 Parameters
594 ----------
595 max_iterations : int, optional
596 Maximum number of events to process. Default 10000.
597 Prevents infinite loops in case of bugs.
598 verbose : bool, optional
599 Print progress messages. Default False.
601 Notes
602 -----
603 The simulation stops when:
604 - No more events exist (all waves have exited or become inactive)
605 - max_iterations is reached (safety limit)
607 After completion, results are available in:
608 - self.state.waves: All waves (active and inactive)
609 - self.state.events: Complete event history
610 - self.t_first_arrival: End of spin-up period
611 """
612 iteration = 0
614 if verbose:
615 logging.info("Starting simulation at t=%.3f", self.state.t_current)
616 logging.info("Initial waves: %d", len(self.state.waves))
617 logging.info("First arrival time: %.3f days", self.t_first_arrival)
619 while iteration < max_iterations:
620 # Find next event
621 event = self.find_next_event()
623 if event is None:
624 if verbose:
625 logging.info("Simulation complete after %d events at t=%.6f", iteration, self.state.t_current)
626 break
628 # Advance time
629 self.state.t_current = event.time
631 # Handle event
632 try:
633 self.handle_event(event)
634 except Exception:
635 logging.exception("Error handling event at t=%.3f", event.time)
636 raise
638 # Optional: verify physics periodically
639 if iteration % 100 == 0:
640 self.verify_physics()
642 if verbose and iteration % 10 == 0:
643 active = sum(1 for w in self.state.waves if w.is_active)
644 logging.debug("Iteration %d: t=%.3f, active_waves=%d", iteration, event.time, active)
646 iteration += 1
648 if iteration >= max_iterations:
649 logging.warning("Reached max_iterations=%d", max_iterations)
651 if verbose:
652 logging.info("Final statistics:")
653 logging.info(" Total events: %d", len(self.state.events))
654 logging.info(" Total waves created: %d", len(self.state.waves))
655 logging.info(" Active waves: %d", sum(1 for w in self.state.waves if w.is_active))
656 logging.info(" First arrival time: %.6f days", self.t_first_arrival)
658 def verify_physics(self, *, check_mass_balance: bool = False, mass_balance_rtol: float = 1e-12):
659 """
660 Verify physical correctness of current state.
662 Implements High Priority #3 from FRONT_TRACKING_REBUILD_PLAN.md by adding
663 runtime mass balance verification using exact analytical integration.
665 Checks:
666 - All shocks satisfy Lax entropy condition
667 - All rarefactions have proper head/tail ordering
668 - Mass balance: mass_in_domain + mass_out = mass_in (to specified tolerance)
670 Parameters
671 ----------
672 check_mass_balance : bool, optional
673 Enable mass balance verification. Default False (opt-in for now).
674 mass_balance_rtol : float, optional
675 Relative tolerance for mass balance check. Default 1e-6.
676 This tolerance accounts for:
677 - Midpoint approximation in spatial integration of rarefactions
678 - Numerical precision in wave position calculations
679 - Piecewise-constant approximations in domain partitioning
681 Raises
682 ------
683 RuntimeError
684 If physics violation is detected
686 Notes
687 -----
688 Mass balance equation:
689 mass_in_domain(t) + mass_out_cumulative(t) = mass_in_cumulative(t)
691 All mass calculations use exact analytical integration where possible:
692 - Inlet/outlet temporal integrals: exact for piecewise-constant functions
693 - Domain spatial integrals: exact for constants, midpoint rule for rarefactions
694 - Overall precision: ~1e-10 to 1e-12 relative error
695 """
696 # Check entropy for all active shocks
697 for wave in self.state.waves:
698 if isinstance(wave, ShockWave) and wave.is_active and not wave.satisfies_entropy():
699 msg = (
700 f"Shock at t_start={wave.t_start:.3f} violates entropy! "
701 f"c_left={wave.c_left:.3f}, c_right={wave.c_right:.3f}, "
702 f"velocity={wave.velocity:.3f}"
703 )
704 raise RuntimeError(msg)
706 # Check rarefaction ordering
707 for wave in self.state.waves:
708 if isinstance(wave, RarefactionWave) and wave.is_active:
709 v_head = wave.head_velocity()
710 v_tail = wave.tail_velocity()
711 if v_head <= v_tail:
712 msg = (
713 f"Rarefaction at t_start={wave.t_start:.3f} has invalid ordering! "
714 f"head_velocity={v_head:.3f} <= tail_velocity={v_tail:.3f}"
715 )
716 raise RuntimeError(msg)
718 # Check mass balance using exact analytical integration
719 if check_mass_balance:
720 t_current = self.state.t_current
722 # Convert tedges from DatetimeIndex to float days for mass functions
723 # Internal simulation uses float days from tedges[0]
724 tedges_days = (self.state.tedges - self.state.tedges[0]) / pd.Timedelta(days=1)
726 # Compute total mass in domain at current time
727 mass_in_domain = compute_domain_mass(
728 t=t_current,
729 v_outlet=self.state.v_outlet,
730 waves=self.state.waves,
731 sorption=self.state.sorption,
732 )
734 # Compute cumulative inlet mass
735 mass_in_cumulative = compute_cumulative_inlet_mass(
736 t=t_current,
737 cin=self.state.cin,
738 flow=self.state.flow,
739 tedges_days=tedges_days,
740 )
742 # Compute cumulative outlet mass
743 mass_out_cumulative = compute_cumulative_outlet_mass(
744 t=t_current,
745 v_outlet=self.state.v_outlet,
746 waves=self.state.waves,
747 sorption=self.state.sorption,
748 flow=self.state.flow,
749 tedges_days=tedges_days,
750 )
752 # Mass balance: mass_in_domain + mass_out = mass_in
753 mass_balance_error = (mass_in_domain + mass_out_cumulative) - mass_in_cumulative
755 # Check relative error
756 if mass_in_cumulative > 0:
757 relative_error = abs(mass_balance_error) / mass_in_cumulative
758 else:
759 # No mass has entered yet - check absolute error is small
760 relative_error = abs(mass_balance_error)
762 if relative_error > mass_balance_rtol:
763 msg = (
764 f"Mass balance violation at t={t_current:.6f}! "
765 f"mass_in_domain={mass_in_domain:.6e}, "
766 f"mass_out={mass_out_cumulative:.6e}, "
767 f"mass_in={mass_in_cumulative:.6e}, "
768 f"error={mass_balance_error:.6e}, "
769 f"relative_error={relative_error:.6e} > {mass_balance_rtol:.6e}"
770 )
771 raise RuntimeError(msg)