Converting Complex ODE Model from as_op to Native PyMC ODE Solvers

Hi all,

I’m working on a physics model that solves a system of ODEs describing element transport in a fluid system. Currently, I’m using as_op to wrap a NumPy-based ODE solver, but I’d like to convert this to use PyMC’s native ODE capabilities (pm.ode.DifferentialEquation or sunode) for better gradient computation and sampling efficiency.

I attempted to use pm.ode.DifferentialEquation following the instruction of LLM models like Gemini and Claude, but ended in that the more questions I asked to LLM helpers, the more bugs I encountered.

So my question is if I should stick with as_op for this type of problem, or is there a clean way to make this work with native solvers?

Any guidance on the best approach would be greatly appreciated! The current as_op version works but is quite slow for MCMC sampling.

Environment

  • PyMC version: 5.12.0
  • Python 3.9
  • Sunode 0.6.0

Current Working Code (Simplified)

import numpy as np
import pymc as pm
from scipy.integrate import solve_ivp
from pytensor.compile.ops import as_op
import pytensor.tensor as pt

def physics_model(k0, Z0, W0, Z_val):
    """Return physical properties at depth Z_val in the system"""
    # Physical constants
    rho = 3000       # density
    delta_rho = 500  # density difference
    mu = 10          # viscosity [Pa·s]
    phi0 = 0.01      # reference porosity
    g = 9.8          # gravity
    
    F = 0.23/Z0 * Z_val
    Gamma = rho * W0 * 0.23 / Z0
    Phi = (-1 + np.sqrt(1 + 4*F*(delta_rho * g * k0) / (W0 * mu * phi0**2))) / (2*(delta_rho * g * k0) / (W0 * mu * phi0**2))
    w = W0 * F / Phi
    W = W0 * (1-F) / (1-Phi)
    return Gamma, Phi, W, w

def solve_concentration_odes(K, Omega, k0, Z0, W0, c_s_0, c_l_0_complex, z_vals):
    """Solve coupled ODEs for element concentrations in the system"""
    
    def ode_system(z, y):
        c_s, c_l = y
        # Get physical properties at depth z
        Gamma_z, Phi_z, W_z, w_z = physics_model(k0, Z0, W0, z)
        rho = 3000  # density
        
        # Complex coupled ODEs with depth-dependent coefficients
        dc_s_dz = -(c_s/K - c_s) * Gamma_z / ((1-Phi_z) * rho * W_z)
        dc_l_dz = +(c_s/K - c_l) * Gamma_z / (Phi_z * rho * w_z) - 1j*Omega*c_l/w_z
        return [dc_s_dz, dc_l_dz]
    
    solution = solve_ivp(
        fun=ode_system,
        t_span=(z_vals[0], z_vals[-1]),
        y0=[c_s_0, c_l_0_complex],
        t_eval=z_vals,
        method='RK45'
    )
    
    return solution.y[1]  # return liquid concentration

def _model_function(lambda_val, c_s_0_val, fluct_amp_val):
    """Model function that takes parameters and returns mean/std"""
    
    # Physical parameters
    k0 = 1e-12  # reference permeability
    Z0 = 7.0e4  # system depth
    W0 = 4 * 0.01/(365*24*3600)  # reference velocity
    
    # Calculate frequency from wavelength
    Omega = 2 * np.pi * W0 / lambda_val
    
    # Solve ODEs with depth-dependent physics
    z_vals = np.linspace(1e-6, Z0, 100)
    c_l_complex = solve_concentration_odes(
        K=0.08, Omega=Omega, k0=k0, Z0=Z0, W0=W0,
        c_s_0=c_s_0_val, 
        c_l_0_complex=fluct_amp_val + 0j,
        z_vals=z_vals
    )
    
    # Generate time series from complex solution
    c_l_top = c_l_complex[-1]  # concentration at top

    tmax = 2*np.pi/Omega
    times = np.linspace(0, tmax, 100)
    time_series = (c_l_top * np.exp(1j * Omega * times)).real
    
    return np.mean(time_series), np.std(time_series)

# Wrap with as_op
@as_op(itypes=[pt.dscalar, pt.dscalar, pt.dscalar], 
       otypes=[pt.dscalar, pt.dscalar])
def pymc_model(lambda_pm, c_s_0_pm, fluct_amp_pm):
    return _model_function(lambda_pm, c_s_0_pm, fluct_amp_pm)

# PyMC model
with pm.Model() as model:
    lambda_0 = pm.Uniform('lambda_0', 1000, 8000)
    c_s_0 = pm.Uniform('c_s_0', 15, 30)
    fluct_amp = pm.Uniform('fluct_amp', 1, 4)
    
    mean_pred, std_pred = pymc_model(lambda_0, c_s_0, fluct_amp)
    
    # Likelihood
    pm.Normal('obs_mean', mu=mean_pred, sigma=0.1, observed=25.3)
    pm.HalfNormal('obs_std', sigma=std_pred, observed=1.24)
    
    trace = pm.sample(1000, tune=500)

cc: @aseyboldt

I don’t think there’s anything wrong with your current method. It won’t be able to compute the gradient of the posterior density, so you can’t use samplers like NUTS, but since you only seem to have 4 parameters, metropolis or smc should probably still work (but make sure to check convergence carefully!). You’ll need many more than 1000 draws though, and also more tuning. Metropolis draws are not nearly as efficient as NUTS draws. You can also help the sampler by rescaling your variables a bit, for instance

lambda_0_raw = pm.Uniform("lambda_0_raw", 0, 1)
lambda_0 = pm.Deterministic("lambda_0", lambda_0_raw * 7000 + 1000)

If you want to use gradient based samplers, I’d currently go for sunode or diffrax, both should work fine. I didn’t have much time to work on sunode for a while now however, and diffrax is for the most part I think better by now. With diffrax, you’ll still have to wrap the solver manually in an op though, (unless someone finishes this PR). For an example, see PYMC Labs | Blog/aaz05zflou9tru7gm475dxwl

Thanks for the helpful response! You’re right that for my current 3-parameter model, the as_op approach works fine with Metropolis sampling. I appreciate the variable rescaling tip as well.

The current code is a very primary version. Since I do plan to expand this to more parameters eventually, I’m interested in exploring gradient-based options. I tried sunode following LLM suggestions but ran into issues getting it to work properly.

Regarding diffrax - the link you shared appears to be invalid on my end. Could you point me to other tutorials or examples for using diffrax with PyMC? I’d love to explore this approach for future model extensions.

Thanks again for the guidance!

I fixed the link in the above post.
Looks like a fun problem by the way! Is this some geochemical model? Reminds me a bit of reactive transport problems, but not quite :slight_smile:

Great! Thx for the new link. I will go through this example thoroughly and apply the diffrax to my code.

Amazing guess, it is a geochemical model! I am new to PyMC and I never expect to meet you, a core developer, who are so familiar with the geochemical theory haha!