Arbitrary Butcher Tableaus and Adaptive Integration¶
In this notebook, we will go through the process of expanding our code in order to allow for different Butcher tableaus and how to implement a step-size adaptation algorithm.
We have moved the compensated_sum
function into our module under neuralode.util.compensated_sum
and added a function neuralode.util.partial_compensated_sum
that can be called with the partial sums and truncated bits for iteratively updated the tracked values. This simplifies our integration code by removing the duplication.
The plotting code has been moved into neuralode.plot.trajectory.plot_trajectory
and neuralode.plot.trajectory.plot_trajectory_with_reference
. The neuralode.plot.trajectory.plot_trajectory
plots a trajectory without a reference solution which we will encounter with systems that don't have a closed form solution. We will be using the concept of a reference where a higher precision solver can be used to compute the trajectory.
import typing
import warnings
import torch
import einops
import neuralode
warnings.simplefilter('once', RuntimeWarning)
# For convenience, we define the default tensor device and dtype here
torch.set_default_device('cpu')
# In neural networks, we prefer 32-bit/16-bit floats, but for precise integration, 64-bit is preferred. We will revisit this later when we need to mix integration with neural network training
torch.set_default_dtype(torch.float64)
We'll be starting with the same function from the previous notebook, but we have tidied up the code by using neuralode.util.partial_compensated_sum
to track the truncated bits instead of duplicating the code through our integration function, and placed common elements in their own modularised functions.
We will test that, with the modularised code, we can still reproduce the previous results on the same exponential system.
x0 = torch.tensor([1.0])
t0 = torch.tensor(0.0)
t1 = torch.tensor(1.0)
N = 10
dt = (t1 - t0) / N
(f_state, f_time), sub_states = neuralode.integrators.IntegrateRK4.apply(neuralode.dynamics.exponential_fn, x0, t0, t1, dt)
reference_trajectory = [neuralode.dynamics.exponential_fn_solution(x0, t) for _,t in sub_states]
fig, axes = neuralode.plot.trajectory.plot_trajectory_with_reference(sub_states, reference_trajectory, method_label="Compensated RK4")
print(f"Error in RK4: {(f_state - neuralode.dynamics.exponential_fn_solution(x0, t1)).abs().item()}")
Error in RK4: 3.3324105608301124e-07
Now that we have verified that the previous code is still correct and working with our code in the module, we can start parametrisation of the integrator as a function of the Butcher Tableau.
The following code takes advantage of the dynamic type system in Python, more specifically, we can generate a class and define its methods dynamically as well as store the tableau itself as an attribute of the class. In this fashion, we can have a function that takes a tableau as a PyTorch tensor and returns a PyTorch function that can integrate a system.
This will enable us to test out different integration algorithms without having to write a new class for each one. Eventually we will move this function into the module and add different integration algorithms at import time.
def get_integrator(integrator_tableau: torch.Tensor, integrator_name: str = None) -> torch.autograd.Function:
__integrator_type = type(integrator_name, (torch.autograd.Function,), {
"integrator_tableau": integrator_tableau
})
def __internal_forward(ctx, fn, x0, t0, t1, dt):
"""
A general integration routine for solving an Initial Value Problem
using any arbitrary Butcher Tableau
Instead of naively summing the changes, we use compensated summation.
:param fn: the function to be integrated
:param initial_state: the initial state to integrate from
:param initial_time: the initial time to integrate from
:param final_time: the final time to integrate to
:param timestep: the time increments to integrate with
:return: a tuple of ((the final state, the final time), the intermediates states [list])
"""
butcher_tableau = __integrator_type.integrator_tableau.clone().to(x0.device, x0.dtype)
# The names for the variables have been shortened for concision, and
# to avoid overlap with variables in the outer scope
# I have also left the annotations as they are invaluable for tracking the method.
c_time = t0.clone()
c_state = x0.clone()
c_state, c_time, i_states = neuralode.integrators.routines.integrate_system(fn, c_state, c_time, t1, dt, butcher_tableau)
ctx.save_for_backward((c_state, c_time, i_states))
ctx.integration_function = fn
return (c_state, c_time), i_states
__integrator_type.forward = staticmethod(__internal_forward)
return __integrator_type
explicit_rk4_integrator = get_integrator(torch.tensor([
# c0, a00, a01, a02, a03
[0.0, 0.0, 0.0, 0.0, 0.0],
# c1, a10, a11, a12, a13
[0.5, 0.5, 0.0, 0.0, 0.0],
# c2, a20, a21, a22, a23
[0.5, 0.0, 0.5, 0.0, 0.0],
# c3, a30, a31, a32, a33
[1.0, 0.0, 0.0, 1.0, 0.0],
# b0, b1, b2, b3
[0.0, 1/6, 2/6, 2/6, 1/6]
], dtype=torch.float64), integrator_name = "ExplicitRungeKutta4")
explicit_midpoint_integrator = get_integrator(torch.tensor([
# c0, a00, a01
[0.0, 0.0, 0.0],
# c1, a10, a11
[0.5, 0.5, 0.0],
# b0, b1
[0.0, 0.0, 1.0]
], dtype=torch.float64), integrator_name = "ExplicitMidpoint")
dt = (t1 - t0)/2000
(f_state, f_time), sub_states = explicit_midpoint_integrator.apply(neuralode.dynamics.exponential_fn, x0, t0, t1, dt)
reference_trajectory = [neuralode.dynamics.exponential_fn_solution(x0, t) for _,t in sub_states]
print(f"Error in {explicit_midpoint_integrator}: {(f_state - neuralode.dynamics.exponential_fn_solution(x0, t1)).abs().item()}")
fig, axes = neuralode.plot.trajectory.plot_trajectory_with_reference(sub_states, reference_trajectory, method_label="Midpoint Method")
(f_state, f_time), sub_states = explicit_rk4_integrator.apply(neuralode.dynamics.exponential_fn, x0, t0, t1, dt)
fig, axes = neuralode.plot.trajectory.plot_trajectory_with_reference(sub_states, reference_trajectory, axes=axes, method_label="Runge-Kutta 4")
print(f"Error in {explicit_rk4_integrator}: {(f_state - neuralode.dynamics.exponential_fn_solution(x0, t1)).abs().item()}")
Error in <class '__main__.ExplicitMidpoint'>: 1.5334058689475683e-08 Error in <class '__main__.ExplicitRungeKutta4'>: 1.1102230246251565e-16
We can see that this simplifies testing of different integration methods, and we could even use it to optimise the tableau itself since PyTorch allows for differentiating tensor variables (but I digress). For now, we will focus on how to implement an adaptive integration scheme. To identify whether a tableau if adaptive, we can use the first column of the rows with the $b_i$ coefficients and put a torch.inf
value in their place. As this value is ignored during integration, it will be a way of signalling that the rows with torch.inf
are used as the $b_i$ coefficients.
Let's implement this!
def compute_step_adaptive(fn, state, time, step, tableau):
# We need to store the intermediate stages
k_stages = torch.stack([torch.zeros_like(state)]*(tableau.shape[0]-2))
# we subtract one since the last row is the final state
for stage_index in range(k_stages.shape[0]):
c_coeff, *a_coeff = tableau[stage_index]
k_stages[stage_index] = fn(
# We use `compensated_sum` instead of `sum` to avoid truncation at each stage calculation
state + step * neuralode.util.compensated_sum(k*a for k,a in zip(k_stages, a_coeff)),
time + c_coeff * step
)
lower_order_estimate = step * neuralode.util.compensated_sum(k*b for k,b in zip(k_stages, tableau[-1, 1:]))
higher_order_estimate = step * neuralode.util.compensated_sum(k*b for k,b in zip(k_stages, tableau[-2, 1:]))
# From a numerical perspective, this implementation is not necessarily ideal as
# we can lose precision when subtracting the two solutions. A more numerically accurate
# implementation would have one row `b_i` coefficients and another row the coefficients
# for computing the error directly
return lower_order_estimate, higher_order_estimate, step
def integrate_system_adaptive(fn: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
c_state: torch.Tensor,
c_time: torch.Tensor,
final_time: torch.Tensor,
dt: torch.Tensor,
atol: torch.Tensor,
rtol: torch.Tensor,
is_adaptive: bool,
use_local_extrapolation: bool,
integrator_order: int,
tableau: torch.Tensor):
i_states = [(c_state.clone(), c_time.clone())]
error_in_state = [torch.zeros(tuple(), device=x0.device, dtype=x0.dtype)]
c_state, truncated_bits_state = neuralode.util.partial_compensated_sum(c_state)
c_time, truncated_bits_time = neuralode.util.partial_compensated_sum(c_time)
if is_adaptive:
compute_step = compute_step_adaptive
else:
compute_step = neuralode.integrators.helpers.compute_step
while torch.any((c_time + dt) < final_time):
delta_state_lower, delta_state_upper, delta_time = compute_step(fn, c_state + truncated_bits_state, c_time + truncated_bits_time, dt, tableau)
# If local extrapolation is enabled, we take the higher order estimate, otherwise the lower order one
delta_state = delta_state_upper if use_local_extrapolation else delta_state_lower
# We use `torch.linalg.norm` to compute the magnitude of the error
# we can adjust this by passing in the `ord` keyword to choose a different
# vector norm, but the 2-norm suffices for our purposes
# We also detach just in case since this variable should not be differentiated
current_error = torch.linalg.norm(delta_state_upper - delta_state_lower).detach()
if is_adaptive:
# To save on computation, we only compute the max error tolerated and the step
# correction when the method is adaptive
max_error = (atol + torch.linalg.norm(rtol * c_state)).detach()
step_correction = 0.8*torch.where(current_error != 0.0, max_error/current_error, 1.0)**(1/integrator_order)
# Based on the error, we correct the step size
dt = step_correction*dt
if current_error >= max_error:
# If the error exceeds our error threshold, we don't commit the step and redo it
continue
else:
error_in_state.append(current_error)
c_state, truncated_bits_state = neuralode.util.partial_compensated_sum(delta_state, (c_state, truncated_bits_state))
c_time, truncated_bits_time = neuralode.util.partial_compensated_sum(delta_time, (c_time, truncated_bits_time))
i_states.append((c_state + truncated_bits_state, c_time + truncated_bits_time))
delta_state_lower, delta_state_upper, delta_time = compute_step(fn, c_state + truncated_bits_state, c_time + truncated_bits_time, (final_time - c_time) - truncated_bits_time, tableau)
delta_state = delta_state_upper if use_local_extrapolation else delta_state_lower
c_state, truncated_bits_state = neuralode.util.partial_compensated_sum(delta_state, (c_state, truncated_bits_state))
c_time, truncated_bits_time = neuralode.util.partial_compensated_sum(delta_time, (c_time, truncated_bits_time))
error_in_state.append(torch.linalg.norm(delta_state_upper - delta_state_lower))
i_states.append((c_state + truncated_bits_state, c_time + truncated_bits_time))
return c_state, c_time, i_states, error_in_state
def get_integrator(integrator_tableau: torch.Tensor, integrator_order: int, use_local_extrapolation: bool = True, integrator_name: str = None) -> torch.autograd.Function:
# We look at the first column of the last two rows, and if both are `inf`, we know the method is adaptive
is_adaptive = torch.isinf(integrator_tableau[-1,0]) and torch.isinf(integrator_tableau[-2,0])
# The number of stages is the number of rows minus the last row
# (or last two rows if the method is adaptive)
number_of_stages = integrator_tableau.shape[0] - 1
if is_adaptive:
number_of_stages -= 1
# The `type` function in this form works to dynamically create a class
# the first parameter is the class name, the second are parent classes,
# and the last are the class attributes. We store the integrator attributes
# here, and reference them in the integration code.
# In this way, we can query these parameters at a future point.
__integrator_type = type(integrator_name, (torch.autograd.Function,), {
"integrator_tableau": integrator_tableau,
"integrator_order": integrator_order,
"is_adaptive": is_adaptive,
"number_of_stages": number_of_stages
})
def __internal_forward(ctx, fn: typing.Callable[[torch.Tensor, torch.Tensor, typing.Any], torch.Tensor],
x0: torch.Tensor, t0: torch.Tensor, t1: torch.Tensor, dt: torch.Tensor,
atol: torch.Tensor, rtol: torch.Tensor, *additional_dynamic_args):
"""
A general integration routine for solving an Initial Value Problem
using any arbitrary Butcher Tableau
Instead of naively summing the changes, we use compensated summation.
:param fn: the function to be integrated
:param initial_state: the initial state to integrate from
:param initial_time: the initial time to integrate from
:param final_time: the final time to integrate to
:param timestep: the time increments to integrate with
:param atol: The absolute tolerance for the error in an adaptive integration
:param rtol: The relative tolerance for the error in an adaptive integration
:param additional_dynamic_args: additional arguments to pass to the function
:return: a tuple of ((the final state, the final time), the intermediate states [list[torch.Tensor]], the error values [list[torch.Tensor]])
"""
if __integrator_type.is_adaptive:
# We need to check that both `atol` and `rtol` are valid values and are compatible with the state
atol = neuralode.integrators.helpers.ensure_tolerance(atol, x0, "Absolute tolerance", "atol")
rtol = neuralode.integrators.helpers.ensure_tolerance(rtol, x0, "Relative tolerance", "rtol")
dt = neuralode.integrators.helpers.ensure_timestep(dt, t0, t1)
butcher_tableau = __integrator_type.integrator_tableau.clone().to(x0.device, x0.dtype)
def forward_fn(state, time):
return fn(state, time, *additional_dynamic_args)
c_state = x0.clone()
c_time = t0.clone()
c_state, c_time, i_states, error_in_state = integrate_system_adaptive(forward_fn, c_state, c_time, t1, dt, atol, rtol,
__integrator_type.is_adaptive, use_local_extrapolation,
__integrator_type.integrator_order, butcher_tableau)
# We save parameters for the backward pass, but these won't be used
# until we implement the adjoint method for backpropagation
ctx.save_for_backward(c_state, c_time, *additional_dynamic_args, *[i[0] for i in i_states], *[i[1] for i in i_states])
ctx.integration_function = fn
return (c_state, c_time), i_states, error_in_state
if not __integrator_type.is_adaptive:
# If the method isn't adaptive, neither atol nor rtol are required, but because of
# how `torch.autograd.Function` works, we cannot have keyword arguments
# For that reason, we use an alternative implementation to fill those values with a stub
def __internal_forward_nonadaptive(ctx, fn: typing.Callable[[torch.Tensor, torch.Tensor, typing.Any], torch.Tensor],
x0: torch.Tensor, t0: torch.Tensor, t1: torch.Tensor, dt: torch.Tensor, *additional_dynamic_args):
return __internal_forward(ctx, fn, x0, t0, t1, dt, torch.inf, torch.inf, *additional_dynamic_args)
__integrator_type.forward = staticmethod(__internal_forward_nonadaptive)
else:
__integrator_type.forward = staticmethod(__internal_forward)
return __integrator_type
adaptive_rk45_integrator = get_integrator(torch.tensor([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[1/5, 1/5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[3/10, 3/40, 9/40, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[4/5, 44/45, -56/15, 32/9, 0.0, 0.0, 0.0, 0.0 ],
[8/9, 19372/6561, -25360/2187, 64448/6561, -212/729, 0.0, 0.0, 0.0 ],
[1.0, 9017/3168, -355/33, 46732/5247, 49/176, -5103/18656, 0.0, 0.0 ],
[1.0, 35/384, 0.0, 500/1113, 125/192, -2187/6784, 11/84, 0.0 ],
[torch.inf, 35/384, 0.0, 500/1113, 125/192, -2187/6784, 11/84, 0.0 ],
[torch.inf, 5179/57600, 0.0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40]
], dtype=torch.float64), integrator_order = 5, use_local_extrapolation = False, integrator_name = "AdaptiveRK45Integrator")
Having demonstrated an adaptive integrator with an arbitrary butcher tableau, we are now ready to try higher dimensional systems and implement backpropagation through our algorithm.
First, let's write the dynamics of a system where a parameter can control the different behaviours. A harmonic oscillator is the ideal candidate as its solution can be determined exactly (if effort is expended) and has physically intuitive parameters. From physics, we know that the dynamical equation for a harmonic oscillator is:
$$ \vec{F}=-k\vec{x}-c\vec{x}^{(1)} $$
And using the good old $\vec{F}=m\vec{a}$ and $\vec{a}=\vec{x}^{(2)}$ we can rewrite our equations as:
$$ \vec{x}^{(2)}=-\frac{k}{m}\vec{x}-\frac{c}{m}\vec{x}^{(1)} $$
This equation is written with arbitrary spatial dimensions in mind, and is a second order equation. For our purposes, we can treat this as having a single spatial dimension and convert it into a one-dimensional system using the methods we discussed in the introduction.
Let's introduce the variable $v=x^{(1)}$ which will allow us to write:
$$ \begin{bmatrix} x^{(1)} \\ v^{(1)} \end{bmatrix} = \begin{bmatrix} v \\ -\frac{k}{m}x-\frac{c}{m}v \end{bmatrix} $$
As physicists like to introduce quantities that are more intuitive, we'll rewrite the equation as:
$$ \begin{bmatrix} x^{(1)} \\ v^{(1)} \end{bmatrix} = \begin{bmatrix} v \\ -\omega^2x-2\zeta\omega v \end{bmatrix} $$
Using the definition of $\omega = \sqrt{\frac{k}{m}}$, corresponding to the undamped angular frequency, and $\zeta=\frac{c}{2\sqrt{mk}$, corresponding to the damping ratio.
This is a linear equation in which we can write it as the original vector multiplied by some matrix as follows:
$$ \begin{bmatrix} x^{(1)} \\ v^{(1)} \end{bmatrix} = \mathbf{A} \begin{bmatrix} x \\ v \end{bmatrix} $$
where
$$ \mathbf{A} = \begin{bmatrix} 0 & 1 \\ -\omega^2 & -2\zeta\omega \end{bmatrix} $$
With this, we now have a system where some parameters control the dynamics, and we can ask questions like "Given some samples of this oscillator, can we estimate these parameters from numerical integration?"
def sha_dynamics(x, t, frequency, damping):
# The dynamics above can easily be represented as a matrix multiplication
# First the matrix with the corresponding terms
A = torch.stack([
# no x term for the derivative of x as it is equal to v
torch.stack([torch.zeros_like(frequency), torch.ones_like(frequency)], dim=-1),
# first we have the omega^2 term, then the 2*zeta*omega term
torch.stack([-frequency**2, -2*frequency*damping], dim=-1),
], dim=-2)
# We implement the matrix multiplication using einops
# This is not necessarily the most efficient, but it allows
# us to track the exact operation without worrying about the shapes
# of our tensors too much
# You can read '...,ij,...j->...i' as:
# - The first argument is a tensor with arbitrary dimensions, but
# the last two of which are of interest, labelled as 'i' and 'j'
# - The second argument is a tensor with arbitrary dimensions, but
# the last of which is commensurate with the number of rows of the input matrix
# - Take the sum of A[...,i,j]*x[...,j] over all 'j' and the output will be indexed
# by 'i' in the last dimension
return einops.einsum(A, x, '... row col,... col->... row')
Before we had to define only one initial state for the system, but since we now have two dimensions, there are two initial conditions to specify. These can physically be interpreted as the initial position of a spring and the initial velocity of a spring.
initial_position = torch.tensor(1.0)
initial_velocity = torch.tensor(0.0)
frequency = torch.ones_like(initial_position)
damping = torch.ones_like(initial_position)*0.25
initial_x = torch.stack([
initial_position,
initial_velocity,
], dim=-1)
initial_time = torch.tensor(0.0)
final_time = torch.tensor(25.0)
initial_timestep = (final_time - initial_time) / 10
atol, rtol = torch.tensor(5e-8), torch.tensor(5e-8)
_, sha_states, _ = adaptive_rk45_integrator.apply(sha_dynamics, initial_x, initial_time, final_time, dt, atol, rtol, frequency, damping)
fig, axes = neuralode.plot.trajectory.plot_trajectory(sha_states, method_label="RK4(5) - Simple Harmonic Oscillator")
Excellent! We can see that we've implemented the oscillator correctly, and it is oscillating and decaying as expected. Now, let's try something interesting with this, let's assume we don't know the frequency and the damping factor, but we do know the initial state, and we would like to estimate the frequency/damping based on the previously computed trajectory.
This requires splitting the problem into several parts:
- First, we must integrate our system with the guesses to each point in time that we've solved the system above.
- Second, we must compute an error at each point in time
- Third, we must compute the gradients of the error with respect to our parameters
- Fourth, we must update our parameters based on the gradients
We'll write the code for the first part, essentially we'll create a few parameters that we'll be optimising
# You'll note that we set requires_grad=True to let PyTorch know that we want to track the operations
# involving these variables and compute their gradients
optimised_frequency = torch.tensor(0.1, requires_grad=True)
optimised_damping = torch.tensor(1.0, requires_grad=True)
Second, we'll need to write some code that integrates our values piecewise to each point in time we're interested in:
times_to_integrate = [time for _, time in sha_states][1:]
# We skip the first time as we know the initial state
current_state = initial_x.clone()
current_time = initial_time.clone()
## DETAIL THE ERROR FUNCTION AND HOW THAT'S USED
# Let's look at the first sample_time and see what output we get
for sample_index, sample_time in enumerate(times_to_integrate[:1]):
(current_state, current_time), _, _ = adaptive_rk45_integrator.apply(sha_dynamics, current_state, current_time, sample_time, dt, atol, rtol, optimised_frequency, optimised_damping)
error = torch.linalg.norm(sha_states[sample_index][0] - current_state)
print(f"The current state: {current_state}")
print(f"Error at time {current_time} is {error}")
The current state: tensor([ 1.0000e+00, -4.9998e-06]) Error at time 0.0005 is 4.999750162497297e-06
Now let's look at how pytorch represents operations whose gradients is taped
:
optimised_frequency*2
tensor(0.2000, grad_fn=<MulBackward0>)
It seems that the output from our integrator is missing the appropriate grad_fn
which should point to our integrator. Looking into the documentation and reading this issue we see that torch.autograd.Function
requires outputs to be flat tensors and not nested inside python structures like lists and tuples. This means we need to rework our code so that the output is compatible with PyTorch's gradient tracking.
def get_forward_method(integrator_type, use_local_extrapolation):
def __internal_forward(ctx, fn: typing.Callable[[torch.Tensor, torch.Tensor, typing.Any], torch.Tensor],
x0: torch.Tensor, t0: torch.Tensor, t1: torch.Tensor, dt: torch.Tensor,
atol: torch.Tensor, rtol: torch.Tensor, *additional_dynamic_args):
"""
A general integration routine for solving an Initial Value Problem
using any arbitrary Butcher Tableau
Instead of naively summing the changes, we use compensated summation.
:param fn: the function to be integrated
:param initial_state: the initial state to integrate from
:param initial_time: the initial time to integrate from
:param final_time: the final time to integrate to
:param timestep: the time increments to integrate with
:param atol: The absolute tolerance for the error in an adaptive integration
:param rtol: The relative tolerance for the error in an adaptive integration
:param additional_dynamic_args: additional arguments to pass to the function
:return: a tuple of ((the final state, the final time), the intermediate states [list[torch.Tensor]], the error values [list[torch.Tensor]])
"""
if integrator_type.is_adaptive:
# We need to check that both `atol` and `rtol` are valid values and are compatible with the state
atol = neuralode.integrators.helpers.ensure_tolerance(atol, x0, "Absolute tolerance", "atol")
rtol = neuralode.integrators.helpers.ensure_tolerance(rtol, x0, "Relative tolerance", "rtol")
dt = neuralode.integrators.helpers.ensure_timestep(dt, t0, t1)
butcher_tableau = integrator_type.integrator_tableau.clone().to(x0.device, x0.dtype)
def forward_fn(state, time):
return fn(state, time, *additional_dynamic_args)
c_state = x0.clone()
c_time = t0.clone()
c_state, c_time, i_states, error_in_state = integrate_system_adaptive(forward_fn, c_state, c_time, t1, dt, atol, rtol,
integrator_type.is_adaptive, use_local_extrapolation,
integrator_type.integrator_order, butcher_tableau)
intermediate_states, intermediate_times = zip(*i_states)
# As we said, these need to be converted to tensors for proper tracking
intermediate_states = torch.stack(intermediate_states, dim=0)
intermediate_times = torch.stack(intermediate_times, dim=0)
# We should also put the errors we're returning into a tensor too
error_in_state = torch.stack(error_in_state, dim=0)
# We save parameters for the backward pass, but these won't be used
# until we implement the adjoint method for backpropagation
if ctx is not None:
ctx.save_for_backward(x0, t0, t1, dt, atol, rtol, c_state, c_time, intermediate_states, intermediate_times, *additional_dynamic_args)
ctx.integration_function = fn
# Now we're returning a flat structure where each element is a tensor, and so its gradients can be properly tracked
return c_state, c_time, intermediate_states, intermediate_times, error_in_state
return __internal_forward
def get_integrator(integrator_tableau: torch.Tensor, integrator_order: int, use_local_extrapolation: bool = True, integrator_name: str = None) -> torch.autograd.Function:
# We look at the first column of the last two rows, and if both are `inf`, we know the method is adaptive
is_adaptive = torch.isinf(integrator_tableau[-1,0]) and torch.isinf(integrator_tableau[-2,0])
# The number of stages is the number of rows minus the last row
# (or last two rows if the method is adaptive)
number_of_stages = integrator_tableau.shape[0] - 1
if is_adaptive:
number_of_stages -= 1
# The `type` function in this form works to dynamically create a class
# the first parameter is the class name, the second are parent classes,
# and the last are the class attributes. We store the integrator attributes
# here, and reference them in the integration code.
# In this way, we can query these parameters at a future point.
__integrator_type = type(integrator_name, (torch.autograd.Function,), {
"integrator_tableau": integrator_tableau,
"integrator_order": integrator_order,
"is_adaptive": is_adaptive,
"number_of_stages": number_of_stages
})
__internal_forward = get_forward_method(__integrator_type, use_local_extrapolation)
if not __integrator_type.is_adaptive:
# If the method isn't adaptive, neither atol nor rtol are required, but because of
# how `torch.autograd.Function` works, we cannot have keyword arguments
# For that reason, we use an alternative implementation to fill those values with a stub
def __internal_forward_nonadaptive(ctx, fn: typing.Callable[[torch.Tensor, torch.Tensor, typing.Any], torch.Tensor],
x0: torch.Tensor, t0: torch.Tensor, t1: torch.Tensor, dt: torch.Tensor, *additional_dynamic_args):
return __internal_forward(ctx, fn, x0, t0, t1, dt, torch.inf, torch.inf, *additional_dynamic_args)
__integrator_type.forward = staticmethod(__internal_forward_nonadaptive)
else:
__integrator_type.forward = staticmethod(__internal_forward)
return __integrator_type
adaptive_rk45_integrator = get_integrator(adaptive_rk45_integrator.integrator_tableau, integrator_order = 5, use_local_extrapolation = False, integrator_name = "AdaptiveRK45Integrator")
times_to_integrate = [time for _, time in sha_states][1:]
# We skip the first time as we know the initial state
current_state = initial_x.clone()
current_time = initial_time.clone()
error = 0.0
# Let's look at the first sample_time and see what output we get
for sample_index, sample_time in enumerate(times_to_integrate[:1]):
current_state, current_time, *_ = adaptive_rk45_integrator.apply(sha_dynamics, current_state, current_time, sample_time, dt, atol, rtol, optimised_frequency, optimised_damping)
error = error + torch.linalg.norm(sha_states[sample_index][0] - current_state)/len(times_to_integrate)
print(f"The current state: {current_state}")
print(f"Error at time {current_time} is {error}")
The current state: tensor([ 1.0000e+00, -4.9998e-06], grad_fn=<AdaptiveRK45IntegratorBackward>) Error at time 0.0005 is 4.03205658265911e-08
Excellent! We now see that pytorch is correctly tracking our integration as part of the gradient tape. Let's try computing the gradients!
try:
error.backward()
except NotImplementedError as e:
print(f"Encountered exception: {e}")
Encountered exception: You must implement either the backward or vjp method for your custom autograd.Function to use it with backward mode AD.
Ah yes, we haven't implemented a backward method yet so let's do that:
def get_integrator(integrator_tableau: torch.Tensor, integrator_order: int, use_local_extrapolation: bool = True, integrator_name: str = None) -> torch.autograd.Function:
# We look at the first column of the last two rows, and if both are `inf`, we know the method is adaptive
is_adaptive = torch.isinf(integrator_tableau[-1,0]) and torch.isinf(integrator_tableau[-2,0])
# The number of stages is the number of rows minus the last row
# (or last two rows if the method is adaptive)
number_of_stages = integrator_tableau.shape[0] - 1
if is_adaptive:
number_of_stages -= 1
# The `type` function in this form works to dynamically create a class
# the first parameter is the class name, the second are parent classes,
# and the last are the class attributes. We store the integrator attributes
# here, and reference them in the integration code.
# In this way, we can query these parameters at a future point.
__integrator_type = type(integrator_name, (torch.autograd.Function,), {
"integrator_tableau": integrator_tableau,
"integrator_order": integrator_order,
"is_adaptive": is_adaptive,
"number_of_stages": number_of_stages
})
__internal_forward = get_forward_method(__integrator_type, use_local_extrapolation)
def __internal_backward(ctx, d_c_state, d_c_time, d_intermediate_states, d_intermediate_times, d_error_in_state):
"""
This function computes the gradient of the input variables for `__internal_forward` by exploiting the fact
that PyTorch can track the whole graph of operations used to derive a specific result. Thus each time backward is called,
we compute the actual graph of operations and propagate derivatives through it. Unfortunately, this is an exceptionally
slow method of computation that also uses a lot of memory.
This is implemented here as a demonstration of how we could compute gradients and how these are expected to be propagated back
to the autograd tape.
:param ctx:
:param d_c_state:
:param d_c_time:
:param d_intermediate_states:
:param d_intermediate_times:
:param d_error_in_state:
:return:
"""
# First we retrieve our integration function that we stored in `__internal_forward`
fn = ctx.integration_function
# Then we retrieve the input variables and clone them to avoid influencing them in the later operations
x0, t0, t1, dt, atol, rtol, _, _, _, _, *additional_dynamic_args = [i.clone().requires_grad_(True) for i in ctx.saved_tensors]
inputs = fn, x0, t0, t1, dt, atol, rtol, *additional_dynamic_args
if any(ctx.needs_input_grad):
# We ensure that gradients are enabled so that autograd tracks the variable operations
with torch.enable_grad():
# And then we integrate our system with the tracking of operations.
# We pass in `None` for the `ctx` to avoid issues with __internal_forward attempting to call methods that we don't want to use
# In the adjoint method; this will not be an issue
c_state, c_time, intermediate_states, intermediate_times, error_in_state = __internal_forward(None, fn, x0, t0, t1, dt, atol, rtol, *additional_dynamic_args)
# We collate the outputs that we can compute gradients for
# with this method, we are restricted to the final state and time
outputs = c_state, c_time #, intermediate_states, intermediate_times, error_in_state
grad_outputs = d_c_state, d_c_time, d_intermediate_states, d_intermediate_times, d_error_in_state
# We also only consider the input and output variables that actually have gradients enabled
inputs_with_grad = [i for idx, i in enumerate(inputs) if ctx.needs_input_grad[idx]]
outputs_with_grad = [idx for idx, i in enumerate(outputs) if i.grad_fn is not None]
grad_of_inputs_with_grad = torch.autograd.grad([outputs[idx] for idx in outputs_with_grad], inputs_with_grad, grad_outputs=[grad_outputs[idx] for idx in outputs_with_grad], allow_unused=True, materialize_grads=True)
else:
grad_of_inputs_with_grad = None
# For each input we must return a gradient
# Interestingly, this also includes the function we passed in...
# We create a list of None values
# (this tells autograd that there is no gradient for those variables).
# And for each variable that does have a gradient, we fill the values in
# before returning the list
input_grads = [None for _ in range(len(inputs))]
if grad_of_inputs_with_grad:
for idx in range(len(inputs)):
if ctx.needs_input_grad[idx]:
input_grads[idx], *grad_of_inputs_with_grad = grad_of_inputs_with_grad
return tuple(input_grads)
if not __integrator_type.is_adaptive:
# If the method isn't adaptive, neither atol nor rtol are required, but because of
# how `torch.autograd.Function` works, we cannot have keyword arguments
# For that reason, we use an alternative implementation to fill those values with a stub
def __internal_forward_nonadaptive(ctx, fn: typing.Callable[[torch.Tensor, torch.Tensor, typing.Any], torch.Tensor],
x0: torch.Tensor, t0: torch.Tensor, t1: torch.Tensor, dt: torch.Tensor, *additional_dynamic_args):
return __internal_forward(ctx, fn, x0, t0, t1, dt, torch.inf, torch.inf, *additional_dynamic_args)
__integrator_type.forward = staticmethod(__internal_forward_nonadaptive)
else:
__integrator_type.forward = staticmethod(__internal_forward)
__integrator_type.backward = staticmethod(__internal_backward)
return __integrator_type
adaptive_rk45_integrator = get_integrator(adaptive_rk45_integrator.integrator_tableau, integrator_order = 5, use_local_extrapolation = False, integrator_name = "AdaptiveRK45Integrator")
times_to_integrate = [time for _, time in sha_states][1:]
# We skip the first time as we know the initial state
current_state = initial_x.clone()
current_time = initial_time.clone()
error = 0.0
# Let's look at the first sample_time and see what output we get
for sample_index, sample_time in enumerate(times_to_integrate[:1]):
current_state, current_time, *_ = adaptive_rk45_integrator.apply(sha_dynamics, current_state, current_time, sample_time, dt, atol, rtol, optimised_frequency, optimised_damping)
error = error + torch.linalg.norm(sha_states[sample_index][0] - current_state)/len(times_to_integrate)
print(f"The current state: {current_state}")
print(f"Error at time {current_time} is {error}")
error.backward()
The current state: tensor([ 1.0000e+00, -4.9998e-06], grad_fn=<AdaptiveRK45IntegratorBackward>) Error at time 0.0005 is 4.03205658265911e-08
print(f"Current Frequency: {optimised_frequency}, Frequency grad: {optimised_frequency.grad}")
print(f"Current Damping: {optimised_damping}, Damping grad: {optimised_damping.grad}")
Current Frequency: 0.1, Frequency grad: 8.063911562493276e-07 Current Damping: 1.0, Damping grad: -2.0159946488567326e-12
Great, we see that the gradient of both optimised_frequency
and optimised_damping
have been populated. Let's check that this gradient is being computed correctly using torch.autograd.gradcheck
with random initial conditions, frequencies and damping coefficients. In this way, we can check that the gradient is correct across multiple conditions.
gradcheck
works by using finite differences as the numerical value and comparing that to the autodiff derived value for the gradient (or, more generally, the Jacobian). This requires multiple numerical integrations and running at the highest precision possible (hence the use of atol**2
and rtol**2
which achieves machine precision) to ensure that numerical inaccuracy is not the cause of incorrect gradients. This is quite an expensive procedure due to the depth of the autodiff graph that is generated and thus will take some time to compute.
from torch.autograd import gradcheck
def test_func(init_state, freq, damp):
res = adaptive_rk45_integrator.apply(sha_dynamics, init_state, initial_time, initial_time+0.01, dt, torch.tensor(1e-14), torch.tensor(1e-14), freq, damp)
return res[0]
test_variables = [initial_x, frequency, damping]
def generate_test_vars():
test_x = (2*torch.rand_like(initial_x) - 1.0)
test_frequency = torch.rand_like(frequency)
test_damping = torch.rand_like(damping)
return [i.requires_grad_(True) for i in [test_x, test_frequency, test_damping]]
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
num_tests = 4
print(f"[0/{num_tests}] - vars: {[i.detach().cpu().tolist() for i in test_variables]}, success: ", end='')
print(gradcheck(test_func, [i.detach().clone().requires_grad_(True) for i in test_variables]))
for iter_idx in range(num_tests):
variables = generate_test_vars()
print(f"[{iter_idx+1}/{num_tests}] - vars: {[i.detach().cpu().tolist() for i in variables]}, success: ", end='')
print(gradcheck(test_func, variables))
[0/4] - vars: [[1.0, 0.0], 1.0, 0.25], success: True [1/4] - vars: [[0.5662623888314278, -0.13291998917513825], 0.5862130150542093, 0.35079282251958555], success: True [2/4] - vars: [[-0.03426942938557631, -0.6035063069149991], 0.0016512400245205505, 0.8237204099086072], success: True [3/4] - vars: [[0.3611170099769341, -0.5509630565383719], 0.08240171994937384, 0.9659915931678263], success: True [4/4] - vars: [[0.14063778041676822, -0.22986764674631344], 0.5371906530543378, 0.9915331859191349], success: True
Now that we've validated the correctness of the gradients, we can implement an optimisation loop to fit our parameters to a prior trajectory.
times_to_integrate = [time for _, time in sha_states][1:]
# We skip the first time as we know the initial state
# We reinitialise our variables
optimised_frequency = torch.tensor(0.1, requires_grad=True)
optimised_damping = torch.tensor(1.0, requires_grad=True)
# As damping needs to be a strictly positive quantity, we log-encode it
log_encoded_damping = torch.log(optimised_damping.detach()).requires_grad_(True)
# First, we'll create an `optimiser` following pytorch convention
optimiser = torch.optim.Adam([optimised_frequency, log_encoded_damping], lr=1e-2, amsgrad=True)
# Next, we'll define a closure function whose sole purpose is to
# zero the gradients and compute the error. This is useful as it allows switching to other
# optimisers such as LBFGS or anything that re-evaluates the error without
# computing its gradient
def sha_closure():
current_state = initial_x.clone()
current_time = initial_time.clone()
optimiser.zero_grad()
error = 0.0
for sample_index, sample_time in enumerate(times_to_integrate):
# We don't require high precision here as this system, with the parameters we've set is fairly stable
# We only need that the gradients have the right direction and roughly the right magnitude to do gradient descent
new_state, new_time, *_ = adaptive_rk45_integrator.apply(sha_dynamics, current_state, current_time, sample_time, torch.minimum(dt, sample_time - current_time), atol, rtol, optimised_frequency, torch.exp(log_encoded_damping))
error = error + torch.linalg.norm(sha_states[sample_index][0] - new_state)/len(times_to_integrate)
current_state, current_time = new_state.detach(), new_time.detach()
if error.requires_grad:
error.backward()
return error
# Now we need an optimisation `loop` where we will take steps to minimise the error
number_of_gd_steps = 256
# We also need to track the best solution thus far
best_error = torch.inf
best_frequency, best_damping = optimised_frequency.detach().clone(), optimised_damping.detach().clone()
for step in range(number_of_gd_steps):
step_error = optimiser.step(sha_closure)
if step_error < best_error:
best_error = step_error.item()
best_frequency = optimised_frequency.detach().clone()
best_damping = torch.exp(log_encoded_damping.detach().clone())
print(f"[{step+1}/{number_of_gd_steps}] Error: {step_error.item():.6f}, Current Frequency: {optimised_frequency.item():.6f}, Current Damping: {torch.exp(log_encoded_damping).item():.6f}")
print(f"Best frequency: {best_frequency.item():.6f}, relative error: {torch.mean(torch.abs(1 - best_frequency / frequency)).item():.6%}")
print(f"Best damping: {best_damping.item():.6f}, relative error: {torch.mean(torch.abs(1 - best_damping / damping)).item():.6%}")
[1/256] Error: 0.718138, Current Frequency: 0.110000, Current Damping: 0.990051 [2/256] Error: 0.686355, Current Frequency: 0.120008, Current Damping: 0.980191 [3/256] Error: 0.655750, Current Frequency: 0.130027, Current Damping: 0.970422 [4/256] Error: 0.626408, Current Frequency: 0.140062, Current Damping: 0.960753 [5/256] Error: 0.598379, Current Frequency: 0.150117, Current Damping: 0.951206 [6/256] Error: 0.571688, Current Frequency: 0.160195, Current Damping: 0.941820 [7/256] Error: 0.546343, Current Frequency: 0.170300, Current Damping: 0.932656 [8/256] Error: 0.522344, Current Frequency: 0.180436, Current Damping: 0.923802 [9/256] Error: 0.499689, Current Frequency: 0.190608, Current Damping: 0.915383 [10/256] Error: 0.478377, Current Frequency: 0.200820, Current Damping: 0.907575 [11/256] Error: 0.458415, Current Frequency: 0.211077, Current Damping: 0.900601 [12/256] Error: 0.439823, Current Frequency: 0.221385, Current Damping: 0.894715 [13/256] Error: 0.422625, Current Frequency: 0.231747, Current Damping: 0.890171 [14/256] Error: 0.406842, Current Frequency: 0.242170, Current Damping: 0.887130 [15/256] Error: 0.392461, Current Frequency: 0.252655, Current Damping: 0.885589 [16/256] Error: 0.379451, Current Frequency: 0.263205, Current Damping: 0.885403 [17/256] Error: 0.367767, Current Frequency: 0.273820, Current Damping: 0.886400 [18/256] Error: 0.357104, Current Frequency: 0.284501, Current Damping: 0.888357 [19/256] Error: 0.347307, Current Frequency: 0.295248, Current Damping: 0.891028 [20/256] Error: 0.338277, Current Frequency: 0.306060, Current Damping: 0.894147 [21/256] Error: 0.329941, Current Frequency: 0.316937, Current Damping: 0.897394 [22/256] Error: 0.322278, Current Frequency: 0.327880, Current Damping: 0.900522 [23/256] Error: 0.315177, Current Frequency: 0.338887, Current Damping: 0.903220 [24/256] Error: 0.308461, Current Frequency: 0.349960, Current Damping: 0.905077 [25/256] Error: 0.302019, Current Frequency: 0.361099, Current Damping: 0.905660 [26/256] Error: 0.295773, Current Frequency: 0.372305, Current Damping: 0.904608 [27/256] Error: 0.289680, Current Frequency: 0.383577, Current Damping: 0.901750 [28/256] Error: 0.283764, Current Frequency: 0.394918, Current Damping: 0.897167 [29/256] Error: 0.278155, Current Frequency: 0.406327, Current Damping: 0.891216 [30/256] Error: 0.273123, Current Frequency: 0.417807, Current Damping: 0.884071 [31/256] Error: 0.268553, Current Frequency: 0.429359, Current Damping: 0.875970 [32/256] Error: 0.264319, Current Frequency: 0.440986, Current Damping: 0.867120 [33/256] Error: 0.260362, Current Frequency: 0.452691, Current Damping: 0.857686 [34/256] Error: 0.256635, Current Frequency: 0.464476, Current Damping: 0.847800 [35/256] Error: 0.253097, Current Frequency: 0.476344, Current Damping: 0.837563 [36/256] Error: 0.249711, Current Frequency: 0.488295, Current Damping: 0.827056 [37/256] Error: 0.246444, Current Frequency: 0.500333, Current Damping: 0.816345 [38/256] Error: 0.243264, Current Frequency: 0.512460, Current Damping: 0.805480 [39/256] Error: 0.240144, Current Frequency: 0.524676, Current Damping: 0.794503 [40/256] Error: 0.237061, Current Frequency: 0.536984, Current Damping: 0.783448 [41/256] Error: 0.233992, Current Frequency: 0.549385, Current Damping: 0.772342 [42/256] Error: 0.230918, Current Frequency: 0.561879, Current Damping: 0.761208 [43/256] Error: 0.227826, Current Frequency: 0.574467, Current Damping: 0.750065 [44/256] Error: 0.224699, Current Frequency: 0.587150, Current Damping: 0.738929 [45/256] Error: 0.221528, Current Frequency: 0.599929, Current Damping: 0.727814 [46/256] Error: 0.218301, Current Frequency: 0.612802, Current Damping: 0.716732 [47/256] Error: 0.215009, Current Frequency: 0.625771, Current Damping: 0.705692 [48/256] Error: 0.211644, Current Frequency: 0.638834, Current Damping: 0.694702 [49/256] Error: 0.208198, Current Frequency: 0.651991, Current Damping: 0.683771 [50/256] Error: 0.204664, Current Frequency: 0.665240, Current Damping: 0.672904 [51/256] Error: 0.201035, Current Frequency: 0.678579, Current Damping: 0.662108 [52/256] Error: 0.197305, Current Frequency: 0.692005, Current Damping: 0.651388 [53/256] Error: 0.193470, Current Frequency: 0.705512, Current Damping: 0.640747 [54/256] Error: 0.189525, Current Frequency: 0.719092, Current Damping: 0.630192 [55/256] Error: 0.185471, Current Frequency: 0.732722, Current Damping: 0.619725 [56/256] Error: 0.181314, Current Frequency: 0.746363, Current Damping: 0.609352 [57/256] Error: 0.177069, Current Frequency: 0.759956, Current Damping: 0.599080 [58/256] Error: 0.172748, Current Frequency: 0.773464, Current Damping: 0.588911 [59/256] Error: 0.168354, Current Frequency: 0.786880, Current Damping: 0.578846 [60/256] Error: 0.163886, Current Frequency: 0.800205, Current Damping: 0.568887 [61/256] Error: 0.159341, Current Frequency: 0.813448, Current Damping: 0.559033 [62/256] Error: 0.154722, Current Frequency: 0.826611, Current Damping: 0.549287 [63/256] Error: 0.150035, Current Frequency: 0.839691, Current Damping: 0.539648 [64/256] Error: 0.145289, Current Frequency: 0.852672, Current Damping: 0.530120 [65/256] Error: 0.140505, Current Frequency: 0.865514, Current Damping: 0.520706 [66/256] Error: 0.135711, Current Frequency: 0.878160, Current Damping: 0.511411 [67/256] Error: 0.130938, Current Frequency: 0.890555, Current Damping: 0.502239 [68/256] Error: 0.126211, Current Frequency: 0.902653, Current Damping: 0.493190 [69/256] Error: 0.121555, Current Frequency: 0.914407, Current Damping: 0.484268 [70/256] Error: 0.117001, Current Frequency: 0.925759, Current Damping: 0.475475 [71/256] Error: 0.112585, Current Frequency: 0.936632, Current Damping: 0.466816 [72/256] Error: 0.108341, Current Frequency: 0.946943, Current Damping: 0.458296 [73/256] Error: 0.104298, Current Frequency: 0.956605, Current Damping: 0.449918 [74/256] Error: 0.100481, Current Frequency: 0.965520, Current Damping: 0.441689 [75/256] Error: 0.096906, Current Frequency: 0.973594, Current Damping: 0.433613 [76/256] Error: 0.093580, Current Frequency: 0.980734, Current Damping: 0.425697 [77/256] Error: 0.090497, Current Frequency: 0.986856, Current Damping: 0.417946 [78/256] Error: 0.087637, Current Frequency: 0.991888, Current Damping: 0.410366 [79/256] Error: 0.084970, Current Frequency: 0.995767, Current Damping: 0.402960 [80/256] Error: 0.082456, Current Frequency: 0.998458, Current Damping: 0.395733 [81/256] Error: 0.080051, Current Frequency: 0.999949, Current Damping: 0.388685 [82/256] Error: 0.077711, Current Frequency: 1.000257, Current Damping: 0.381817 [83/256] Error: 0.075395, Current Frequency: 0.999422, Current Damping: 0.375128 [84/256] Error: 0.073072, Current Frequency: 0.997512, Current Damping: 0.368614 [85/256] Error: 0.070725, Current Frequency: 0.994614, Current Damping: 0.362269 [86/256] Error: 0.068353, Current Frequency: 0.990838, Current Damping: 0.356087 [87/256] Error: 0.065968, Current Frequency: 0.986310, Current Damping: 0.350061 [88/256] Error: 0.063604, Current Frequency: 0.981175, Current Damping: 0.344185 [89/256] Error: 0.061307, Current Frequency: 0.975590, Current Damping: 0.338450 [90/256] Error: 0.059134, Current Frequency: 0.969731, Current Damping: 0.332852 [91/256] Error: 0.057144, Current Frequency: 0.963785, Current Damping: 0.327386 [92/256] Error: 0.055386, Current Frequency: 0.957946, Current Damping: 0.322051 [93/256] Error: 0.053888, Current Frequency: 0.952415, Current Damping: 0.316846 [94/256] Error: 0.052642, Current Frequency: 0.947385, Current Damping: 0.311773 [95/256] Error: 0.051598, Current Frequency: 0.943027, Current Damping: 0.306836 [96/256] Error: 0.050671, Current Frequency: 0.939483, Current Damping: 0.302037 [97/256] Error: 0.049750, Current Frequency: 0.936860, Current Damping: 0.297379 [98/256] Error: 0.048718, Current Frequency: 0.935224, Current Damping: 0.292863 [99/256] Error: 0.047465, Current Frequency: 0.934593, Current Damping: 0.288490 [100/256] Error: 0.045908, Current Frequency: 0.934943, Current Damping: 0.284258 [101/256] Error: 0.043998, Current Frequency: 0.936210, Current Damping: 0.280164 [102/256] Error: 0.041722, Current Frequency: 0.938294, Current Damping: 0.276202 [103/256] Error: 0.039108, Current Frequency: 0.941064, Current Damping: 0.272367 [104/256] Error: 0.036221, Current Frequency: 0.944357, Current Damping: 0.268653 [105/256] Error: 0.033159, Current Frequency: 0.947982, Current Damping: 0.265056 [106/256] Error: 0.030053, Current Frequency: 0.951718, Current Damping: 0.261571 [107/256] Error: 0.027058, Current Frequency: 0.955319, Current Damping: 0.258197 [108/256] Error: 0.024338, Current Frequency: 0.958529, Current Damping: 0.254939 [109/256] Error: 0.022050, Current Frequency: 0.961105, Current Damping: 0.251807 [110/256] Error: 0.020332, Current Frequency: 0.962857, Current Damping: 0.248819 [111/256] Error: 0.019311, Current Frequency: 0.963734, Current Damping: 0.246005 [112/256] Error: 0.019093, Current Frequency: 0.963559, Current Damping: 0.243415 [113/256] Error: 0.019669, Current Frequency: 0.962275, Current Damping: 0.241101 [114/256] Error: 0.020845, Current Frequency: 0.960001, Current Damping: 0.239078 [115/256] Error: 0.022372, Current Frequency: 0.957020, Current Damping: 0.237332 [116/256] Error: 0.024107, Current Frequency: 0.953687, Current Damping: 0.235833 [117/256] Error: 0.025951, Current Frequency: 0.950354, Current Damping: 0.234553 [118/256] Error: 0.027791, Current Frequency: 0.947331, Current Damping: 0.233465 [119/256] Error: 0.029500, Current Frequency: 0.944854, Current Damping: 0.232547 [120/256] Error: 0.030960, Current Frequency: 0.943076, Current Damping: 0.231781 [121/256] Error: 0.032079, Current Frequency: 0.942071, Current Damping: 0.231152 [122/256] Error: 0.032806, Current Frequency: 0.941834, Current Damping: 0.230649 [123/256] Error: 0.033135, Current Frequency: 0.942293, Current Damping: 0.230265 [124/256] Error: 0.033102, Current Frequency: 0.943323, Current Damping: 0.229995 [125/256] Error: 0.032775, Current Frequency: 0.944763, Current Damping: 0.229834 [126/256] Error: 0.032240, Current Frequency: 0.946428, Current Damping: 0.229783 [127/256] Error: 0.031589, Current Frequency: 0.948126, Current Damping: 0.229839 [128/256] Error: 0.030901, Current Frequency: 0.949679, Current Damping: 0.230002 [129/256] Error: 0.030243, Current Frequency: 0.950939, Current Damping: 0.230269 [130/256] Error: 0.029655, Current Frequency: 0.951802, Current Damping: 0.230635 [131/256] Error: 0.029162, Current Frequency: 0.952222, Current Damping: 0.231094 [132/256] Error: 0.028773, Current Frequency: 0.952212, Current Damping: 0.231633 [133/256] Error: 0.028490, Current Frequency: 0.951839, Current Damping: 0.232239 [134/256] Error: 0.028309, Current Frequency: 0.951209, Current Damping: 0.232897 [135/256] Error: 0.028214, Current Frequency: 0.950451, Current Damping: 0.233588 [136/256] Error: 0.028179, Current Frequency: 0.949698, Current Damping: 0.234296 [137/256] Error: 0.028165, Current Frequency: 0.949068, Current Damping: 0.235005 [138/256] Error: 0.028130, Current Frequency: 0.948651, Current Damping: 0.235700 [139/256] Error: 0.028034, Current Frequency: 0.948505, Current Damping: 0.236371 [140/256] Error: 0.027848, Current Frequency: 0.948650, Current Damping: 0.237009 [141/256] Error: 0.027559, Current Frequency: 0.949073, Current Damping: 0.237608 [142/256] Error: 0.027169, Current Frequency: 0.949727, Current Damping: 0.238163 [143/256] Error: 0.026697, Current Frequency: 0.950546, Current Damping: 0.238675 [144/256] Error: 0.026175, Current Frequency: 0.951446, Current Damping: 0.239143 [145/256] Error: 0.025637, Current Frequency: 0.952340, Current Damping: 0.239568 [146/256] Error: 0.025122, Current Frequency: 0.953147, Current Damping: 0.239951 [147/256] Error: 0.024662, Current Frequency: 0.953800, Current Damping: 0.240295 [148/256] Error: 0.024285, Current Frequency: 0.954252, Current Damping: 0.240601 [149/256] Error: 0.024007, Current Frequency: 0.954484, Current Damping: 0.240869 [150/256] Error: 0.023833, Current Frequency: 0.954502, Current Damping: 0.241099 [151/256] Error: 0.023763, Current Frequency: 0.954336, Current Damping: 0.241290 [152/256] Error: 0.023782, Current Frequency: 0.954034, Current Damping: 0.241441 [153/256] Error: 0.023872, Current Frequency: 0.953650, Current Damping: 0.241550 [154/256] Error: 0.024010, Current Frequency: 0.953245, Current Damping: 0.241617 [155/256] Error: 0.024169, Current Frequency: 0.952871, Current Damping: 0.241643 [156/256] Error: 0.024328, Current Frequency: 0.952570, Current Damping: 0.241628 [157/256] Error: 0.024464, Current Frequency: 0.952370, Current Damping: 0.241575 [158/256] Error: 0.024566, Current Frequency: 0.952281, Current Damping: 0.241488 [159/256] Error: 0.024626, Current Frequency: 0.952300, Current Damping: 0.241372 [160/256] Error: 0.024646, Current Frequency: 0.952408, Current Damping: 0.241231 [161/256] Error: 0.024633, Current Frequency: 0.952578, Current Damping: 0.241071 [162/256] Error: 0.024598, Current Frequency: 0.952777, Current Damping: 0.240898 [163/256] Error: 0.024556, Current Frequency: 0.952969, Current Damping: 0.240718 [164/256] Error: 0.024521, Current Frequency: 0.953123, Current Damping: 0.240536 [165/256] Error: 0.024504, Current Frequency: 0.953215, Current Damping: 0.240357 [166/256] Error: 0.024515, Current Frequency: 0.953230, Current Damping: 0.240186 [167/256] Error: 0.024558, Current Frequency: 0.953164, Current Damping: 0.240023 [168/256] Error: 0.024634, Current Frequency: 0.953026, Current Damping: 0.239873 [169/256] Error: 0.024737, Current Frequency: 0.952831, Current Damping: 0.239735 [170/256] Error: 0.024861, Current Frequency: 0.952602, Current Damping: 0.239610 [171/256] Error: 0.024997, Current Frequency: 0.952365, Current Damping: 0.239497 [172/256] Error: 0.025133, Current Frequency: 0.952143, Current Damping: 0.239396 [173/256] Error: 0.025258, Current Frequency: 0.951959, Current Damping: 0.239307 [174/256] Error: 0.025365, Current Frequency: 0.951826, Current Damping: 0.239228 [175/256] Error: 0.025446, Current Frequency: 0.951753, Current Damping: 0.239160 [176/256] Error: 0.025499, Current Frequency: 0.951739, Current Damping: 0.239102 [177/256] Error: 0.025523, Current Frequency: 0.951777, Current Damping: 0.239054 [178/256] Error: 0.025522, Current Frequency: 0.951855, Current Damping: 0.239016 [179/256] Error: 0.025500, Current Frequency: 0.951958, Current Damping: 0.238989 [180/256] Error: 0.025464, Current Frequency: 0.952069, Current Damping: 0.238973 [181/256] Error: 0.025422, Current Frequency: 0.952171, Current Damping: 0.238968 [182/256] Error: 0.025380, Current Frequency: 0.952253, Current Damping: 0.238974 [183/256] Error: 0.025344, Current Frequency: 0.952305, Current Damping: 0.238989 [184/256] Error: 0.025317, Current Frequency: 0.952326, Current Damping: 0.239014 [185/256] Error: 0.025300, Current Frequency: 0.952316, Current Damping: 0.239047 [186/256] Error: 0.025294, Current Frequency: 0.952282, Current Damping: 0.239086 [187/256] Error: 0.025296, Current Frequency: 0.952232, Current Damping: 0.239130 [188/256] Error: 0.025303, Current Frequency: 0.952177, Current Damping: 0.239178 [189/256] Error: 0.025312, Current Frequency: 0.952127, Current Damping: 0.239226 [190/256] Error: 0.025318, Current Frequency: 0.952091, Current Damping: 0.239275 [191/256] Error: 0.025318, Current Frequency: 0.952075, Current Damping: 0.239323 [192/256] Error: 0.025310, Current Frequency: 0.952081, Current Damping: 0.239368 [193/256] Error: 0.025293, Current Frequency: 0.952110, Current Damping: 0.239411 [194/256] Error: 0.025268, Current Frequency: 0.952157, Current Damping: 0.239451 [195/256] Error: 0.025236, Current Frequency: 0.952217, Current Damping: 0.239487 [196/256] Error: 0.025199, Current Frequency: 0.952282, Current Damping: 0.239519 [197/256] Error: 0.025161, Current Frequency: 0.952346, Current Damping: 0.239549 [198/256] Error: 0.025125, Current Frequency: 0.952401, Current Damping: 0.239575 [199/256] Error: 0.025093, Current Frequency: 0.952442, Current Damping: 0.239598 [200/256] Error: 0.025069, Current Frequency: 0.952467, Current Damping: 0.239618 [201/256] Error: 0.025052, Current Frequency: 0.952475, Current Damping: 0.239636 [202/256] Error: 0.025043, Current Frequency: 0.952468, Current Damping: 0.239650 [203/256] Error: 0.025042, Current Frequency: 0.952448, Current Damping: 0.239662 [204/256] Error: 0.025047, Current Frequency: 0.952420, Current Damping: 0.239670 [205/256] Error: 0.025057, Current Frequency: 0.952389, Current Damping: 0.239675 [206/256] Error: 0.025068, Current Frequency: 0.952361, Current Damping: 0.239677 [207/256] Error: 0.025080, Current Frequency: 0.952338, Current Damping: 0.239676 [208/256] Error: 0.025090, Current Frequency: 0.952323, Current Damping: 0.239672 [209/256] Error: 0.025098, Current Frequency: 0.952318, Current Damping: 0.239665 [210/256] Error: 0.025102, Current Frequency: 0.952321, Current Damping: 0.239656 [211/256] Error: 0.025103, Current Frequency: 0.952331, Current Damping: 0.239645 [212/256] Error: 0.025102, Current Frequency: 0.952345, Current Damping: 0.239632 [213/256] Error: 0.025100, Current Frequency: 0.952360, Current Damping: 0.239619 [214/256] Error: 0.025097, Current Frequency: 0.952374, Current Damping: 0.239605 [215/256] Error: 0.025096, Current Frequency: 0.952383, Current Damping: 0.239592 [216/256] Error: 0.025096, Current Frequency: 0.952387, Current Damping: 0.239578 [217/256] Error: 0.025098, Current Frequency: 0.952383, Current Damping: 0.239566 [218/256] Error: 0.025104, Current Frequency: 0.952374, Current Damping: 0.239554 [219/256] Error: 0.025111, Current Frequency: 0.952359, Current Damping: 0.239543 [220/256] Error: 0.025121, Current Frequency: 0.952341, Current Damping: 0.239533 [221/256] Error: 0.025132, Current Frequency: 0.952323, Current Damping: 0.239524 [222/256] Error: 0.025142, Current Frequency: 0.952305, Current Damping: 0.239516 [223/256] Error: 0.025152, Current Frequency: 0.952291, Current Damping: 0.239509 [224/256] Error: 0.025160, Current Frequency: 0.952281, Current Damping: 0.239503 [225/256] Error: 0.025166, Current Frequency: 0.952277, Current Damping: 0.239498 [226/256] Error: 0.025170, Current Frequency: 0.952277, Current Damping: 0.239493 [227/256] Error: 0.025171, Current Frequency: 0.952281, Current Damping: 0.239490 [228/256] Error: 0.025171, Current Frequency: 0.952288, Current Damping: 0.239487 [229/256] Error: 0.025169, Current Frequency: 0.952296, Current Damping: 0.239485 [230/256] Error: 0.025165, Current Frequency: 0.952305, Current Damping: 0.239484 [231/256] Error: 0.025162, Current Frequency: 0.952312, Current Damping: 0.239484 [232/256] Error: 0.025159, Current Frequency: 0.952318, Current Damping: 0.239485 [233/256] Error: 0.025156, Current Frequency: 0.952320, Current Damping: 0.239487 [234/256] Error: 0.025155, Current Frequency: 0.952320, Current Damping: 0.239489 [235/256] Error: 0.025154, Current Frequency: 0.952318, Current Damping: 0.239493 [236/256] Error: 0.025154, Current Frequency: 0.952314, Current Damping: 0.239496 [237/256] Error: 0.025155, Current Frequency: 0.952309, Current Damping: 0.239500 [238/256] Error: 0.025155, Current Frequency: 0.952305, Current Damping: 0.239504 [239/256] Error: 0.025156, Current Frequency: 0.952302, Current Damping: 0.239508 [240/256] Error: 0.025156, Current Frequency: 0.952301, Current Damping: 0.239512 [241/256] Error: 0.025155, Current Frequency: 0.952302, Current Damping: 0.239516 [242/256] Error: 0.025154, Current Frequency: 0.952304, Current Damping: 0.239519 [243/256] Error: 0.025152, Current Frequency: 0.952308, Current Damping: 0.239522 [244/256] Error: 0.025149, Current Frequency: 0.952314, Current Damping: 0.239525 [245/256] Error: 0.025146, Current Frequency: 0.952319, Current Damping: 0.239528 [246/256] Error: 0.025143, Current Frequency: 0.952324, Current Damping: 0.239530 [247/256] Error: 0.025140, Current Frequency: 0.952328, Current Damping: 0.239532 [248/256] Error: 0.025138, Current Frequency: 0.952331, Current Damping: 0.239534 [249/256] Error: 0.025136, Current Frequency: 0.952332, Current Damping: 0.239535 [250/256] Error: 0.025135, Current Frequency: 0.952332, Current Damping: 0.239537 [251/256] Error: 0.025134, Current Frequency: 0.952331, Current Damping: 0.239538 [252/256] Error: 0.025135, Current Frequency: 0.952329, Current Damping: 0.239538 [253/256] Error: 0.025135, Current Frequency: 0.952326, Current Damping: 0.239539 [254/256] Error: 0.025136, Current Frequency: 0.952324, Current Damping: 0.239539 [255/256] Error: 0.025137, Current Frequency: 0.952321, Current Damping: 0.239539 [256/256] Error: 0.025138, Current Frequency: 0.952320, Current Damping: 0.239538 Best frequency: 0.963559, relative error: 3.644131% Best damping: 0.243415, relative error: 2.633959%
_, _, sha_states_ref, sha_times_ref, _ = adaptive_rk45_integrator.apply(sha_dynamics, initial_x, initial_time, final_time, dt, atol, rtol, frequency, damping)
fig_ref_position, axes_ref_position = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Position Ref.")
fig_ref_velocity, axes_ref_velocity = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Velocity Ref.")
_, _, sha_states_optimised, sha_times_optimised, _ = adaptive_rk45_integrator.apply(sha_dynamics, initial_x, initial_time, final_time, dt, atol, rtol, best_frequency, best_damping)
_ = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_position, method_label="RK4(5) - SHA Position Opt.")
_ = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_velocity, method_label="RK4(5) - SHA Velocity Opt.")
While we were able to infer the parameters closely through gradient descent, they are still imperfect, and you will notice that this implementation is quite slow. We could speed up some of the optimisation by using only subsets of the samples aka mini-batching at each gradient descent step.
First, we'll collate the reference states into one tensor and the reference times into another. At each iteration, we'll sample some subset and integrate the system to the appropriate times taking care to sort the integration times to avoid integrating backwards (although we could).
Second, we'll write a closure function that takes in the batch at each step and runs our previous closure routine.
time_dataset = torch.stack([time for _, time in sha_states][1:])
state_dataset = torch.stack([state for state, _ in sha_states][1:])
# We skip the first time as we know the initial state
# We reinitialise our variables
optimised_frequency = torch.tensor(0.1, requires_grad=True)
optimised_damping = torch.tensor(1.0, requires_grad=True)
# As damping needs to be a strictly positive quantity, we log-encode it
log_encoded_damping = torch.log(optimised_damping.detach()).requires_grad_(True)
# First, we'll create an `optimiser` following pytorch convention
optimiser = torch.optim.Adam([optimised_frequency, log_encoded_damping], lr=1e-2, amsgrad=True)
# Next, we'll define a closure function whose sole purpose is to
# zero the gradients and compute the error. This is useful as it allows switching to other
# optimisers such as LBFGS or anything that re-evaluates the error without
# computing its gradient
def sha_closure(minibatch):
current_state = initial_x.clone()
current_time = initial_time.clone()
optimiser.zero_grad()
error = 0.0
times = minibatch['times']
states = minibatch['states']
# We need to sort both times and states simultaneously, so we'll use `argsort`
sorted_time_indices = torch.argsort(times)
times, states = times[sorted_time_indices], states[sorted_time_indices]
for sample_state, sample_time in zip(states, times):
new_state, new_time, *_ = adaptive_rk45_integrator.apply(sha_dynamics, current_state, current_time, sample_time, torch.minimum(dt, sample_time - current_time), atol, rtol, optimised_frequency, torch.exp(log_encoded_damping))
error = error + torch.linalg.norm(sample_state - new_state)/times.shape[0]
current_state, current_time = new_state, new_time
if error.requires_grad:
error.backward()
return error
# We need to set the size of our mini-batches
batch_size = 16
# Now we need an optimisation `loop` where we will take steps to minimise the error
# We set the number of steps proportionally smaller to account for the fact that at each iteration
# we take `time_dataset.shape[0]//batch_size` steps instead of just 1
number_of_gd_steps = 256*batch_size//time_dataset.shape[0]
# We also need to track the best solution thus far
best_error = torch.inf
best_frequency, best_damping = optimised_frequency.detach().clone(), optimised_damping.detach().clone()
for step in range(number_of_gd_steps):
epoch_error = 0.0
shuffled_indices = torch.randperm(time_dataset.shape[0])
for batch_idx in range(0, time_dataset.shape[0], batch_size):
batch_dict = {
'times': time_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
'states': state_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
}
step_error = optimiser.step(lambda: sha_closure(batch_dict))
epoch_error = epoch_error + step_error.item()*batch_dict['times'].shape[0]
print(f"[{step+1}/{number_of_gd_steps}]/[{batch_idx}/{time_dataset.shape[0]}] Batch Error: {step_error:.6f}, Current Frequency: {optimised_frequency.item():.4f}, Current Damping: {torch.exp(log_encoded_damping).item():.4f}", end='\r')
epoch_error = epoch_error/time_dataset.shape[0]
if epoch_error < best_error:
best_error = epoch_error
best_frequency = optimised_frequency.detach().clone()
best_damping = torch.exp(log_encoded_damping.detach().clone())
print(" "*128, end="\r")
print(f"[{step+1}/{number_of_gd_steps}] Epoch Error: {epoch_error:.6f}, Current Frequency: {optimised_frequency.item():.6f}, Current Damping: {torch.exp(log_encoded_damping).item():.6f}")
print(f"Best frequency: {best_frequency.item():.6f}, relative error: {torch.mean(torch.abs(1 - best_frequency / frequency)).item():.6%}")
print(f"Best damping: {best_damping.item():.6f}, relative error: {torch.mean(torch.abs(1 - best_damping / damping)).item():.6%}")
[1/33] Epoch Error: 0.628324, Current Frequency: 0.178396, Current Damping: 0.923665 [2/33] Epoch Error: 0.452757, Current Frequency: 0.250116, Current Damping: 0.852145 [3/33] Epoch Error: 0.352843, Current Frequency: 0.309024, Current Damping: 0.791153 [4/33] Epoch Error: 0.302845, Current Frequency: 0.355167, Current Damping: 0.748466 [5/33] Epoch Error: 0.281069, Current Frequency: 0.391215, Current Damping: 0.717315 [6/33] Epoch Error: 0.269684, Current Frequency: 0.417420, Current Damping: 0.700159 [7/33] Epoch Error: 0.264258, Current Frequency: 0.439627, Current Damping: 0.689579 [8/33] Epoch Error: 0.259327, Current Frequency: 0.458514, Current Damping: 0.684602 [9/33] Epoch Error: 0.254482, Current Frequency: 0.474936, Current Damping: 0.685596 [10/33] Epoch Error: 0.250911, Current Frequency: 0.492765, Current Damping: 0.683432 [11/33] Epoch Error: 0.246691, Current Frequency: 0.510047, Current Damping: 0.685052 [12/33] Epoch Error: 0.242817, Current Frequency: 0.528726, Current Damping: 0.684228 [13/33] Epoch Error: 0.238562, Current Frequency: 0.547995, Current Damping: 0.680715 [14/33] Epoch Error: 0.233992, Current Frequency: 0.567164, Current Damping: 0.677625 [15/33] Epoch Error: 0.229790, Current Frequency: 0.588169, Current Damping: 0.670971 [16/33] Epoch Error: 0.224707, Current Frequency: 0.609309, Current Damping: 0.664911 [17/33] Epoch Error: 0.219676, Current Frequency: 0.631531, Current Damping: 0.656331 [18/33] Epoch Error: 0.214181, Current Frequency: 0.654534, Current Damping: 0.644180 [19/33] Epoch Error: 0.208597, Current Frequency: 0.679353, Current Damping: 0.631005 [20/33] Epoch Error: 0.201547, Current Frequency: 0.703593, Current Damping: 0.618783 [21/33] Epoch Error: 0.195240, Current Frequency: 0.730725, Current Damping: 0.600403 [22/33] Epoch Error: 0.186872, Current Frequency: 0.758010, Current Damping: 0.582101 [23/33] Epoch Error: 0.177497, Current Frequency: 0.785443, Current Damping: 0.562757 [24/33] Epoch Error: 0.167940, Current Frequency: 0.815522, Current Damping: 0.536682 [25/33] Epoch Error: 0.155794, Current Frequency: 0.847253, Current Damping: 0.508472 [26/33] Epoch Error: 0.142342, Current Frequency: 0.881644, Current Damping: 0.473993 [27/33] Epoch Error: 0.124190, Current Frequency: 0.916691, Current Damping: 0.439423 [28/33] Epoch Error: 0.104539, Current Frequency: 0.953759, Current Damping: 0.399611 [29/33] Epoch Error: 0.080617, Current Frequency: 0.991097, Current Damping: 0.358279 [30/33] Epoch Error: 0.054462, Current Frequency: 1.022754, Current Damping: 0.317208 [31/33] Epoch Error: 0.034670, Current Frequency: 1.032316, Current Damping: 0.280054 [32/33] Epoch Error: 0.018756, Current Frequency: 1.004898, Current Damping: 0.255116 [33/33] Epoch Error: 0.006446, Current Frequency: 0.994217, Current Damping: 0.239427 Best frequency: 0.994217, relative error: 0.578333% Best damping: 0.239427, relative error: 4.229135%
_, _, sha_states_ref, sha_times_ref, _ = adaptive_rk45_integrator.apply(sha_dynamics, initial_x, initial_time, final_time, dt, atol, rtol, frequency, damping)
fig_ref_position, axes_ref_position = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Position Ref.")
fig_ref_velocity, axes_ref_velocity = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Velocity Ref.")
_, _, sha_states_optimised, sha_times_optimised, _ = adaptive_rk45_integrator.apply(sha_dynamics, initial_x, initial_time, final_time, dt, atol, rtol, best_frequency, best_damping)
_ = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_position, method_label="RK4(5) - SHA Position Opt.")
_ = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_velocity, method_label="RK4(5) - SHA Velocity Opt.")
Excellent, this takes the same number of optimisation steps, but we achieve better results. While at each step, the estimate is noisy as we don't account for all the points, we do gain in terms of convergence speed. Furthermore, we achieve a better error because we're less likely to be trapped in a local minimum.
Let's now suppose that we don't know what the dynamics are at all, can we learn the appropriate system? Probably, but right now our method of computing the gradients is exceptionally inefficient, and is the reason why we cannot run our integration at higher precision. To resolve this, we will need to employ the adjoint method.
Appendix - Local Extrapolation¶
If you'll notice, in the above adaptive integration schemes, we have two estimates of the trajectory, a higher and a lower order one. While the error analysis of stepsize adaptation applies to the trajectory of the lower order estimate, in practice it can make sense to use the higher order estimate as it provides improved convergence for many systems. This is referred to as local extrapolation.
Below, we demonstrate how local extrapolation effects the error terms and their growth over a trajectory:
adaptive_rk45_integrator_without_local_extrapolation = get_integrator(torch.tensor([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[1/5, 1/5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[3/10, 3/40, 9/40, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[4/5, 44/45, -56/15, 32/9, 0.0, 0.0, 0.0, 0.0 ],
[8/9, 19372/6561, -25360/2187, 64448/6561, -212/729, 0.0, 0.0, 0.0 ],
[1.0, 9017/3168, -355/33, 46732/5247, 49/176, -5103/18656, 0.0, 0.0 ],
[1.0, 35/384, 0.0, 500/1113, 125/192, -2187/6784, 11/84, 0.0 ],
[torch.inf, 35/384, 0.0, 500/1113, 125/192, -2187/6784, 11/84, 0.0 ],
[torch.inf, 5179/57600, 0.0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40]
], dtype=torch.float64), integrator_order = 5, use_local_extrapolation = False, integrator_name = "AdaptiveRK45IntegratorNoLocalExtrapolation")
adaptive_rk45_integrator_with_local_extrapolation = get_integrator(adaptive_rk45_integrator_without_local_extrapolation.integrator_tableau, integrator_order = 5, use_local_extrapolation = True, integrator_name = "AdaptiveRK45IntegratorLocalExtrapolation")
dt = (t1 - t0)/1e6
*_, sub_states_upper, sub_times_upper, state_errors_upper = adaptive_rk45_integrator_with_local_extrapolation.apply(neuralode.dynamics.exponential_fn, x0, t0, t1, dt, torch.tensor(5e-16), torch.tensor(5e-16))
*_, sub_states_lower, sub_times_lower, state_errors_lower = adaptive_rk45_integrator_without_local_extrapolation.apply(neuralode.dynamics.exponential_fn, x0, t0, t1, dt, torch.tensor(5e-16), torch.tensor(5e-16))
print(f"Error in {adaptive_rk45_integrator_with_local_extrapolation}: {(sub_states_upper[-1] - neuralode.dynamics.exponential_fn_solution(x0, t1)).abs().item()}")
print(f"Error in {adaptive_rk45_integrator_without_local_extrapolation}: {(sub_states_lower[-1] - neuralode.dynamics.exponential_fn_solution(x0, t1)).abs().item()}")
reference_trajectory = [neuralode.dynamics.exponential_fn_solution(x0, t) for t in sub_times_upper]
fig, axes = neuralode.plot.trajectory.plot_trajectory_with_reference([(i[0], j) for i, j in zip(sub_states_upper, sub_times_upper)], reference_trajectory, method_label="RK4(5) Method with Local Extrapolation")
# axes[1].plot([t.item() for _, t in sub_states_upper], [e.abs().item() for e in state_errors_upper], marker='x', label="Estimated Error")
# axes[1].legend()
reference_trajectory = [neuralode.dynamics.exponential_fn_solution(x0, t) for t in sub_times_lower]
fig, axes = neuralode.plot.trajectory.plot_trajectory_with_reference([(i[0], j) for i, j in zip(sub_states_lower, sub_times_lower)], reference_trajectory, method_label="RK4(5) Method without Local Extrapolation", axes=axes)
# axes[1].plot([t.item() for _, t in sub_states_lower], [e.abs().item() for e in state_errors_lower], marker='x', label="Estimated Error")
# axes[1].legend()
C:\Users\ekin4\AppData\Local\Temp\ipykernel_5888\3984266084.py:21: RuntimeWarning: Atol is smaller than the square root of the epsilon for torch.float64, this may increase truncation error warnings.warn(f"{tol_name.title()} is smaller than the square root of the epsilon for {x0.dtype}, this may increase truncation error", RuntimeWarning) C:\Users\ekin4\AppData\Local\Temp\ipykernel_5888\3984266084.py:21: RuntimeWarning: Rtol is smaller than the square root of the epsilon for torch.float64, this may increase truncation error warnings.warn(f"{tol_name.title()} is smaller than the square root of the epsilon for {x0.dtype}, this may increase truncation error", RuntimeWarning)
Error in <class '__main__.AdaptiveRK45IntegratorLocalExtrapolation'>: 5.551115123125783e-17 Error in <class '__main__.AdaptiveRK45IntegratorNoLocalExtrapolation'>: 4.6129766673175254e-14
As you can see, the use of local extrapolation leads to higher precision results over the whole trajectory for the same error tolerances. This usually implies that a) we can take fewer integration steps to achieve the same results and b) we can do so by relaxing the tolerances.