Neural and Universal Ordinary Differential Equations: Part 02

Following along from notes by Chris Rackauckas.

In the previous notebook is documented how we can use a neural network to solve a parameter problem relating to an ODE – we essentially trained a neural network on an ODE model, and attempted to find a best fit for the model parameters.

using OrdinaryDiffEq, Flux, DiffEqFlux, Plots

Unable to load WebIO. Please make sure WebIO works for your Jupyter client. For troubleshooting, please see the WebIO/IJulia documentation.

Defining and Training Neural ODEs

Defining a neural ODE is the same as defining a parameterized differential equation, except here the parameterized ODE is a NN.

Consider the following example; we wish to match the following data:

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)

function trueODE!(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODE!, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577125  …  1.40688   1.37023   1.29215
 0.0  0.798831  1.46473  1.80877  1.86465      0.451367  0.728692  0.972095

Let us quickly visualise this data:

plot(ode_data')
../_images/neural_universal_diff_eq_02_5_0.svg

We will use a so-called knowledge-infused approach; that is to say, we assume that we knew the ODE had cubic behaviour. We can attempt to encode that physical information into a dense NN as:

# build a NN
dudt = Chain(
    x -> x.^3,
    Dense(2, 50, tanh),
    Dense(50, 2)
)
Chain(
  var"#1#2"(),
  Dense(2, 50, tanh),                   # 150 parameters
  Dense(50, 2),                         # 102 parameters
)                   # Total: 4 arrays, 252 parameters, 1.234 KiB.

To train this network we will make use of Flux.destructure and Flux.restructure, allowing us to take the parameters out of a NN into a vector, and similarly rebuild a NN from a parameter vector.

Using these, we define the ODE:

p, re = Flux.destructure(dudt)
dudt2_(u, p, t) = re(p)(u)
prob2 = ODEProblem(dudt2_, u0, tspan, p)
ODEProblem with uType Vector{Float32} and tType Float32. In-place: false
timespan: (0.0f0, 1.5f0)
u0: 2-element Vector{Float32}:
 2.0
 0.0

This is equivalent to

$$u^\prime = \text{NN}(u)$$

where the parameters are the parameters of the NN. We then use the same structure as before to train the network to reconstruct the ODE:

function predict_n_ode()
    Array(concrete_solve(prob2, Tsit5(), u0, p, saveat=t))
end

loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())


data = Iterators.repeated((), 300)
opt = ADAM(0.1)

training_plots = []
iter = 0

callback = function()
    global iter += 1
    if iter % 50 == 0
        @show loss_n_ode()

        cur_pred = predict_n_ode()
        pl = Plots.scatter(t, ode_data[1,:], label="data")
        Plots.scatter!(pl, t, cur_pred[1,:], label="prediction")
        push!(training_plots, plot(pl))
    end
end

ps = Flux.params(p)
Flux.train!(loss_n_ode, ps, data, opt, cb=callback)
"Done"
loss_n_ode() = 10.00524f0
loss_n_ode() = 2.0725229f0
loss_n_ode() = 0.83961827f0
loss_n_ode() = 0.5685626f0
loss_n_ode() = 0.38171425f0
loss_n_ode() = 0.19634308f0
"Done"

We can then view our plots:

# using interactive widget
plot(training_plots...; legend=false)
../_images/neural_universal_diff_eq_02_13_0.svg

And true enough, we were able to fit the parameters of the neural network (can we extract true A from this?).

Augmented Neural ODE

Not every function can be represented by an ODE; specifically, some

$$u(t) : \mathbb{R} \rightarrow \mathbb{R}^n$$

would be unable to be multivariate unless cyclic.

This is because the flow of the ODE is unique at every $t$; for the above mapping to have at least two directions of flow for a given point $u_i$, there would have to be at least two solutions in phase space to

$$u^\prime = f(u, p, t),$$

using the convention $u(0) = u_i$, which cannot happen with a monotonic map.

We can rectify this by introducing additional degrees of freedom, to ensure that the ODE does not overlap. This is the so-called augmented neural ODE.

This can be built using the following prescription:

  • add a fake state to the ODE which is 0 everywhere

  • allow this extra dimension to bump around to let the function become a universal approximator

In Julia, this is:

dudt = Chain(...)
p, re = Flux.destructure(dudt)

dudt_(u, p, t) = re(p)(u)
prob = ODEProblem(dudt_, [u0, f0], tspan, p)

augmented_data = vcat(
    ode_data,
    zeros(
        1,
        size(ode_data, 2)
    )
)