Techno Blender
Digitally Yours.

Differential Equations as a Pytorch Neural Network Layer | by Kevin Hannay | Apr, 2023

0 39


Differential equations are the mathematical foundation for most of modern science. They describe the state of a system using an equation for the rate of change (differential). It is remarkable how many systems can be well described by equations of this form. For example, the physical laws describing motion, electromagnetism and quantum mechanics all take this form. More broadly, differential equations describe chemical reaction rates through the law of mass action, neuronal firing and disease spread through the SIR model.

The deep learning revolution has brought with it a new set of tools for performing large scale optimizations over enormous datasets. In this post, we will see how you can use these tools to fit the parameters of a custom differential equation layer in pytorch.

What is the problem we are trying to solve?

Let’s say we have some time series data y(t) that we want to model with a differential equation. The data takes the form of a set of observations yᵢ at times tᵢ. Based on some domain knowledge of the underlying system we can write down a differential equation to approximate the system.

In the most general form this takes the form:

General ordinary differential equation system

where y is the state of the system, t is time, and θ are the parameters of the model. In this post we will assume that the parameters θ are unknown and we want to learn them from the data.

All of the code for this post is available on github or as a colab notebook, so no need to try and copy and paste if you want to follow along.

Let’s import the libraries we will need for this post. The only non standard machine learning library we will use the torchdiffeq library to solve the differential equations. This library implements numerical differential equation solvers in pytorch.

import torch 
import torch.nn as nn
from torchdiffeq import odeint as odeint
import pylab as plt
from torch.utils.data import Dataset, DataLoader
from typing import Callable, List, Tuple, Union, Optional
from pathlib import Path

Models

The first step of our modeling process is to define the model. For differential equations this means we must choose a form for the function f(y,t;θ) and a way to represent the parameters θ. We also need to do this in a way that is compatible with pytorch.

This means we need to encode our function as a torch.nn.Module class. As you will see this is pretty easy and only requires defining two methods. Lets get started with the first of out three example models.

van Der Pol Oscillator (VDP)

We can define a differential equation system using the torch.nn.Module class where the parameters are created using the torch.nn.Parameter declaration. This lets pytorch know that we want to accumulate gradients for those parameters. We can also include fixed parameters (parameters that we don’t want to fit) by just not wrapping them with this declaration.

The first example we will use is the classic VDP oscillator which is a nonlinear oscillator with a single parameter μ. The differential equations for this system are:

VDP oscillator equations

where x and y are the state variables. The VDP model is used to model everything from electronic circuits to cardiac arrhythmias and circadian rhythms. We can define this system in pytorch as follows:

class VDP(nn.Module):
"""
Define the Van der Pol oscillator as a PyTorch module.
"""
def __init__(self,
mu: float, # Stiffness parameter of the VDP oscillator
):
super().__init__()
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter

def forward(self,
t: float, # time index
state: torch.TensorType, # state of the system first dimension is the batch size
) -> torch.Tensor: # return the derivative of the state
"""
Define the right hand side of the VDP oscillator.
"""
x = state[..., 0] # first dimension is the batch size
y = state[..., 1]
dX = self.mu*(x-1/3*x**3 - y)
dY = 1/self.mu*x
# trick to make sure our return value has the same shape as the input
dfunc = torch.zeros_like(state)
dfunc[..., 0] = dX
dfunc[..., 1] = dY
return dfunc

def __repr__(self):
"""Print the parameters of the model."""
return f" mu: {self.mu.item()}"

You only need to define the __init__ method (init) and the forward method. I added a string method __repr__ to pretty print the parameter. The key point here is how we can translate from the differential equation to torch code in the forward method. This method needs to define the right-hand side of the differential equation.

Let’s see how we can integrate this model using the odeint method from torchdiffeq:

vdp_model = VDP(mu=0.5)

# Create a time vector, this is the time axis of the ODE
ts = torch.linspace(0,30.0,1000)
# Create a batch of initial conditions
batch_size = 30
# Creates some random initial conditions
initial_conditions = torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2))

# Solve the ODE, odeint comes from torchdiffeq
sol = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy()
plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the VDP oscillator");
plt.xlabel("time");
plt.ylabel("x");p

png

Here is a phase plane plot of the solution (a phase plane plot of a parametric plot of the dynamical state).

# Check the solution
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5);
plt.title("Phase plot of the VDP oscillator");
plt.xlabel("x");
plt.ylabel("y");
png

The colors indicate the 30 separate trajectories in our batch. The solution comes back as a torch tensor with dimensions (time_points, batch number, dynamical_dimension).

sol.shape
(1000, 30, 2)

Lotka Volterra Predator Prey equations

As another example we create a module for the Lotka-Volterra predator-prey equations. In the Lotka-Volterra (LV) predator-prey model, there are two primary variables: the population of prey (x) and the population of predators (y). The model is defined by the following equations:

Lotka-Volterra equations for predator prey dynamics

In addition to the primary variables, there are also four parameters that are used to describe various ecological factors in the model:

α represents the intrinsic growth rate of the prey population in the absence of predators. β represents the predation rate of the predators on the prey. γ represents the death rate of the predator population in the absence of prey. δ represents the efficiency with which the predators convert the consumed prey into new predator biomass.

Together, these variables and parameters describe the dynamics of predator-prey interactions in an ecosystem and are used to mathematically model the changes in the populations of prey and predators over time. Here is this system as a torch.nn.Module:

class LotkaVolterra(nn.Module):
"""
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations
describing the dynamics of two species interacting in a predator-prey relationship.
"""
def __init__(self,
alpha: float = 1.5, # The alpha parameter of the Lotka-Volterra system
beta: float = 1.0, # The beta parameter of the Lotka-Volterra system
delta: float = 3.0, # The delta parameter of the Lotka-Volterra system
gamma: float = 1.0 # The gamma parameter of the Lotka-Volterra system
) -> None:
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma]))

def forward(self, t, state):
x = state[...,0] #variables are part of vector array u
y = state[...,1]
sol = torch.zeros_like(state)

#coefficients are part of tensor model_params
alpha, beta, delta, gamma = self.model_params
sol[...,0] = alpha*x - beta*x*y
sol[...,1] = -delta*y + gamma*x*y
return sol

def __repr__(self):
return f" alpha: {self.model_params[0].item()}, \
beta: {self.model_params[1].item()}, \
delta: {self.model_params[2].item()}, \
gamma: {self.model_params[3].item()}"

This follows the same pattern as the first example, the main difference is that we now have four parameters and store them as a model_params tensor. Here is the integration and plotting code for the predator-prey equations.

lv_model = LotkaVolterra() #use default parameters
ts = torch.linspace(0,30.0,1000)
batch_size = 30
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[3,3]]) + 0.50*torch.randn((batch_size,2))
sol = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution

plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the Lotka-Volterra system");
plt.xlabel("time");
plt.ylabel("x");

png

Now a phase plane plot of the system:

png

Lorenz system

The last example we will use is the Lorenz equations which are famous for their beautiful plots illustrating chaotic dynamics. They originally came from a reduced model for fluid dynamics and take the form:

where x, y, and z are the state variables, and σ, ρ, and β are the system parameters.

class Lorenz(nn.Module):
"""
Define the Lorenz system as a PyTorch module.
"""
def __init__(self,
sigma: float =10.0, # The sigma parameter of the Lorenz system
rho: float=28.0, # The rho parameter of the Lorenz system
beta: float=8.0/3, # The beta parameter of the Lorenz system
):
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta]))

def forward(self, t, state):
x = state[...,0] #variables are part of vector array u
y = state[...,1]
z = state[...,2]
sol = torch.zeros_like(state)

sigma, rho, beta = self.model_params #coefficients are part of vector array p
sol[...,0] = sigma*(y-x)
sol[...,1] = x*(rho-z) - y
sol[...,2] = x*y - beta*z
return sol

def __repr__(self):
return f" sigma: {self.model_params[0].item()}, \
rho: {self.model_params[1].item()}, \
beta: {self.model_params[2].item()}"

This shows how to integrate this system and plot the results. This system (at these parameter values) shows chaotic dynamics so initial conditions that start off close together diverge from one another exponentially.

lorenz_model = Lorenz()
ts = torch.linspace(0,50.0,3000)
batch_size = 30
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((batch_size,3))
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()

# Check the solution
plt.plot(ts[:2000], sol[:2000,:,0], lw=0.5);
plt.title("Time series of the Lorenz system");
plt.xlabel("time");
plt.ylabel("x");

png

Here we show the famous butterfly plot (phase plane plot) for the first set of initial conditions in the batch.

png
The Lorenz equations show chaotic dynamics and trace out a beautiful strange attractor.

Data

Now that we can define the differential equation models in pytorch we need to create some data to be used in training. This is where things start to get really neat as we see our first glimpse of being able to hijack deep learning machinery for fitting the parameters. Really we could just use tensor of data directly, but this is a nice way to organize the data. It will also be useful if you have some experimental data that you want to use.

Torch provides the Dataset class for loading in data. To use it you just need to create a subclass and define two methods. The __len__ function that returns the number of data points and a __getitem__ function that returns the data point at a given index. If you are wondering these methods are what underly the len(array) and array[0] subscript access in python lists.

The rest of boilerplate code needed in defined in the parent class torch.utils.data.Dataset. We will see the power of these method when we go to define a training loop.

class SimODEData(Dataset):
"""
A very simple dataset class for simulating ODEs
"""
def __init__(self,
ts: List[torch.Tensor], # List of time points as tensors
values: List[torch.Tensor], # List of dynamical state values (tensor) at each time point
true_model: Union[torch.nn.Module,None] = None,
) -> None:
self.ts = ts
self.values = values
self.true_model = true_model

def __len__(self) -> int:
return len(self.ts)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.ts[index], self.values[index]

Next let’s create a quick generator function to generate some simulated data to test the algorithms on. In a real use case the data would be loaded from a file or database- but for this example we will just generate some data. In fact, I recommend that you always start with generated data to make sure your code is working before you try to load real data.

def create_sim_dataset(model: nn.Module, # model to simulate from
ts: torch.Tensor, # Time points to simulate for
num_samples: int = 10, # Number of samples to generate
sigma_noise: float = 0.1, # Noise level to add to the data
initial_conditions_default: torch.Tensor = torch.tensor([0.0, 0.0]), # Default initial conditions
sigma_initial_conditions: float = 0.1, # Noise level to add to the initial conditions
) -> SimODEData:
ts_list = []
states_list = []
dim = initial_conditions_default.shape[0]
for i in range(num_samples):
x0 = sigma_initial_conditions * torch.randn((1,dim)).detach() + initial_conditions_default
ys = odeint(model, x0, ts).squeeze(1).detach()
ys += sigma_noise*torch.randn_like(ys)
ys[0,:] = x0 # Set the first value to the initial condition
ts_list.append(ts)
states_list.append(ys)
return SimODEData(ts_list, states_list, true_model=model)

This just takes in a differential equation model with some initial states and generates some time-series data from it (and adds in some gaussian noise). This data is then passed into our custom dataset container.

Training Loop

Next we will create a wrapper function for a pytorch training loop. Training means we want to update the model parameters to increase the alignment with the data (or decrease the cost function).

One of the tricks for this from deep learning is to not use all the data before taking a gradient step. Part of this is necessity for using enormous datasets as you can’t fit all of that data inside a GPU’s memory, but this also can help the gradient descent algorithm avoid getting stuck in local minima.

The training loop in words:

  • Divide the dataset into mini-batches, these are subsets of your entire data set. Usually want to choose these randomly.
  • Iterate through the mini-batches
  • Generate the predictions using the current model parameters
  • Calculate the loss (here we will use the mean squared error)
  • Calculate the gradients, using backpropagation.
  • Update the parameters using a gradient descent step. Here we use the Adam optimizer.
  • Each full pass through the dataset is called an epoch.

Okay here is the code:

def train(model: torch.nn.Module, # Model to train
data: SimODEData, # Data to train on
lr: float = 1e-2, # learning rate for the Adam optimizer
epochs: int = 10, # Number of epochs to train for
batch_size: int = 5, # Batch size for training
method = 'rk4', # ODE solver to use
step_size: float = 0.10, # for fixed diffeq solver set the step size
show_every: int = 10, # How often to print the loss function message
save_plots_every: Union[int,None] = None, # save a plot of the fit, to disable make this None
model_name: str = "", #string for the model, used to reference the saved plots
*args: tuple,
**kwargs: dict
):

# Create a data loader to iterate over the data. This takes in our dataset and returns batches of data
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True)
# Choose an optimizer. Adam is a good default choice as a fancy gradient descent
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Create a loss function this computes the error between the predicted and true values
criterion = torch.nn.MSELoss()

for epoch in range(epochs):
running_loss = 0.0
for batchdata in trainloader:
optimizer.zero_grad() # reset gradients, famous gotcha in a pytorch training loop
ts, states = batchdata # unpack the data
initial_state = states[:,0,:] # grab the initial state
# Make the prediction and then flip the dimensions to be (batch, state_dim, time)
# Pytorch expects the batch dimension to be first
pred = odeint(model,
initial_state,
ts[0],
method=method,
options={'step_size': step_size}).transpose(0,1)
# Compute the loss
loss = criterion(pred, states)
# compute gradients
loss.backward()
# update parameters
optimizer.step()
running_loss += loss.item() # record loss
if epoch % show_every == 0:
print(f"Loss at {epoch}: {running_loss}")
# Use this to save plots of the fit every save_plots_every epochs
if save_plots_every is not None and epoch % save_plots_every == 0:
with torch.no_grad():
fig, ax = plot_time_series(data.true_model, model, data[0])
ax.set_title(f"Epoch: {epoch}")
fig.savefig(f"./tmp_plots/{epoch}_{model_name}_fit_plot")
plt.close()

Examples

Fitting the VDP Oscillator

Let’s use this training loop to recover the parameters from simulated VDP oscillator data.

true_mu = 0.30
model_sim = VDP(mu=true_mu)
ts_data = torch.linspace(0.0,10.0,10)
data_vdp = create_sim_dataset(model_sim,
ts = ts_data,
num_samples=10,
sigma_noise=0.01)

Let’s create a model with the wrong parameter value and visualize the starting point.

vdp_model = VDP(mu = 0.10) 
plot_time_series(model_sim,
vdp_model,
data_vdp[0],
dyn_var_idx=1,
title = "VDP Model: Before Parameter Fits");
png

Now, we will use the training loop to fit the parameters of the VDP oscillator to the simulated data.

train(vdp_model, data_vdp, epochs=50, model_name="vdp");
print(f"After training: {vdp_model}, where the true value is {true_mu}")
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}")
Loss at 0: 0.1369624137878418
Loss at 10: 0.0073615191504359245
Loss at 20: 0.0009214915917254984
Loss at 30: 0.0002127257248503156
Loss at 40: 0.00019956652977271006
After training: mu: 0.3018421530723572, where the true value is 0.3
Final Parameter Recovery Error: 0.0018421411514282227

Not to bad! Let’s see how the plot looks now…

png

The plot confirms that we almost perfectly recovered the parameter. One more quick plot, where we plot the dynamics of the system in the phase plane (a parametric plot of the state variables).

png

Here is a visual of the training process for this model:

Video of the training process for the VDP model

Lotka Voltera Equations

Now lets adapt our methods to fit simulated data from the Lotka-Volterra equations.

model_sim_lv = LotkaVolterra(1.5,1.0,3.0,1.0)
ts_data = torch.arange(0.0, 10.0, 0.1)
data_lv = create_sim_dataset(model_sim_lv,
ts = ts_data,
num_samples=10,
sigma_noise=0.1,
initial_conditions_default=torch.tensor([2.5, 2.5]))
model_lv = LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2)

plot_time_series(model_sim_lv, model_lv, data = data_lv[0], title = "Lotka Volterra: Before Fitting");

Here is the initial fits for the starting parameters, then we will fit as before and take a look at the results.

png
train(model_lv, data_lv, epochs=60, lr=1e-2, model_name="lotkavolterra")
print(f"Fitted model: {model_lv}")
print(f"True model: {model_sim_lv}")
Loss at 0: 1.1298701763153076
Loss at 10: 0.1296287178993225
Loss at 20: 0.045993587002158165
Loss at 30: 0.02311511617153883
Loss at 40: 0.020882505923509598
Loss at 50: 0.020726025104522705
Fitted model: alpha: 1.5965800285339355, beta: 1.0465354919433594, delta: 2.817030429840088, gamma: 0.939825177192688
True model: alpha: 1.5, beta: 1.0, delta: 3.0, gamma: 1.0

First a time-series plot of the fitted system:

png

Now let’s visualize the results using a phase plane plot.

png

Here is a visual of the fitting process….

Lorenz Equations

Finally, let’s try to fit the Lorenz equations.

model_sim_lorenz = Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0)
ts_data = torch.arange(0, 10.0, 0.05)
data_lorenz = create_sim_dataset(model_sim_lorenz,
ts = ts_data,
num_samples=30,
initial_conditions_default=torch.tensor([1.0, 0.0, 0.0]),
sigma_noise=0.01,
sigma_initial_conditions=0.10)
lorenz_model = Lorenz(sigma=10.2, rho=28.2, beta=9.0/3)
fig, ax = plot_time_series(model_sim_lorenz, lorenz_model, data_lorenz[0], title="Lorenz Model: Before Fitting");

ax.set_xlim((2,15));

Here is the initial fits, then we will call our training loop.

png
train(lorenz_model, 
data_lorenz,
epochs=300,
batch_size=5,
method = 'rk4',
step_size=0.05,
show_every=50,
lr = 1e-3)

Here are the training results:

Loss at 0: 114.25119400024414
Loss at 50: 4.364489555358887
Loss at 100: 2.055854558944702
Loss at 150: 1.2539702206850052
Loss at 200: 0.7839434593915939
Loss at 250: 0.5347371995449066

Let’s look at the fitted model. Starting with a full plot of the dynamics.

png

Let’s zoom in on the bulk of the data and see how the fit looks.

png

You can see the model is very close to the true model for the data range, and generalizes well for t < 16 for the unseen data. Now the phase plane plot (zoomed in).

plot_phase_plane(model_sim_lorenz, lorenz_model, data_lorenz[0], title = "Lorenz Model: After Fitting", time_range=(0,20.0));
png

You can see that our fitted model performs well for t in [0,16] and then starts to diverge.

This procedure works great for the situation where we know the form of the equations on the right-hand-side, but what if we don’t? Can we use this procedure to discover the model equations?

This is much too big of a subject to fully cover in this post, but one of the biggest advantages of moving our differential equations models into the torch framework is that we can mix and match them with artificial neural network layers.

The simplest thing we can do is to replace the right-hand-side f(y,t; θ) with a neural network layer. These types of equations have been called a neural differential equations and it can be viewed as generalization of a recurrent neural network.

Neural differential equations replace the right-hand-side of the equations with a artificial neural network

As a first example, let’s do this for the our simple VDP oscillator system. To begin we will remake the simulated data, you will notice that I am creating longer time-series of the data and more samples. Fitting a neural differential equation takes much more data and more computational power since we have many more parameters that need to be determined.

# remake the data 
model_sim_vdp = VDP(mu=0.20)
ts_data = torch.linspace(0.0,30.0,100) # longer time series than the custom ode layer
data_vdp = create_sim_dataset(model_sim_vdp,
ts = ts_data,
num_samples=30, # more samples than the custom ode layer
sigma_noise=0.1,
initial_conditions_default=torch.tensor([0.50,0.10]))

Now I define a simple feedforward neural network layer to fill in the right-hand-side of the equation.


class NeuralDiffEq(nn.Module):
"""
Basic Neural ODE model
"""
def __init__(self,
dim: int = 2, # dimension of the state vector
) -> None:
super().__init__()
self.ann = nn.Sequential(torch.nn.Linear(dim, 8),
torch.nn.LeakyReLU(),
torch.nn.Linear(8, 16),
torch.nn.LeakyReLU(),
torch.nn.Linear(16, 32),
torch.nn.LeakyReLU(),
torch.nn.Linear(32, dim))

def forward(self, t, state):
return self.ann(state)

Here is a plot of the system before fitting:

model_vdp_nde = NeuralDiffEq(dim=2) 
plot_time_series(model_sim_vdp, model_vdp_nde, data_vdp[0], title = "Neural ODE: Before Fitting");
png

You can see we start very far away for the correct solution, but then again we are injecting much less information into our model. Let’s see if we can fit the model to get better results.

train(model_vdp_nde, 
data_vdp,
epochs=1500,
lr=1e-3,
batch_size=5,
show_every=100,
model_name = "nde")
Loss at 0: 84.39617252349854
Loss at 100: 84.34061241149902
Loss at 200: 73.75008296966553
Loss at 300: 3.4929964542388916
Loss at 400: 1.6555403769016266
Loss at 500: 0.7814530655741692
Loss at 600: 0.41551147401332855
Loss at 700: 0.3157300055027008
Loss at 800: 0.19066352397203445
Loss at 900: 0.15869349241256714
Loss at 1000: 0.12904016114771366
Loss at 1100: 0.23840919509530067
Loss at 1200: 0.1681726910173893
Loss at 1300: 0.09865255374461412
Loss at 1400: 0.09134986530989408

Visualizing the results, we can see that the model is able to fit the data and even extrapolate to the future (although it is not as good or fast as the specified model).

png

Now the phase plane plot of our neural differential equation model.

png

These models take a long time to train and more data to converge on a good fit. This makes sense since we are both trying to learn the model and the parameters at the same time.

Video of the fitting process for a NDE.

Conclusions and Wrap-Up

In this article I have demonstrated how we can use differential equation models within the pytorch ecosytem using the torchdiffeq package. The code from this article is available on github and can be opened directly to google colab for experimentation. You can also install the code from this article using pip:

pip install paramfittorchdemo

This post is an introduction in the future I will be writing more about the following topics:

  • How to blend some mechanistic knowledge of the dynamics with deep learning. These have been called universal differential equations as they enable us to combine scientific knowledge with deep learning. This basically blends the two approaches together.
  • How to combine differential equation layers with other deep learning layers.
  • Model discovery: Can we recover the actual model equations from data? This uses tools like SINDy to extract the model equations from data.
  • MLOps tools for managing the training of these models. This includes tools like MLFlow , Weights and Biases , and Tensorboard .
  • Anything else I hear back about from you!

If you liked this post, be sure to follow me and connect on linked-in. All images unless otherwise noted are by the author.


Differential equations are the mathematical foundation for most of modern science. They describe the state of a system using an equation for the rate of change (differential). It is remarkable how many systems can be well described by equations of this form. For example, the physical laws describing motion, electromagnetism and quantum mechanics all take this form. More broadly, differential equations describe chemical reaction rates through the law of mass action, neuronal firing and disease spread through the SIR model.

The deep learning revolution has brought with it a new set of tools for performing large scale optimizations over enormous datasets. In this post, we will see how you can use these tools to fit the parameters of a custom differential equation layer in pytorch.

What is the problem we are trying to solve?

Let’s say we have some time series data y(t) that we want to model with a differential equation. The data takes the form of a set of observations yᵢ at times tᵢ. Based on some domain knowledge of the underlying system we can write down a differential equation to approximate the system.

In the most general form this takes the form:

General ordinary differential equation system

where y is the state of the system, t is time, and θ are the parameters of the model. In this post we will assume that the parameters θ are unknown and we want to learn them from the data.

All of the code for this post is available on github or as a colab notebook, so no need to try and copy and paste if you want to follow along.

Let’s import the libraries we will need for this post. The only non standard machine learning library we will use the torchdiffeq library to solve the differential equations. This library implements numerical differential equation solvers in pytorch.

import torch 
import torch.nn as nn
from torchdiffeq import odeint as odeint
import pylab as plt
from torch.utils.data import Dataset, DataLoader
from typing import Callable, List, Tuple, Union, Optional
from pathlib import Path

Models

The first step of our modeling process is to define the model. For differential equations this means we must choose a form for the function f(y,t;θ) and a way to represent the parameters θ. We also need to do this in a way that is compatible with pytorch.

This means we need to encode our function as a torch.nn.Module class. As you will see this is pretty easy and only requires defining two methods. Lets get started with the first of out three example models.

van Der Pol Oscillator (VDP)

We can define a differential equation system using the torch.nn.Module class where the parameters are created using the torch.nn.Parameter declaration. This lets pytorch know that we want to accumulate gradients for those parameters. We can also include fixed parameters (parameters that we don’t want to fit) by just not wrapping them with this declaration.

The first example we will use is the classic VDP oscillator which is a nonlinear oscillator with a single parameter μ. The differential equations for this system are:

VDP oscillator equations

where x and y are the state variables. The VDP model is used to model everything from electronic circuits to cardiac arrhythmias and circadian rhythms. We can define this system in pytorch as follows:

class VDP(nn.Module):
"""
Define the Van der Pol oscillator as a PyTorch module.
"""
def __init__(self,
mu: float, # Stiffness parameter of the VDP oscillator
):
super().__init__()
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter

def forward(self,
t: float, # time index
state: torch.TensorType, # state of the system first dimension is the batch size
) -> torch.Tensor: # return the derivative of the state
"""
Define the right hand side of the VDP oscillator.
"""
x = state[..., 0] # first dimension is the batch size
y = state[..., 1]
dX = self.mu*(x-1/3*x**3 - y)
dY = 1/self.mu*x
# trick to make sure our return value has the same shape as the input
dfunc = torch.zeros_like(state)
dfunc[..., 0] = dX
dfunc[..., 1] = dY
return dfunc

def __repr__(self):
"""Print the parameters of the model."""
return f" mu: {self.mu.item()}"

You only need to define the __init__ method (init) and the forward method. I added a string method __repr__ to pretty print the parameter. The key point here is how we can translate from the differential equation to torch code in the forward method. This method needs to define the right-hand side of the differential equation.

Let’s see how we can integrate this model using the odeint method from torchdiffeq:

vdp_model = VDP(mu=0.5)

# Create a time vector, this is the time axis of the ODE
ts = torch.linspace(0,30.0,1000)
# Create a batch of initial conditions
batch_size = 30
# Creates some random initial conditions
initial_conditions = torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2))

# Solve the ODE, odeint comes from torchdiffeq
sol = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy()
plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the VDP oscillator");
plt.xlabel("time");
plt.ylabel("x");p

png

Here is a phase plane plot of the solution (a phase plane plot of a parametric plot of the dynamical state).

# Check the solution
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5);
plt.title("Phase plot of the VDP oscillator");
plt.xlabel("x");
plt.ylabel("y");
png

The colors indicate the 30 separate trajectories in our batch. The solution comes back as a torch tensor with dimensions (time_points, batch number, dynamical_dimension).

sol.shape
(1000, 30, 2)

Lotka Volterra Predator Prey equations

As another example we create a module for the Lotka-Volterra predator-prey equations. In the Lotka-Volterra (LV) predator-prey model, there are two primary variables: the population of prey (x) and the population of predators (y). The model is defined by the following equations:

Lotka-Volterra equations for predator prey dynamics

In addition to the primary variables, there are also four parameters that are used to describe various ecological factors in the model:

α represents the intrinsic growth rate of the prey population in the absence of predators. β represents the predation rate of the predators on the prey. γ represents the death rate of the predator population in the absence of prey. δ represents the efficiency with which the predators convert the consumed prey into new predator biomass.

Together, these variables and parameters describe the dynamics of predator-prey interactions in an ecosystem and are used to mathematically model the changes in the populations of prey and predators over time. Here is this system as a torch.nn.Module:

class LotkaVolterra(nn.Module):
"""
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations
describing the dynamics of two species interacting in a predator-prey relationship.
"""
def __init__(self,
alpha: float = 1.5, # The alpha parameter of the Lotka-Volterra system
beta: float = 1.0, # The beta parameter of the Lotka-Volterra system
delta: float = 3.0, # The delta parameter of the Lotka-Volterra system
gamma: float = 1.0 # The gamma parameter of the Lotka-Volterra system
) -> None:
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma]))

def forward(self, t, state):
x = state[...,0] #variables are part of vector array u
y = state[...,1]
sol = torch.zeros_like(state)

#coefficients are part of tensor model_params
alpha, beta, delta, gamma = self.model_params
sol[...,0] = alpha*x - beta*x*y
sol[...,1] = -delta*y + gamma*x*y
return sol

def __repr__(self):
return f" alpha: {self.model_params[0].item()}, \
beta: {self.model_params[1].item()}, \
delta: {self.model_params[2].item()}, \
gamma: {self.model_params[3].item()}"

This follows the same pattern as the first example, the main difference is that we now have four parameters and store them as a model_params tensor. Here is the integration and plotting code for the predator-prey equations.

lv_model = LotkaVolterra() #use default parameters
ts = torch.linspace(0,30.0,1000)
batch_size = 30
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[3,3]]) + 0.50*torch.randn((batch_size,2))
sol = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution

plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the Lotka-Volterra system");
plt.xlabel("time");
plt.ylabel("x");

png

Now a phase plane plot of the system:

png

Lorenz system

The last example we will use is the Lorenz equations which are famous for their beautiful plots illustrating chaotic dynamics. They originally came from a reduced model for fluid dynamics and take the form:

where x, y, and z are the state variables, and σ, ρ, and β are the system parameters.

class Lorenz(nn.Module):
"""
Define the Lorenz system as a PyTorch module.
"""
def __init__(self,
sigma: float =10.0, # The sigma parameter of the Lorenz system
rho: float=28.0, # The rho parameter of the Lorenz system
beta: float=8.0/3, # The beta parameter of the Lorenz system
):
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta]))

def forward(self, t, state):
x = state[...,0] #variables are part of vector array u
y = state[...,1]
z = state[...,2]
sol = torch.zeros_like(state)

sigma, rho, beta = self.model_params #coefficients are part of vector array p
sol[...,0] = sigma*(y-x)
sol[...,1] = x*(rho-z) - y
sol[...,2] = x*y - beta*z
return sol

def __repr__(self):
return f" sigma: {self.model_params[0].item()}, \
rho: {self.model_params[1].item()}, \
beta: {self.model_params[2].item()}"

This shows how to integrate this system and plot the results. This system (at these parameter values) shows chaotic dynamics so initial conditions that start off close together diverge from one another exponentially.

lorenz_model = Lorenz()
ts = torch.linspace(0,50.0,3000)
batch_size = 30
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((batch_size,3))
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()

# Check the solution
plt.plot(ts[:2000], sol[:2000,:,0], lw=0.5);
plt.title("Time series of the Lorenz system");
plt.xlabel("time");
plt.ylabel("x");

png

Here we show the famous butterfly plot (phase plane plot) for the first set of initial conditions in the batch.

png
The Lorenz equations show chaotic dynamics and trace out a beautiful strange attractor.

Data

Now that we can define the differential equation models in pytorch we need to create some data to be used in training. This is where things start to get really neat as we see our first glimpse of being able to hijack deep learning machinery for fitting the parameters. Really we could just use tensor of data directly, but this is a nice way to organize the data. It will also be useful if you have some experimental data that you want to use.

Torch provides the Dataset class for loading in data. To use it you just need to create a subclass and define two methods. The __len__ function that returns the number of data points and a __getitem__ function that returns the data point at a given index. If you are wondering these methods are what underly the len(array) and array[0] subscript access in python lists.

The rest of boilerplate code needed in defined in the parent class torch.utils.data.Dataset. We will see the power of these method when we go to define a training loop.

class SimODEData(Dataset):
"""
A very simple dataset class for simulating ODEs
"""
def __init__(self,
ts: List[torch.Tensor], # List of time points as tensors
values: List[torch.Tensor], # List of dynamical state values (tensor) at each time point
true_model: Union[torch.nn.Module,None] = None,
) -> None:
self.ts = ts
self.values = values
self.true_model = true_model

def __len__(self) -> int:
return len(self.ts)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.ts[index], self.values[index]

Next let’s create a quick generator function to generate some simulated data to test the algorithms on. In a real use case the data would be loaded from a file or database- but for this example we will just generate some data. In fact, I recommend that you always start with generated data to make sure your code is working before you try to load real data.

def create_sim_dataset(model: nn.Module, # model to simulate from
ts: torch.Tensor, # Time points to simulate for
num_samples: int = 10, # Number of samples to generate
sigma_noise: float = 0.1, # Noise level to add to the data
initial_conditions_default: torch.Tensor = torch.tensor([0.0, 0.0]), # Default initial conditions
sigma_initial_conditions: float = 0.1, # Noise level to add to the initial conditions
) -> SimODEData:
ts_list = []
states_list = []
dim = initial_conditions_default.shape[0]
for i in range(num_samples):
x0 = sigma_initial_conditions * torch.randn((1,dim)).detach() + initial_conditions_default
ys = odeint(model, x0, ts).squeeze(1).detach()
ys += sigma_noise*torch.randn_like(ys)
ys[0,:] = x0 # Set the first value to the initial condition
ts_list.append(ts)
states_list.append(ys)
return SimODEData(ts_list, states_list, true_model=model)

This just takes in a differential equation model with some initial states and generates some time-series data from it (and adds in some gaussian noise). This data is then passed into our custom dataset container.

Training Loop

Next we will create a wrapper function for a pytorch training loop. Training means we want to update the model parameters to increase the alignment with the data (or decrease the cost function).

One of the tricks for this from deep learning is to not use all the data before taking a gradient step. Part of this is necessity for using enormous datasets as you can’t fit all of that data inside a GPU’s memory, but this also can help the gradient descent algorithm avoid getting stuck in local minima.

The training loop in words:

  • Divide the dataset into mini-batches, these are subsets of your entire data set. Usually want to choose these randomly.
  • Iterate through the mini-batches
  • Generate the predictions using the current model parameters
  • Calculate the loss (here we will use the mean squared error)
  • Calculate the gradients, using backpropagation.
  • Update the parameters using a gradient descent step. Here we use the Adam optimizer.
  • Each full pass through the dataset is called an epoch.

Okay here is the code:

def train(model: torch.nn.Module, # Model to train
data: SimODEData, # Data to train on
lr: float = 1e-2, # learning rate for the Adam optimizer
epochs: int = 10, # Number of epochs to train for
batch_size: int = 5, # Batch size for training
method = 'rk4', # ODE solver to use
step_size: float = 0.10, # for fixed diffeq solver set the step size
show_every: int = 10, # How often to print the loss function message
save_plots_every: Union[int,None] = None, # save a plot of the fit, to disable make this None
model_name: str = "", #string for the model, used to reference the saved plots
*args: tuple,
**kwargs: dict
):

# Create a data loader to iterate over the data. This takes in our dataset and returns batches of data
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True)
# Choose an optimizer. Adam is a good default choice as a fancy gradient descent
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Create a loss function this computes the error between the predicted and true values
criterion = torch.nn.MSELoss()

for epoch in range(epochs):
running_loss = 0.0
for batchdata in trainloader:
optimizer.zero_grad() # reset gradients, famous gotcha in a pytorch training loop
ts, states = batchdata # unpack the data
initial_state = states[:,0,:] # grab the initial state
# Make the prediction and then flip the dimensions to be (batch, state_dim, time)
# Pytorch expects the batch dimension to be first
pred = odeint(model,
initial_state,
ts[0],
method=method,
options={'step_size': step_size}).transpose(0,1)
# Compute the loss
loss = criterion(pred, states)
# compute gradients
loss.backward()
# update parameters
optimizer.step()
running_loss += loss.item() # record loss
if epoch % show_every == 0:
print(f"Loss at {epoch}: {running_loss}")
# Use this to save plots of the fit every save_plots_every epochs
if save_plots_every is not None and epoch % save_plots_every == 0:
with torch.no_grad():
fig, ax = plot_time_series(data.true_model, model, data[0])
ax.set_title(f"Epoch: {epoch}")
fig.savefig(f"./tmp_plots/{epoch}_{model_name}_fit_plot")
plt.close()

Examples

Fitting the VDP Oscillator

Let’s use this training loop to recover the parameters from simulated VDP oscillator data.

true_mu = 0.30
model_sim = VDP(mu=true_mu)
ts_data = torch.linspace(0.0,10.0,10)
data_vdp = create_sim_dataset(model_sim,
ts = ts_data,
num_samples=10,
sigma_noise=0.01)

Let’s create a model with the wrong parameter value and visualize the starting point.

vdp_model = VDP(mu = 0.10) 
plot_time_series(model_sim,
vdp_model,
data_vdp[0],
dyn_var_idx=1,
title = "VDP Model: Before Parameter Fits");
png

Now, we will use the training loop to fit the parameters of the VDP oscillator to the simulated data.

train(vdp_model, data_vdp, epochs=50, model_name="vdp");
print(f"After training: {vdp_model}, where the true value is {true_mu}")
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}")
Loss at 0: 0.1369624137878418
Loss at 10: 0.0073615191504359245
Loss at 20: 0.0009214915917254984
Loss at 30: 0.0002127257248503156
Loss at 40: 0.00019956652977271006
After training: mu: 0.3018421530723572, where the true value is 0.3
Final Parameter Recovery Error: 0.0018421411514282227

Not to bad! Let’s see how the plot looks now…

png

The plot confirms that we almost perfectly recovered the parameter. One more quick plot, where we plot the dynamics of the system in the phase plane (a parametric plot of the state variables).

png

Here is a visual of the training process for this model:

Video of the training process for the VDP model

Lotka Voltera Equations

Now lets adapt our methods to fit simulated data from the Lotka-Volterra equations.

model_sim_lv = LotkaVolterra(1.5,1.0,3.0,1.0)
ts_data = torch.arange(0.0, 10.0, 0.1)
data_lv = create_sim_dataset(model_sim_lv,
ts = ts_data,
num_samples=10,
sigma_noise=0.1,
initial_conditions_default=torch.tensor([2.5, 2.5]))
model_lv = LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2)

plot_time_series(model_sim_lv, model_lv, data = data_lv[0], title = "Lotka Volterra: Before Fitting");

Here is the initial fits for the starting parameters, then we will fit as before and take a look at the results.

png
train(model_lv, data_lv, epochs=60, lr=1e-2, model_name="lotkavolterra")
print(f"Fitted model: {model_lv}")
print(f"True model: {model_sim_lv}")
Loss at 0: 1.1298701763153076
Loss at 10: 0.1296287178993225
Loss at 20: 0.045993587002158165
Loss at 30: 0.02311511617153883
Loss at 40: 0.020882505923509598
Loss at 50: 0.020726025104522705
Fitted model: alpha: 1.5965800285339355, beta: 1.0465354919433594, delta: 2.817030429840088, gamma: 0.939825177192688
True model: alpha: 1.5, beta: 1.0, delta: 3.0, gamma: 1.0

First a time-series plot of the fitted system:

png

Now let’s visualize the results using a phase plane plot.

png

Here is a visual of the fitting process….

Lorenz Equations

Finally, let’s try to fit the Lorenz equations.

model_sim_lorenz = Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0)
ts_data = torch.arange(0, 10.0, 0.05)
data_lorenz = create_sim_dataset(model_sim_lorenz,
ts = ts_data,
num_samples=30,
initial_conditions_default=torch.tensor([1.0, 0.0, 0.0]),
sigma_noise=0.01,
sigma_initial_conditions=0.10)
lorenz_model = Lorenz(sigma=10.2, rho=28.2, beta=9.0/3)
fig, ax = plot_time_series(model_sim_lorenz, lorenz_model, data_lorenz[0], title="Lorenz Model: Before Fitting");

ax.set_xlim((2,15));

Here is the initial fits, then we will call our training loop.

png
train(lorenz_model, 
data_lorenz,
epochs=300,
batch_size=5,
method = 'rk4',
step_size=0.05,
show_every=50,
lr = 1e-3)

Here are the training results:

Loss at 0: 114.25119400024414
Loss at 50: 4.364489555358887
Loss at 100: 2.055854558944702
Loss at 150: 1.2539702206850052
Loss at 200: 0.7839434593915939
Loss at 250: 0.5347371995449066

Let’s look at the fitted model. Starting with a full plot of the dynamics.

png

Let’s zoom in on the bulk of the data and see how the fit looks.

png

You can see the model is very close to the true model for the data range, and generalizes well for t < 16 for the unseen data. Now the phase plane plot (zoomed in).

plot_phase_plane(model_sim_lorenz, lorenz_model, data_lorenz[0], title = "Lorenz Model: After Fitting", time_range=(0,20.0));
png

You can see that our fitted model performs well for t in [0,16] and then starts to diverge.

This procedure works great for the situation where we know the form of the equations on the right-hand-side, but what if we don’t? Can we use this procedure to discover the model equations?

This is much too big of a subject to fully cover in this post, but one of the biggest advantages of moving our differential equations models into the torch framework is that we can mix and match them with artificial neural network layers.

The simplest thing we can do is to replace the right-hand-side f(y,t; θ) with a neural network layer. These types of equations have been called a neural differential equations and it can be viewed as generalization of a recurrent neural network.

Neural differential equations replace the right-hand-side of the equations with a artificial neural network

As a first example, let’s do this for the our simple VDP oscillator system. To begin we will remake the simulated data, you will notice that I am creating longer time-series of the data and more samples. Fitting a neural differential equation takes much more data and more computational power since we have many more parameters that need to be determined.

# remake the data 
model_sim_vdp = VDP(mu=0.20)
ts_data = torch.linspace(0.0,30.0,100) # longer time series than the custom ode layer
data_vdp = create_sim_dataset(model_sim_vdp,
ts = ts_data,
num_samples=30, # more samples than the custom ode layer
sigma_noise=0.1,
initial_conditions_default=torch.tensor([0.50,0.10]))

Now I define a simple feedforward neural network layer to fill in the right-hand-side of the equation.


class NeuralDiffEq(nn.Module):
"""
Basic Neural ODE model
"""
def __init__(self,
dim: int = 2, # dimension of the state vector
) -> None:
super().__init__()
self.ann = nn.Sequential(torch.nn.Linear(dim, 8),
torch.nn.LeakyReLU(),
torch.nn.Linear(8, 16),
torch.nn.LeakyReLU(),
torch.nn.Linear(16, 32),
torch.nn.LeakyReLU(),
torch.nn.Linear(32, dim))

def forward(self, t, state):
return self.ann(state)

Here is a plot of the system before fitting:

model_vdp_nde = NeuralDiffEq(dim=2) 
plot_time_series(model_sim_vdp, model_vdp_nde, data_vdp[0], title = "Neural ODE: Before Fitting");
png

You can see we start very far away for the correct solution, but then again we are injecting much less information into our model. Let’s see if we can fit the model to get better results.

train(model_vdp_nde, 
data_vdp,
epochs=1500,
lr=1e-3,
batch_size=5,
show_every=100,
model_name = "nde")
Loss at 0: 84.39617252349854
Loss at 100: 84.34061241149902
Loss at 200: 73.75008296966553
Loss at 300: 3.4929964542388916
Loss at 400: 1.6555403769016266
Loss at 500: 0.7814530655741692
Loss at 600: 0.41551147401332855
Loss at 700: 0.3157300055027008
Loss at 800: 0.19066352397203445
Loss at 900: 0.15869349241256714
Loss at 1000: 0.12904016114771366
Loss at 1100: 0.23840919509530067
Loss at 1200: 0.1681726910173893
Loss at 1300: 0.09865255374461412
Loss at 1400: 0.09134986530989408

Visualizing the results, we can see that the model is able to fit the data and even extrapolate to the future (although it is not as good or fast as the specified model).

png

Now the phase plane plot of our neural differential equation model.

png

These models take a long time to train and more data to converge on a good fit. This makes sense since we are both trying to learn the model and the parameters at the same time.

Video of the fitting process for a NDE.

Conclusions and Wrap-Up

In this article I have demonstrated how we can use differential equation models within the pytorch ecosytem using the torchdiffeq package. The code from this article is available on github and can be opened directly to google colab for experimentation. You can also install the code from this article using pip:

pip install paramfittorchdemo

This post is an introduction in the future I will be writing more about the following topics:

  • How to blend some mechanistic knowledge of the dynamics with deep learning. These have been called universal differential equations as they enable us to combine scientific knowledge with deep learning. This basically blends the two approaches together.
  • How to combine differential equation layers with other deep learning layers.
  • Model discovery: Can we recover the actual model equations from data? This uses tools like SINDy to extract the model equations from data.
  • MLOps tools for managing the training of these models. This includes tools like MLFlow , Weights and Biases , and Tensorboard .
  • Anything else I hear back about from you!

If you liked this post, be sure to follow me and connect on linked-in. All images unless otherwise noted are by the author.

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment