The Adjoint Method¶
In this notebook, we will go through the process of implementing the adjoint method for computing gradients arising from a numerical integration.
import typing
import warnings
import math
import torch
import einops
import neuralode
import random
import numpy as np
warnings.simplefilter('ignore', RuntimeWarning)
# For convenience, we define the default tensor device and dtype here
torch.set_default_device('cpu')
# In neural networks, we prefer 32-bit/16-bit floats, but for precise integration, 64-bit is preferred. We will revisit this later when we need to mix integration with neural network training
torch.set_default_dtype(torch.float64)
# Set the random seeds so that each run of this notebook is reproducible
torch.manual_seed(1)
random.seed(1)
np.random.seed(1)
We'll be starting with the same function from the previous notebook, but we have tidied up the code by using neuralode.util.partial_compensated_sum
to track the truncated bits instead of duplicating the code through our integration function.
We have moved many of the functions into a submodule. Most importantly, the class creation and finalisation have been moved into neuralode.integrators.classes
, and the integrators are now subclasses of neuralode.integrators.classes.Integrator
which enables us to do type checking and signature checking both statically and at runtime if needed. Duplicated functionality in checking the tolerances has been moved into neuralode.integrators.helpers.ensure_tolerance
function, and consistency checking of the timestep has been put into neuralode.integrators.helpers.ensure_timestep
.
Also, the entire forward integration loop has now been moved into neuralode.integrators.routines.integrate_system
which allows us to separate the concerns of torch.autograd.Function
and the actual integration (e.g. ctx does not need to be in the same scope). This requires passing around several packaged arguments, but gains us conciseness in the code here, allowing us to focus on the adjoint method implementation.
def get_backward_method_adjoint(integrator_type: typing.Type[neuralode.integrators.classes.Integrator]):
def __internal_backward(ctx: torch.autograd.function.FunctionCtx,
d_c_state: torch.Tensor,
d_c_time: torch.Tensor,
d_intermediate_states: torch.Tensor,
d_intermediate_times: torch.Tensor,
d_error_in_state: torch.Tensor) -> tuple[torch.Tensor | None]:
"""
This function computes the gradient of the input variables for `__internal_forward` by implementing
the adjoint method. This involves computing the adjoint state and adjoint state equation and systematically
integrating backwards, from t1 to t0, accumulating the gradient to obtain the gradient wrt. the input variables.
:param ctx: Function context for storing variables and function from the forward pass
:param d_c_state: incoming gradient of c_state
:param d_c_time: incoming gradient of c_time
:param d_intermediate_states: incoming gradient of intermediate_states
:param d_intermediate_times: incoming gradient of intermediate_times
:param d_error_in_state: incoming gradient of error_in_state. This output is non-differentiable.
:return: The gradients wrt. all the inputs.
"""
# First, we retrieve our integration function that we stored in `__internal_forward`
forward_fn: neuralode.integrators.signatures.integration_fn_signature = ctx.forward_fn
integrator_kwargs = ctx.integrator_kwargs
# Then we retrieve the input variables
(
x0,
t0,
t1,
dt,
c_state,
c_time,
intermediate_times,
*additional_dynamic_args,
) = ctx.saved_tensors
inputs = forward_fn, x0, t0, t1, dt, integrator_kwargs, *additional_dynamic_args
input_grads: list[torch.Tensor | None] = [None for _ in range(len(inputs))]
if any(ctx.needs_input_grad):
# Construct the adjoint equation
adjoint_fn = neuralode.integrators.helpers.construct_adjoint_fn(forward_fn, c_state.shape)
# We ensure that gradients are enabled so that autograd tracks the variable operations
# For pointwise functionals, the initial adjoint state is simply the incoming gradients
parameter_shapes = [i.shape for i in additional_dynamic_args]
packed_reverse_state = torch.cat(
[
c_state.ravel(),
(d_c_state + d_intermediate_states[-1]).ravel(),
]
)
if len(additional_dynamic_args) > 0:
packed_reverse_state = torch.cat(
[
packed_reverse_state,
torch.zeros(
sum(map(math.prod, parameter_shapes)),
device=c_state.device,
dtype=c_state.dtype,
),
]
)
current_adj_time = t1
current_adj_state = packed_reverse_state
if torch.any(d_intermediate_states != 0.0):
adj_indices = torch.arange(
c_state.numel(), 2 * c_state.numel(), device=c_state.device
)
# We only need to account for the incoming gradients if any are non-zero
for next_adj_time, d_inter_state in zip(
intermediate_times[1:-1].flip(dims=[0]),
d_intermediate_states[1:-1].flip(dims=[0]),
):
# The incoming gradients of the intermediate states are the gradients of the state defined at
# various points in time. For each of these incoming gradients, we need to integrate up to that
# temporal boundary and add them to adjoint state
if torch.all(d_inter_state == 0.0):
# No need to integrate up to the boundary if the incoming gradients are zero
continue
current_adj_state, current_adj_time, _, _, _ = (
integrator_type.apply(
adjoint_fn,
current_adj_state,
current_adj_time,
next_adj_time,
-dt,
integrator_kwargs,
*additional_dynamic_args,
)
)
current_adj_state = torch.scatter(
current_adj_state, 0, adj_indices, d_inter_state.ravel()
)
final_adj_state, final_adj_time, _, _, _ = integrator_type.apply(
adjoint_fn,
current_adj_state,
current_adj_time,
t0,
-dt,
integrator_kwargs,
*additional_dynamic_args,
)
# This should be equivalent to the initial state we passed in, but it will
# be appropriately attached to the autograd graph for higher order derivatives
if torch.is_grad_enabled() and any(
i.requires_grad for i in [d_c_state, d_c_time, d_intermediate_states]
):
adj_initial_state = final_adj_state[: c_state.numel()].reshape(
c_state.shape
)
else:
adj_initial_state = x0.clone()
adj_variables = final_adj_state[c_state.numel() : 2 * c_state.numel()]
adj_parameter_gradients = final_adj_state[2 * c_state.numel() :]
# The gradients of the incoming state are equal to the gradients from the first element of the
# intermediate state plus the lagrange variables
initial_state_grads_from_adj = adj_variables.reshape(c_state.shape)
initial_state_grads_from_intermediate = d_intermediate_states[0]
input_grads[1] = (
initial_state_grads_from_adj + initial_state_grads_from_intermediate
)
# The gradient of the initial time is equal to the gradient from the first element of the intermediate times
# minus the product of the lagrange variables and the derivative of the system at the initial time
derivative_at_t0 = forward_fn(
adj_initial_state, final_adj_time, *additional_dynamic_args
)
initial_time_grads_from_ode = torch.sum(
adj_variables * derivative_at_t0.ravel()
)
initial_time_grads_from_intermediate = d_intermediate_times[0].ravel()
input_grads[2] = (
initial_time_grads_from_intermediate - initial_time_grads_from_ode
)
# The gradient of the final time is equal to the gradient from the gradient in the final state
# plus the product of the lagrange variables and the derivative of the system at the final time
derivative_at_t1 = forward_fn(c_state, c_time, *additional_dynamic_args)
final_time_grads_from_ode = torch.sum(
(d_c_state + d_intermediate_states[-1]) * derivative_at_t1
)
final_time_grads_from_intermediate = d_c_time + d_intermediate_times[-1]
input_grads[3] = (
final_time_grads_from_intermediate + final_time_grads_from_ode
)
parameter_gradients = []
for p_shape, num_elem in zip(
parameter_shapes, map(math.prod, parameter_shapes)
):
parameter_gradients.append(
adj_parameter_gradients[:num_elem].reshape(p_shape)
)
adj_parameter_gradients = adj_parameter_gradients[num_elem:]
input_grads[6:] = parameter_gradients
inputs_grad_not_finite = list(
map(
lambda x: False if x is None else (~x.isfinite()).any(), input_grads
)
)
if any(inputs_grad_not_finite):
inp_nonfinite_indices = [
inp_idx
for inp_idx, inp_grad_is_not_finite in enumerate(
inputs_grad_not_finite
)
if inp_grad_is_not_finite
]
raise ValueError(
f"Encountered non-finite grads for inputs: {inp_nonfinite_indices}"
)
return tuple(input_grads)
return __internal_backward
def get_integrator(integrator_tableau: torch.Tensor, integrator_order: int, use_local_extrapolation: bool = True,
integrator_name: str = None) -> typing.Type[torch.autograd.Function]:
__integrator_type = neuralode.integrators.classes.create_integrator_class(integrator_tableau, integrator_order,
use_local_extrapolation, integrator_name)
# Forward integration method
__internal_forward = neuralode.integrators.integrators.get_forward_method(__integrator_type, use_local_extrapolation)
# Backward integration method
__internal_backward = get_backward_method_adjoint(__integrator_type)
# Enables batching along arbitrary dimensions using `torch.vmap`
__internal_vmap = neuralode.integrators.integrators.get_vmap_method(__integrator_type)
neuralode.integrators.classes.finalise_integrator_class(__integrator_type, __internal_forward,
__internal_backward, __internal_vmap)
return __integrator_type
initial_position = torch.tensor(1.0)
initial_velocity = torch.tensor(0.0)
frequency = (torch.ones_like(initial_position)).requires_grad_(True)
damping = (torch.ones_like(initial_position)*0.25).requires_grad_(True)
initial_state = torch.stack([
initial_position,
initial_velocity,
], dim=-1).requires_grad_(True)
initial_time = torch.tensor(0.0).requires_grad_(True)
final_time = torch.tensor(10.0).requires_grad_(True)
initial_timestep = (final_time - initial_time) / 100
adaptive_rk45_integrator = get_integrator(torch.tensor([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[1/5, 1/5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[3/10, 3/40, 9/40, 0.0, 0.0, 0.0, 0.0, 0.0 ],
[4/5, 44/45, -56/15, 32/9, 0.0, 0.0, 0.0, 0.0 ],
[8/9, 19372/6561, -25360/2187, 64448/6561, -212/729, 0.0, 0.0, 0.0 ],
[1.0, 9017/3168, -355/33, 46732/5247, 49/176, -5103/18656, 0.0, 0.0 ],
[1.0, 35/384, 0.0, 500/1113, 125/192, -2187/6784, 11/84, 0.0 ],
[torch.inf, 35/384, 0.0, 500/1113, 125/192, -2187/6784, 11/84, 0.0 ],
[torch.inf, 5179/57600, 0.0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40]
], dtype=torch.float64), integrator_order = 5, integrator_name = "AdaptiveRK45Integrator")
atol = torch.tensor(0.0)
rtol = torch.tensor(5e-8)
integrator_kwargs = {'atol': atol, 'rtol': rtol}
final_state, _, sha_states, sha_times, _ = adaptive_rk45_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, initial_state, initial_time, final_time, initial_timestep, integrator_kwargs, frequency, damping)
fig, axes = neuralode.plot.trajectory.plot_trajectory([(i, j) for i, j in zip(sha_states, sha_times)], method_label="RK4(5) - Simple Harmonic Oscillator")
The Jacobian is the matrix of gradients of each output with respect to each input. For given vector function $\vec{f}(\vec{x})$, the Jacobian is:
$$ J = \begin{bmatrix} \frac{\partial{f_1}}{\partial{x_1}} & \cdots & \frac{\partial{f_1}}{\partial{x_n}} \\ & \vdots & \\ \frac{\partial{f_m}}{\partial{x_1}} & \cdots & \frac{\partial{f_m}}{\partial{x_n}} \\ \end{bmatrix} $$
where $\frac{\partial{f_m}}{\partial{x_n}}$ is the (partial) derivative of the $m^{th}$ component of $\vec{f}$ with respect to the $n^{th}$ component of $\vec{x}$. What this matrix tells us how $\vec{f}$ when $\vec{x}$ changes.
If we think of our integration as a function $\vec{f}$, the Jacobian indicates how the result of our integration changes with respect to the inputs. We can compute it by creating a function that simply returns our trajectory and then using the pytorch.autograd
functional API as follows.
def compute_trajectory_return_final_state(x, t0, t1, f, d):
return adaptive_rk45_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, x, t0, t1, initial_timestep, {'atol': atol, 'rtol': rtol*1e-3}, f, d)[0][0]
def compute_trajectory_return_intermediate_state(x, t0, t1, f, d):
return adaptive_rk45_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, x, t0, t1, initial_timestep, {'atol': atol, 'rtol': rtol*1e-3}, f, d)[2][3][0]
test_variables = [initial_state.clone(), initial_time.clone(), final_time.clone(), frequency.clone(), damping.clone()]
test_variables = [i.detach().clone().requires_grad_(True) for i in test_variables]
jac_wrt_initial_state, jac_wrt_initial_time, jac_wrt_final_time, jac_wrt_freq, jac_wrt_damp = torch.autograd.functional.jacobian(compute_trajectory_return_final_state, tuple(test_variables))
print(f"The Jacobian of x(t1) wrt. x(t0) is: {jac_wrt_initial_state.cpu().numpy()}")
print(f"The Jacobian of x(t1) wrt. t0 is: {jac_wrt_initial_time.cpu().numpy()}")
print(f"The Jacobian of x(t1) wrt. t1 is: {jac_wrt_final_time.cpu().numpy()}")
print(f"The Jacobian of x(t1) wrt. frequency is: {jac_wrt_freq.cpu().numpy()}")
print(f"The Jacobian of x(t1) wrt. damping is: {jac_wrt_damp.cpu().numpy()}")
jac_wrt_initial_state, jac_wrt_initial_time, jac_wrt_final_time, jac_wrt_freq, jac_wrt_damp = torch.autograd.functional.jacobian(compute_trajectory_return_intermediate_state, tuple(test_variables))
print()
print(f"The Jacobian of x({sha_times[3].item():.4e}) wrt. x(t0) is: {jac_wrt_initial_state.cpu().numpy()}")
print(f"The Jacobian of x({sha_times[3].item():.4e}) wrt. t0 is: {jac_wrt_initial_time.cpu().numpy()}")
print(f"The Jacobian of x({sha_times[3].item():.4e}) wrt. t1 is: {jac_wrt_final_time.cpu().numpy()}")
print(f"The Jacobian of x({sha_times[3].item():.4e}) wrt. frequency is: {jac_wrt_freq.cpu().numpy()}")
print(f"The Jacobian of x({sha_times[3].item():.4e}) wrt. damping is: {jac_wrt_damp.cpu().numpy()}")
The Jacobian of x(t1) wrt. x(t0) is: [-0.08477596 -0.02160443] The Jacobian of x(t1) wrt. t0 is: -0.021604426103042366 The Jacobian of x(t1) wrt. t1 is: 0.021604426126233627 The Jacobian of x(t1) wrt. frequency is: 0.2160442607119664 The Jacobian of x(t1) wrt. damping is: 0.823620406656798 The Jacobian of x(2.9702e-01) wrt. x(t0) is: [0.9986817 0.05089237] The Jacobian of x(2.9702e-01) wrt. t0 is: 0.05089236517502573 The Jacobian of x(2.9702e-01) wrt. t1 is: 0.0 The Jacobian of x(2.9702e-01) wrt. frequency is: -0.0026247345057088646 The Jacobian of x(2.9702e-01) wrt. damping is: 4.513039652480764e-05
We can further use the torch.autograd.gradcheck
function to test that these Jacobians are correct. The basic overview of gradcheck
is that it computes the gradient using finite differences^[1] and then compares them to our implementation of the gradient. Within some numerical tolerance, these should be identical and this function can test this. We can also test that the gradient of the gradient is correct by using a similar procedure and the torch.autograd.gradgradcheck
function.
The three cases we need to test are if the initial state, the final state and an arbitrarily selected intermediate state are correctly differentiated. We also need to test these with tighter tolerances than our previous integration as we need to minimise the error accumulated during both the forward and the backward integration.
def test_func_adaptive(init_state, integration_t0, integration_t1, freq, damp):
# Integrates the system with tighter tolerances
res = adaptive_rk45_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, init_state, integration_t0, integration_t1, initial_timestep.double()*1e-2, {'atol': atol.double(), 'rtol': rtol.double()*1e-3}, freq, damp)
return res
def test_func_initial_state(init_state, integration_t0, integration_t1, freq, damp):
# Integrates the system and returns the initial state stored in the intermediate states
res = test_func_adaptive(init_state, integration_t0, integration_t1, freq, damp)
return res[2][0]
def test_func_intermediate_state(init_state, integration_t0, _, freq, damp):
# Ideally we'd pick from the intermediate states tensor, but during finite differencing
# the time values may change and as a result we will sample the wrong point in the trajectory.
# The solution to this is to treat the system as being integrated to some intermediate time and return
# the final state.
res = test_func_adaptive(init_state, integration_t0, sha_times[sha_times.shape[0]//2], freq, damp)
return res[0]
def test_func_final_state(init_state, integration_t0, integration_t1, freq, damp):
# Return the final state of the integration
res = test_func_adaptive(init_state, integration_t0, integration_t1, freq, damp)
return res[0]
from torch.autograd import gradcheck
test_variables = [initial_state, initial_time, initial_time.detach().clone()+0.1, frequency, damping]
# 64-bit floats are required for correct evaluation of finite differences and derivatives
test_variables = [i.double() for i in test_variables]
test_functions = [test_func_final_state, test_func_initial_state, test_func_intermediate_state]
def generate_test_vars():
# Randomly generate difference integration conditions
test_t0 = (torch.rand_like(initial_time.double()) - 1.0)*0.1
test_t1 = torch.rand_like(initial_time.double())*0.1 + test_t0.detach()
test_x = 2*torch.rand_like(initial_state.double()) - 1.0
test_frequency = torch.rand_like(frequency.double())
test_damping = torch.rand_like(damping.double())
return [i.requires_grad_(True) for i in [test_x, test_t0, test_t1, test_frequency, test_damping]]
num_tests = 16
# Run test on our initial conditions defined earlier
print(f"[0/{num_tests}] - vars: {[i.detach().cpu().numpy() for i in test_variables]}, success_jacobian: [", end='')
for fn in test_functions:
print(gradcheck(fn, [i.detach().clone().requires_grad_(True) for i in test_variables]), end=', ' if fn != test_functions[-1] else '')
print(']')
# Run test on the randomly generated states
for iter_idx in range(num_tests):
variables = generate_test_vars()
print(f"[{iter_idx+1}/{num_tests}] - vars: {[i.detach().cpu().numpy() for i in variables]}, success: [", end='')
for fn in test_functions:
print(gradcheck(fn, variables), end=', ' if fn != test_functions[-1] else '')
print(']')
[0/16] - vars: [array([1., 0.]), array(0.), array(0.1), array(1.), array(0.25)], success_jacobian: [True, True, True] [1/16] - vars: [array([-0.53149497, -0.64580155]), array(-0.09389467), array(-0.07143922), array(0.55606806), array(0.10944385)], success: [True, True, True] [2/16] - vars: [array([ 0.15955251, -0.00666538]), array(-0.05390871), array(0.01692782), array(0.51037517), array(0.32953764)], success: [True, True, True] [3/16] - vars: [array([-0.8204064 , -0.76508816]), array(-0.02817939), array(0.01027172), array(0.64023946), array(0.19676658)], success: [True, True, True] [4/16] - vars: [array([0.84974453, 0.999399 ]), array(-0.04875527), array(0.02242855), array(0.89273028), array(0.87671972)], success: [True, True, True] [5/16] - vars: [array([-0.65892713, 0.96839645]), array(-0.01550284), array(-5.80456654e-05), array(0.81270555), array(0.43584903)], success: [True, True, True] [6/16] - vars: [array([0.51552335, 0.84502525]), array(-0.05856788), array(-0.01572713), array(0.964327), array(0.17601831)], success: [True, True, True] [7/16] - vars: [array([-0.09120461, -0.40889585]), array(-0.0046106), array(0.02672735), array(0.18750747), array(0.24325766)], success: [True, True, True] [8/16] - vars: [array([-0.18625443, -0.42812303]), array(-0.06507037), array(-0.02066313), array(0.8035932), array(0.32176585)], success: [True, True, True] [9/16] - vars: [array([ 0.32706173, -0.48966606]), array(-0.06360974), array(-0.03375875), array(0.41437192), array(0.83955493)], success: [True, True, True] [10/16] - vars: [array([5.85717787e-01, 2.32655082e-04]), array(-0.02581675), array(0.00283237), array(0.89774049), array(0.10512458)], success: [True, True, True] [11/16] - vars: [array([-0.73695231, -0.52172514]), array(-0.04190862), array(0.05675738), array(0.30468405), array(0.51584484)], success: [True, True, True] [12/16] - vars: [array([ 0.0601311 , -0.47056019]), array(-0.05485591), array(-0.00556665), array(0.16711785), array(0.54819128)], success: [True, True, True] [13/16] - vars: [array([-0.11568781, 0.29077878]), array(-0.07620483), array(-0.02246851), array(0.53756922), array(0.22447995)], success: [True, True, True] [14/16] - vars: [array([-0.97824889, -0.43864284]), array(-0.03368136), array(0.05070647), array(0.93011158), array(0.54379786)], success: [True, True, True] [15/16] - vars: [array([0.46151627, 0.98484136]), array(-0.01876734), array(0.05872961), array(0.72818875), array(0.23283437)], success: [True, True, True] [16/16] - vars: [array([-0.15990223, 0.08383269]), array(-2.5330212e-05), array(0.05537509), array(0.86417502), array(0.43124712)], success: [True, True, True]
# Generate reference trajectory for optimisation/learning
with torch.no_grad():
_, _, sha_states_ref, sha_times_ref, _ = adaptive_rk45_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, initial_state, initial_time, final_time, initial_timestep, integrator_kwargs, frequency, damping)
sha_states_ref, sha_times_ref = sha_states_ref.detach(), sha_times_ref.detach()
state_dataset = sha_states_ref[1:].clone()
time_dataset = sha_times_ref[1:].clone()
# Next, we'll define a closure function whose sole purpose is to
# zero the gradients and compute the error. This is useful as it allows switching to other
# optimizers such as LBFGS or anything that re-evaluates the error without
# computing its gradient
def sha_closure(rhs: neuralode.integrators.signatures.integration_fn_signature, parameters: list[torch.Tensor], minibatch: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Computes the error on a minibatch of states and times according to the given function `rhs` and its parameters.
:param rhs: The function to integrate
:param parameters: The parameters wrt. which the gradients should be computed
:param minibatch: The minibatch to compute the error on
:return: The error of the integration on the minibatch
"""
current_state = initial_state.detach().clone()
current_time = initial_time.detach().clone()
optimiser.zero_grad()
error = 0.0
times = minibatch['times']
states = minibatch['states']
# We need to sort both times and states simultaneously, so we'll use `argsort`
sorted_time_indices = torch.argsort(times)
times, states = times[sorted_time_indices], states[sorted_time_indices]
for sample_state, sample_time in zip(states, times):
dt = torch.minimum(initial_timestep, sample_time - current_time).detach()
current_state, current_time, _, _, _ = adaptive_rk45_integrator.apply(rhs, current_state, current_time, sample_time, dt, integrator_kwargs, *parameters)
error = error + torch.linalg.norm(sample_state - current_state)/times.shape[0]
if error.requires_grad:
error.backward()
return error
# We need to set the size of our mini-batches
batch_size = 4
# Now we need an optimisation `loop` where we will take steps to minimise the error
number_of_gd_steps = 128
# We reinitialise our variables
optimised_frequency = torch.tensor(0.1, requires_grad=True)
optimised_damping = torch.tensor(1.0, requires_grad=True)
# As damping needs to be a strictly positive quantity, we log-encode it
log_encoded_damping = torch.log(optimised_damping.detach()).requires_grad_(True)
# First, we'll create an `optimiser` following pytorch convention
optimiser = torch.optim.Adam([optimised_frequency, log_encoded_damping], lr=1e-1, amsgrad=True)
# Whenever the loss plateaus, we can reduce the learning rate to improve convergence
lr_on_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser)
# We also need to track the best solution thus far
best_error = torch.inf
best_frequency, best_damping = optimised_frequency.detach().clone(), optimised_damping.detach().clone()
for step in range(number_of_gd_steps):
epoch_error = 0.0
shuffled_indices = torch.randperm(time_dataset.shape[0])
for batch_idx in range(0, time_dataset.shape[0], batch_size):
batch_dict = {
'times': time_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
'states': state_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
}
step_error = optimiser.step(lambda: sha_closure(neuralode.dynamics.simple_harmonic_oscillator, [optimised_frequency, torch.exp(log_encoded_damping)], batch_dict))
epoch_error = epoch_error + step_error.item()*batch_dict['times'].shape[0]
print(f"[{step+1}/{number_of_gd_steps}]/[{batch_idx}/{time_dataset.shape[0]}] Batch Error: {step_error:.6f}, Current Frequency: {optimised_frequency.item():.4f}, Current Damping: {torch.exp(log_encoded_damping).item():.4f}", end='\r')
epoch_error = epoch_error/time_dataset.shape[0]
if epoch_error < best_error:
best_error = epoch_error
best_frequency = optimised_frequency.detach().clone()
best_damping = torch.exp(log_encoded_damping.detach().clone())
lr_on_plateau.step(epoch_error)
print(" "*128, end="\r")
print(f"[{step+1}/{number_of_gd_steps} - lr: {lr_on_plateau.get_last_lr()[0]:.4e}] Epoch Error: {epoch_error:.6f}, Current Frequency: {optimised_frequency.item():.6f}, Current Damping: {torch.exp(log_encoded_damping).item():.6f}")
# If the step size is too small, then we can interrupt the
# training as it will not lead to significant improvements
if lr_on_plateau.get_last_lr()[0] < 1e-6:
break
rel_err = torch.mean(torch.abs(1 - best_frequency / frequency)).item()
mae_err = torch.mean(torch.abs(frequency - best_frequency)).item()
print(f"Best frequency: {best_frequency.item():.6f}, relative error: {rel_err:.6%}, mean absolute error: {mae_err:.6f}")
rel_err = torch.mean(torch.abs(1 - best_damping / damping)).item()
mae_err = torch.mean(torch.abs(damping - best_damping)).item()
print(f"Best damping: {best_damping.item():.6f}, relative error: {rel_err:.6%}, mean absolute error: {mae_err:.6f}")
[1/128 - lr: 1.0000e-01] Epoch Error: 0.321610, Current Frequency: 1.251675, Current Damping: 0.337432 [2/128 - lr: 1.0000e-01] Epoch Error: 0.077131, Current Frequency: 0.995374, Current Damping: 0.271970 [3/128 - lr: 1.0000e-01] Epoch Error: 0.012469, Current Frequency: 1.007481, Current Damping: 0.247264 [4/128 - lr: 1.0000e-01] Epoch Error: 0.010313, Current Frequency: 0.993685, Current Damping: 0.251497 [5/128 - lr: 1.0000e-01] Epoch Error: 0.015036, Current Frequency: 0.987936, Current Damping: 0.255128 [6/128 - lr: 1.0000e-01] Epoch Error: 0.015201, Current Frequency: 0.992519, Current Damping: 0.247674 [7/128 - lr: 1.0000e-01] Epoch Error: 0.014721, Current Frequency: 1.003300, Current Damping: 0.246292 [8/128 - lr: 1.0000e-01] Epoch Error: 0.010783, Current Frequency: 1.003838, Current Damping: 0.252907 [9/128 - lr: 1.0000e-01] Epoch Error: 0.005755, Current Frequency: 0.997486, Current Damping: 0.244679 [10/128 - lr: 1.0000e-01] Epoch Error: 0.009421, Current Frequency: 0.995029, Current Damping: 0.253222 [11/128 - lr: 1.0000e-01] Epoch Error: 0.018796, Current Frequency: 0.999096, Current Damping: 0.247862 [12/128 - lr: 1.0000e-01] Epoch Error: 0.012960, Current Frequency: 1.006188, Current Damping: 0.246765 [13/128 - lr: 1.0000e-01] Epoch Error: 0.012631, Current Frequency: 0.997460, Current Damping: 0.251612 [14/128 - lr: 1.0000e-01] Epoch Error: 0.012186, Current Frequency: 1.004606, Current Damping: 0.252483 [15/128 - lr: 1.0000e-01] Epoch Error: 0.007520, Current Frequency: 1.009415, Current Damping: 0.247402 [16/128 - lr: 1.0000e-01] Epoch Error: 0.011840, Current Frequency: 0.986589, Current Damping: 0.253592 [17/128 - lr: 1.0000e-01] Epoch Error: 0.013378, Current Frequency: 1.000095, Current Damping: 0.243079 [18/128 - lr: 1.0000e-01] Epoch Error: 0.013211, Current Frequency: 1.008167, Current Damping: 0.250376 [19/128 - lr: 1.0000e-01] Epoch Error: 0.009024, Current Frequency: 0.999350, Current Damping: 0.247020 [20/128 - lr: 1.0000e-02] Epoch Error: 0.006457, Current Frequency: 0.997652, Current Damping: 0.254450 [21/128 - lr: 1.0000e-02] Epoch Error: 0.002457, Current Frequency: 1.000375, Current Damping: 0.249883 [22/128 - lr: 1.0000e-02] Epoch Error: 0.001674, Current Frequency: 1.001694, Current Damping: 0.250331 [23/128 - lr: 1.0000e-02] Epoch Error: 0.001102, Current Frequency: 1.000206, Current Damping: 0.249461 [24/128 - lr: 1.0000e-02] Epoch Error: 0.000884, Current Frequency: 0.999689, Current Damping: 0.249995 [25/128 - lr: 1.0000e-02] Epoch Error: 0.000733, Current Frequency: 1.000436, Current Damping: 0.249912 [26/128 - lr: 1.0000e-02] Epoch Error: 0.001635, Current Frequency: 0.997968, Current Damping: 0.250234 [27/128 - lr: 1.0000e-02] Epoch Error: 0.000655, Current Frequency: 0.999234, Current Damping: 0.249902 [28/128 - lr: 1.0000e-02] Epoch Error: 0.000835, Current Frequency: 1.001496, Current Damping: 0.249199 [29/128 - lr: 1.0000e-02] Epoch Error: 0.001248, Current Frequency: 1.000008, Current Damping: 0.250116 [30/128 - lr: 1.0000e-02] Epoch Error: 0.000784, Current Frequency: 1.000849, Current Damping: 0.249335 [31/128 - lr: 1.0000e-02] Epoch Error: 0.001683, Current Frequency: 0.997040, Current Damping: 0.250993 [32/128 - lr: 1.0000e-02] Epoch Error: 0.001396, Current Frequency: 1.001091, Current Damping: 0.249821 [33/128 - lr: 1.0000e-02] Epoch Error: 0.000768, Current Frequency: 1.000053, Current Damping: 0.250131 [34/128 - lr: 1.0000e-02] Epoch Error: 0.001088, Current Frequency: 0.999780, Current Damping: 0.248895 [35/128 - lr: 1.0000e-02] Epoch Error: 0.000714, Current Frequency: 1.000858, Current Damping: 0.249915 [36/128 - lr: 1.0000e-02] Epoch Error: 0.001056, Current Frequency: 0.998587, Current Damping: 0.250181 [37/128 - lr: 1.0000e-02] Epoch Error: 0.001231, Current Frequency: 1.001789, Current Damping: 0.249745 [38/128 - lr: 1.0000e-03] Epoch Error: 0.000752, Current Frequency: 0.999486, Current Damping: 0.249796 [39/128 - lr: 1.0000e-03] Epoch Error: 0.000264, Current Frequency: 0.999915, Current Damping: 0.250041 [40/128 - lr: 1.0000e-03] Epoch Error: 0.000137, Current Frequency: 0.999792, Current Damping: 0.249997 [41/128 - lr: 1.0000e-03] Epoch Error: 0.000192, Current Frequency: 1.000024, Current Damping: 0.249953 [42/128 - lr: 1.0000e-03] Epoch Error: 0.000080, Current Frequency: 1.000043, Current Damping: 0.250012 [43/128 - lr: 1.0000e-03] Epoch Error: 0.000099, Current Frequency: 0.999943, Current Damping: 0.250036 [44/128 - lr: 1.0000e-03] Epoch Error: 0.000119, Current Frequency: 1.000001, Current Damping: 0.250088 [45/128 - lr: 1.0000e-03] Epoch Error: 0.000089, Current Frequency: 0.999957, Current Damping: 0.249983 [46/128 - lr: 1.0000e-03] Epoch Error: 0.000088, Current Frequency: 1.000017, Current Damping: 0.249934 [47/128 - lr: 1.0000e-03] Epoch Error: 0.000085, Current Frequency: 0.999884, Current Damping: 0.249989 [48/128 - lr: 1.0000e-03] Epoch Error: 0.000158, Current Frequency: 0.999842, Current Damping: 0.250018 [49/128 - lr: 1.0000e-03] Epoch Error: 0.000154, Current Frequency: 1.000107, Current Damping: 0.249976 [50/128 - lr: 1.0000e-03] Epoch Error: 0.000074, Current Frequency: 0.999905, Current Damping: 0.250041 [51/128 - lr: 1.0000e-03] Epoch Error: 0.000072, Current Frequency: 0.999941, Current Damping: 0.250007 [52/128 - lr: 1.0000e-03] Epoch Error: 0.000119, Current Frequency: 0.999948, Current Damping: 0.249982 [53/128 - lr: 1.0000e-03] Epoch Error: 0.000127, Current Frequency: 0.999979, Current Damping: 0.249996 [54/128 - lr: 1.0000e-03] Epoch Error: 0.000129, Current Frequency: 0.999846, Current Damping: 0.250023 [55/128 - lr: 1.0000e-03] Epoch Error: 0.000169, Current Frequency: 0.999789, Current Damping: 0.250029 [56/128 - lr: 1.0000e-03] Epoch Error: 0.000269, Current Frequency: 1.000029, Current Damping: 0.250036 [57/128 - lr: 1.0000e-03] Epoch Error: 0.000136, Current Frequency: 0.999884, Current Damping: 0.250002 [58/128 - lr: 1.0000e-03] Epoch Error: 0.000187, Current Frequency: 1.000121, Current Damping: 0.250004 [59/128 - lr: 1.0000e-03] Epoch Error: 0.000164, Current Frequency: 0.999913, Current Damping: 0.249983 [60/128 - lr: 1.0000e-03] Epoch Error: 0.000108, Current Frequency: 0.999945, Current Damping: 0.250092 [61/128 - lr: 1.0000e-03] Epoch Error: 0.000076, Current Frequency: 1.000125, Current Damping: 0.249983 [62/128 - lr: 1.0000e-04] Epoch Error: 0.000149, Current Frequency: 0.999997, Current Damping: 0.250033 [63/128 - lr: 1.0000e-04] Epoch Error: 0.000012, Current Frequency: 1.000014, Current Damping: 0.250001 [64/128 - lr: 1.0000e-04] Epoch Error: 0.000013, Current Frequency: 1.000021, Current Damping: 0.249993 [65/128 - lr: 1.0000e-04] Epoch Error: 0.000013, Current Frequency: 1.000002, Current Damping: 0.249991 [66/128 - lr: 1.0000e-04] Epoch Error: 0.000013, Current Frequency: 1.000010, Current Damping: 0.250002 [67/128 - lr: 1.0000e-04] Epoch Error: 0.000014, Current Frequency: 1.000012, Current Damping: 0.249998 [68/128 - lr: 1.0000e-04] Epoch Error: 0.000007, Current Frequency: 1.000012, Current Damping: 0.250006 [69/128 - lr: 1.0000e-04] Epoch Error: 0.000010, Current Frequency: 1.000016, Current Damping: 0.250000 [70/128 - lr: 1.0000e-04] Epoch Error: 0.000018, Current Frequency: 0.999979, Current Damping: 0.249999 [71/128 - lr: 1.0000e-04] Epoch Error: 0.000014, Current Frequency: 0.999994, Current Damping: 0.249999 [72/128 - lr: 1.0000e-04] Epoch Error: 0.000016, Current Frequency: 1.000012, Current Damping: 0.250004 [73/128 - lr: 1.0000e-04] Epoch Error: 0.000021, Current Frequency: 1.000002, Current Damping: 0.250002 [74/128 - lr: 1.0000e-04] Epoch Error: 0.000016, Current Frequency: 0.999998, Current Damping: 0.249997 [75/128 - lr: 1.0000e-04] Epoch Error: 0.000010, Current Frequency: 0.999992, Current Damping: 0.250010 [76/128 - lr: 1.0000e-04] Epoch Error: 0.000013, Current Frequency: 1.000000, Current Damping: 0.250005 [77/128 - lr: 1.0000e-04] Epoch Error: 0.000007, Current Frequency: 0.999988, Current Damping: 0.250002 [78/128 - lr: 1.0000e-04] Epoch Error: 0.000013, Current Frequency: 0.999991, Current Damping: 0.250005 [79/128 - lr: 1.0000e-04] Epoch Error: 0.000016, Current Frequency: 0.999998, Current Damping: 0.249999 [80/128 - lr: 1.0000e-04] Epoch Error: 0.000012, Current Frequency: 0.999999, Current Damping: 0.249996 [81/128 - lr: 1.0000e-04] Epoch Error: 0.000011, Current Frequency: 0.999985, Current Damping: 0.250004 [82/128 - lr: 1.0000e-04] Epoch Error: 0.000022, Current Frequency: 0.999975, Current Damping: 0.250006 [83/128 - lr: 1.0000e-04] Epoch Error: 0.000026, Current Frequency: 0.999980, Current Damping: 0.250006 [84/128 - lr: 1.0000e-04] Epoch Error: 0.000018, Current Frequency: 1.000034, Current Damping: 0.249991 [85/128 - lr: 1.0000e-04] Epoch Error: 0.000018, Current Frequency: 0.999967, Current Damping: 0.250005 [86/128 - lr: 1.0000e-04] Epoch Error: 0.000020, Current Frequency: 0.999997, Current Damping: 0.250000 [87/128 - lr: 1.0000e-04] Epoch Error: 0.000006, Current Frequency: 1.000013, Current Damping: 0.249999 [88/128 - lr: 1.0000e-04] Epoch Error: 0.000021, Current Frequency: 1.000005, Current Damping: 0.249997 [89/128 - lr: 1.0000e-04] Epoch Error: 0.000014, Current Frequency: 1.000006, Current Damping: 0.249997 [90/128 - lr: 1.0000e-04] Epoch Error: 0.000006, Current Frequency: 0.999994, Current Damping: 0.249995 [91/128 - lr: 1.0000e-04] Epoch Error: 0.000010, Current Frequency: 0.999995, Current Damping: 0.250004 [92/128 - lr: 1.0000e-04] Epoch Error: 0.000006, Current Frequency: 0.999996, Current Damping: 0.249999 [93/128 - lr: 1.0000e-04] Epoch Error: 0.000009, Current Frequency: 1.000001, Current Damping: 0.250003 [94/128 - lr: 1.0000e-04] Epoch Error: 0.000012, Current Frequency: 1.000007, Current Damping: 0.249999 [95/128 - lr: 1.0000e-04] Epoch Error: 0.000011, Current Frequency: 1.000000, Current Damping: 0.249996 [96/128 - lr: 1.0000e-04] Epoch Error: 0.000015, Current Frequency: 0.999975, Current Damping: 0.250005 [97/128 - lr: 1.0000e-04] Epoch Error: 0.000017, Current Frequency: 0.999996, Current Damping: 0.250002 [98/128 - lr: 1.0000e-04] Epoch Error: 0.000007, Current Frequency: 1.000005, Current Damping: 0.249999 [99/128 - lr: 1.0000e-04] Epoch Error: 0.000007, Current Frequency: 1.000005, Current Damping: 0.249999 [100/128 - lr: 1.0000e-04] Epoch Error: 0.000008, Current Frequency: 0.999997, Current Damping: 0.250001 [101/128 - lr: 1.0000e-05] Epoch Error: 0.000010, Current Frequency: 1.000004, Current Damping: 0.250001 [102/128 - lr: 1.0000e-05] Epoch Error: 0.000002, Current Frequency: 1.000000, Current Damping: 0.250001 [103/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 0.999999, Current Damping: 0.250000 [104/128 - lr: 1.0000e-05] Epoch Error: 0.000002, Current Frequency: 0.999998, Current Damping: 0.250001 [105/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [106/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000000, Current Damping: 0.250000 [107/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [108/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000000, Current Damping: 0.250000 [109/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [110/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 0.999999, Current Damping: 0.250000 [111/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [112/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [113/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000000, Current Damping: 0.250000 [114/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000000, Current Damping: 0.250000 [115/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [116/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000000, Current Damping: 0.250000 [117/128 - lr: 1.0000e-05] Epoch Error: 0.000001, Current Frequency: 1.000001, Current Damping: 0.250000 [118/128 - lr: 1.0000e-06] Epoch Error: 0.000001, Current Frequency: 1.000000, Current Damping: 0.250000 [119/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [120/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [121/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [122/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [123/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [124/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [125/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [126/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [127/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 [128/128 - lr: 1.0000e-06] Epoch Error: 0.000000, Current Frequency: 1.000000, Current Damping: 0.250000 Best frequency: 1.000000, relative error: 0.000001%, mean absolute error: 0.000000 Best damping: 0.250000, relative error: 0.000023%, mean absolute error: 0.000000
fig_ref_position, axes_ref_position = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Position Ref.")
fig_ref_velocity, axes_ref_velocity = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Velocity Ref.")
_, _, sha_states_optimised, sha_times_optimised, _ = adaptive_rk45_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, initial_state, initial_time, final_time, initial_timestep, integrator_kwargs, best_frequency, best_damping)
_ = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_position, method_label="RK4(5) - SHA Position Opt.")
_ = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_velocity, method_label="RK4(5) - SHA Velocity Opt.")
This replicates the prior results and actually improves on them by a small margin.
Now let's add a neural network into the mix! This will be the simplest network possible aka a matrix multiplied with the input vector. While this may seem simple, you'll see that our simple harmonic oscillator can also be expressed as a matrix multiplied by the input. It has a very specific structure that arises from the fact that it was a second order equation. I've written this matrix below:
$$ \begin{bmatrix} x^{(1)} \\ v^{(1)} \end{bmatrix} = \mathbf{A} \begin{bmatrix} x \\ v \end{bmatrix} $$
where
$$ \mathbf{A} = \begin{bmatrix} 0 & 1 \\ -\omega^2 & -2\zeta\omega \end{bmatrix} $$
Given that matrix multiplication underlies most neural networks, we can try to learn this $\mathbf{A}$-matrix and at the same time introduce some of the Neural Network machinery in PyTorch. We will revisit these later when learning more interesting/complex systems.
# we define our network as a subclass of torch.nn.Module
# This allows PyTorch to appropriately track parameters
class OscillatorNet(torch.nn.Module):
def __init__(self):
# First we initialise the superclass, `torch.nn.Module`
super().__init__()
# Then we define the actual neural network
# Most Neural Networks operate sequentially so they can be wrapped
# inside a torch.nn.Sequential which takes each layer
# as an argument.
# Since we're only learning one matrix, we have
# one layer, the `torch.nn.Linear`.
# `torch.nn.Linear` stores a matrix and a bias which actually makes it
# an Affine transformation rather than a purely linear transformation
self.internal_net = torch.nn.Sequential(
torch.nn.Linear(2, 2),
)
def forward(self, x, t):
# Our network only depends on x, but since it could also depend on t, we have
# included it for completeness
# Additionally, PyTorch layers and modules expect a batched tensor
# ie. a tensor where the first dimension is over different samples
# Since we don't depend on batches, we check if the input is 1-dimensional
# And add a batch dimension as needed for the internal module
if x.dim() == 1:
return self.internal_net(x[None])[0]
else:
return self.internal_net(x)
# And then instantiate the weights of the network itself
def init_weights(m):
# For each layer type, we can define how we initialise its values
if isinstance(m, torch.nn.Linear):
# A linear equation with a positive coefficient
# translates to exponential growth and a negative coefficient
# to exponential decay. In order to preserve stability we sample a matrix
# that is biased to be negative in its entries thus ensuring
# that our initial system is of exponential decay.
m.weight.data.normal_(0.0, 0.1)
if m.bias is not None:
m.bias.data.normal_(0.0, 0.1)
# Here we instantiate our network.
torch.manual_seed(36)
simple_oscillator_net = OscillatorNet()
simple_oscillator_net.apply(init_weights)
# `torch.autograd.Function`s track computation on all input tensors.
# For that reason, we must pass our neural network parameters to the integrator,
# which will pass it to the derivative function.
# Since our network is stateful, we don't use these parameters, but define them in
# the function signature.
def sha_nn_fn(x, t, *nn_parameters):
return torch.func.functional_call(simple_oscillator_net, {k: p for (k, _), p in zip(simple_oscillator_net.named_parameters(), nn_parameters)}, (x, t))
optimiser = torch.optim.Adam(simple_oscillator_net.parameters(), lr=1e-3, amsgrad=True)
# OneCycleLR cycles the learning rate from max_lr/25 to max_lr and then back down.
# This helps the network escape local minima and find better solutions.
# It is an alternative to ReduceLROnPlateau which can sometimes get stuck in local minima
# for problems with a small number of degrees of freedom.
one_cycle_lr = torch.optim.lr_scheduler.OneCycleLR(optimiser, max_lr=1e-1, steps_per_epoch=round(time_dataset.shape[0]/batch_size+0.5), epochs=number_of_gd_steps, three_phase=True)
ideal_matrix = neuralode.dynamics.get_simple_harmonic_oscillator_matrix(frequency, damping)
ideal_bias = torch.zeros_like(initial_state)
best_error = torch.inf
# For pytorch modules, the `state_dict` method allows us to get a copy
# of all the parameters that define the model, thus enabling us to
# store the state as well as restore it.
best_parameters = simple_oscillator_net.state_dict()
for step in range(number_of_gd_steps):
epoch_error = 0.0
shuffled_indices = torch.randperm(time_dataset.shape[0])
for batch_idx in range(0, time_dataset.shape[0], batch_size):
batch_dict = {
'times': time_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
'states': state_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
}
step_error = optimiser.step(lambda: sha_closure(sha_nn_fn, list(simple_oscillator_net.parameters()), batch_dict))
epoch_error = epoch_error + step_error.item()*batch_dict['times'].shape[0]
one_cycle_lr.step()
# print(f"[{step+1}/{number_of_gd_steps}]/[{batch_idx}/{time_dataset.shape[0]}] Batch Error: {step_error:.6f} ", end='\r')
epoch_error = epoch_error/time_dataset.shape[0]
if epoch_error < best_error:
best_error = epoch_error
best_parameters = simple_oscillator_net.state_dict()
learned_matrix = simple_oscillator_net.state_dict()['internal_net.0.weight']
learned_bias = simple_oscillator_net.state_dict()['internal_net.0.bias']
# Ideally our matrix is equivalent to our simple harmonic oscillator matrix and our bias goes to zero
print(f"[{step+1}/{number_of_gd_steps} - lr: {one_cycle_lr.get_last_lr()[0]:.4e}] Epoch Error: {epoch_error:.6f}, \nW={learned_matrix.cpu()}, \nb={learned_bias.cpu()}")
print()
simple_oscillator_net.load_state_dict(best_parameters)
learned_matrix = simple_oscillator_net.state_dict()['internal_net.0.weight']
learned_bias = simple_oscillator_net.state_dict()['internal_net.0.bias']
# Before we were looking at relative error, but in the case of a matrix with zeros,
# the relative error is undefined, so we look at another common metric: mean absolute error
print(f"Best matrix: {learned_matrix}, mean absolute error: {torch.mean(torch.abs(ideal_matrix - learned_matrix)).item():.6f}")
print(f"Best bias: {learned_bias}, mean absolute error: {torch.mean(torch.abs(ideal_bias - learned_bias)).item():.6f}")
[1/128 - lr: 4.1609e-03] Epoch Error: 1.609792, W=tensor([[ 0.0641, 0.1550], [-0.0738, -0.1089]]), b=tensor([-0.0626, -0.0914]) [2/128 - lr: 4.6425e-03] Epoch Error: 0.720409, W=tensor([[ 0.0104, 0.2160], [-0.0536, -0.1961]]), b=tensor([-0.1232, -0.0374]) [3/128 - lr: 5.4416e-03] Epoch Error: 0.504084, W=tensor([[-0.0127, 0.2244], [-0.0064, -0.2662]]), b=tensor([-0.1308, 0.0365]) [4/128 - lr: 6.5528e-03] Epoch Error: 0.483388, W=tensor([[-0.0308, 0.2290], [-0.0187, -0.3025]]), b=tensor([-0.1357, 0.0223]) [5/128 - lr: 7.9688e-03] Epoch Error: 0.471892, W=tensor([[-0.0409, 0.2373], [-0.0125, -0.3169]]), b=tensor([-0.1141, 0.0231]) [6/128 - lr: 9.6799e-03] Epoch Error: 0.457693, W=tensor([[-0.0528, 0.2441], [-0.0243, -0.3263]]), b=tensor([-0.0928, -0.0056]) [7/128 - lr: 1.1675e-02] Epoch Error: 0.451436, W=tensor([[-0.0744, 0.2438], [-0.0301, -0.3349]]), b=tensor([-0.0837, -0.0036]) [8/128 - lr: 1.3940e-02] Epoch Error: 0.443324, W=tensor([[-0.1043, 0.2411], [-0.0366, -0.3436]]), b=tensor([-0.0817, 0.0052]) [9/128 - lr: 1.6460e-02] Epoch Error: 0.428885, W=tensor([[-0.1330, 0.2390], [-0.0429, -0.3516]]), b=tensor([-0.0619, -0.0138]) [10/128 - lr: 1.9219e-02] Epoch Error: 0.423491, W=tensor([[-0.1724, 0.2298], [-0.0604, -0.3743]]), b=tensor([-0.0684, -0.0139]) [11/128 - lr: 2.2197e-02] Epoch Error: 0.415295, W=tensor([[-0.2186, 0.2158], [-0.0939, -0.4081]]), b=tensor([-0.0541, -0.0083]) [12/128 - lr: 2.5375e-02] Epoch Error: 0.393295, W=tensor([[-0.2631, 0.2108], [-0.1239, -0.4492]]), b=tensor([-0.0418, -0.0093]) [13/128 - lr: 2.8732e-02] Epoch Error: 0.378483, W=tensor([[-0.3168, 0.2258], [-0.1701, -0.4994]]), b=tensor([-0.0218, 0.0029]) [14/128 - lr: 3.2244e-02] Epoch Error: 0.362960, W=tensor([[-0.3653, 0.2573], [-0.2244, -0.5573]]), b=tensor([-0.0408, -0.0072]) [15/128 - lr: 3.5890e-02] Epoch Error: 0.343125, W=tensor([[-0.4226, 0.3115], [-0.3458, -0.6166]]), b=tensor([-0.0015, -0.0041]) [16/128 - lr: 3.9643e-02] Epoch Error: 0.315085, W=tensor([[-0.4736, 0.4208], [-0.5121, -0.6513]]), b=tensor([-0.0264, -0.0381]) [17/128 - lr: 4.3479e-02] Epoch Error: 0.269020, W=tensor([[-0.5067, 0.5902], [-0.7364, -0.6283]]), b=tensor([-0.0192, -0.0658]) [18/128 - lr: 4.7372e-02] Epoch Error: 0.197346, W=tensor([[-0.4880, 0.7967], [-1.0256, -0.3488]]), b=tensor([-0.0205, -0.0677]) [19/128 - lr: 5.1296e-02] Epoch Error: 0.107882, W=tensor([[-0.4213, 0.8397], [-1.1120, -0.1394]]), b=tensor([0.0012, 0.0378]) [20/128 - lr: 5.5224e-02] Epoch Error: 0.089753, W=tensor([[-0.3504, 0.8440], [-1.0732, -0.1570]]), b=tensor([-0.0072, 0.0150]) [21/128 - lr: 5.9132e-02] Epoch Error: 0.074635, W=tensor([[-0.2692, 0.8929], [-1.0202, -0.2536]]), b=tensor([0.0095, 0.0032]) [22/128 - lr: 6.2991e-02] Epoch Error: 0.055025, W=tensor([[-0.1796, 0.9106], [-0.9909, -0.3446]]), b=tensor([ 0.0009, -0.0002]) [23/128 - lr: 6.6777e-02] Epoch Error: 0.042198, W=tensor([[-0.0918, 0.9161], [-0.9663, -0.4102]]), b=tensor([0.0152, 0.0084]) [24/128 - lr: 7.0463e-02] Epoch Error: 0.030391, W=tensor([[-0.0337, 0.9564], [-1.0250, -0.4770]]), b=tensor([ 0.0103, -0.0006]) [25/128 - lr: 7.4026e-02] Epoch Error: 0.022276, W=tensor([[-0.0159, 1.0148], [-0.9859, -0.4600]]), b=tensor([0.0005, 0.0003]) [26/128 - lr: 7.7441e-02] Epoch Error: 0.018522, W=tensor([[ 0.0070, 1.0024], [-1.0076, -0.4974]]), b=tensor([0.0049, 0.0168]) [27/128 - lr: 8.0686e-02] Epoch Error: 0.020377, W=tensor([[ 0.0027, 0.9898], [-0.9873, -0.4936]]), b=tensor([-0.0004, -0.0022]) [28/128 - lr: 8.3738e-02] Epoch Error: 0.022518, W=tensor([[ 0.0015, 1.0029], [-0.9754, -0.5097]]), b=tensor([-0.0057, 0.0114]) [29/128 - lr: 8.6578e-02] Epoch Error: 0.020440, W=tensor([[ 0.0074, 1.0234], [-1.0268, -0.5093]]), b=tensor([-0.0029, 0.0104]) [30/128 - lr: 8.9186e-02] Epoch Error: 0.027248, W=tensor([[-1.7884e-05, 1.0147e+00], [-9.9924e-01, -5.1175e-01]]), b=tensor([ 0.0046, -0.0083]) [31/128 - lr: 9.1544e-02] Epoch Error: 0.029168, W=tensor([[-3.4838e-04, 1.0066e+00], [-9.7280e-01, -5.1309e-01]]), b=tensor([-0.0065, 0.0063]) [32/128 - lr: 9.3637e-02] Epoch Error: 0.030273, W=tensor([[ 0.0034, 0.9911], [-1.0255, -0.5189]]), b=tensor([0.0078, 0.0182]) [33/128 - lr: 9.5451e-02] Epoch Error: 0.022938, W=tensor([[-0.0146, 1.0047], [-1.0120, -0.4691]]), b=tensor([-0.0095, 0.0011]) [34/128 - lr: 9.6974e-02] Epoch Error: 0.023357, W=tensor([[-0.0071, 1.0359], [-0.9823, -0.5283]]), b=tensor([-0.0299, -0.0092]) [35/128 - lr: 9.8196e-02] Epoch Error: 0.028715, W=tensor([[ 0.0179, 0.9993], [-1.0036, -0.5020]]), b=tensor([0.0063, 0.0142]) [36/128 - lr: 9.9107e-02] Epoch Error: 0.032321, W=tensor([[ 0.0182, 0.9964], [-0.9976, -0.5147]]), b=tensor([-0.0158, -0.0083]) [37/128 - lr: 9.9703e-02] Epoch Error: 0.020442, W=tensor([[-0.0135, 0.9988], [-1.0087, -0.5116]]), b=tensor([0.0238, 0.0263]) [38/128 - lr: 9.9979e-02] Epoch Error: 0.037006, W=tensor([[-0.0259, 0.9824], [-1.0131, -0.4769]]), b=tensor([-0.0046, 0.0186]) [39/128 - lr: 9.9934e-02] Epoch Error: 0.029400, W=tensor([[-0.0120, 0.9996], [-1.0392, -0.4779]]), b=tensor([ 0.0089, -0.0208]) [40/128 - lr: 9.9567e-02] Epoch Error: 0.033170, W=tensor([[-0.0055, 1.0142], [-1.0174, -0.4937]]), b=tensor([-0.0042, -0.0404]) [41/128 - lr: 9.8881e-02] Epoch Error: 0.034741, W=tensor([[ 0.0095, 0.9889], [-0.9669, -0.5061]]), b=tensor([-0.0012, -0.0139]) [42/128 - lr: 9.7881e-02] Epoch Error: 0.025106, W=tensor([[-0.0160, 0.9722], [-1.0266, -0.4955]]), b=tensor([0.0083, 0.0179]) [43/128 - lr: 9.6573e-02] Epoch Error: 0.022553, W=tensor([[-5.3376e-04, 9.7187e-01], [-9.9602e-01, -4.8455e-01]]), b=tensor([0.0142, 0.0078]) [44/128 - lr: 9.4967e-02] Epoch Error: 0.030136, W=tensor([[ 0.0033, 0.9065], [-1.0110, -0.4894]]), b=tensor([-0.0296, -0.0152]) [45/128 - lr: 9.3072e-02] Epoch Error: 0.037048, W=tensor([[ 0.0018, 1.0208], [-0.9850, -0.4883]]), b=tensor([0.0109, 0.0114]) [46/128 - lr: 9.0902e-02] Epoch Error: 0.015217, W=tensor([[-0.0096, 0.9875], [-1.0184, -0.5115]]), b=tensor([0.0029, 0.0136]) [47/128 - lr: 8.8471e-02] Epoch Error: 0.037965, W=tensor([[-0.0263, 1.0420], [-0.9906, -0.5544]]), b=tensor([ 0.0238, -0.0071]) [48/128 - lr: 8.5796e-02] Epoch Error: 0.027586, W=tensor([[ 0.0051, 1.0088], [-1.0085, -0.4161]]), b=tensor([0.0016, 0.0284]) [49/128 - lr: 8.2894e-02] Epoch Error: 0.036071, W=tensor([[-0.0171, 0.9916], [-0.9907, -0.4765]]), b=tensor([ 0.0048, -0.0028]) [50/128 - lr: 7.9785e-02] Epoch Error: 0.016789, W=tensor([[ 0.0072, 1.0112], [-0.9874, -0.5083]]), b=tensor([-0.0076, 0.0059]) [51/128 - lr: 7.6490e-02] Epoch Error: 0.018357, W=tensor([[ 8.1218e-04, 1.0239e+00], [-1.0008e+00, -5.0510e-01]]), b=tensor([ 0.0013, -0.0125]) [52/128 - lr: 7.3031e-02] Epoch Error: 0.023514, W=tensor([[-0.0103, 0.9784], [-1.0374, -0.5118]]), b=tensor([ 0.0046, -0.0131]) [53/128 - lr: 6.9430e-02] Epoch Error: 0.021051, W=tensor([[-0.0074, 0.9536], [-0.9919, -0.5058]]), b=tensor([ 0.0024, -0.0115]) [54/128 - lr: 6.5713e-02] Epoch Error: 0.017005, W=tensor([[-0.0068, 1.0226], [-1.0077, -0.4919]]), b=tensor([0.0081, 0.0121]) [55/128 - lr: 6.1904e-02] Epoch Error: 0.017439, W=tensor([[ 0.0139, 0.9870], [-1.0104, -0.5034]]), b=tensor([-0.0074, 0.0184]) [56/128 - lr: 5.8028e-02] Epoch Error: 0.011525, W=tensor([[ 0.0023, 1.0026], [-1.0045, -0.4947]]), b=tensor([-0.0010, 0.0025]) [57/128 - lr: 5.4112e-02] Epoch Error: 0.012064, W=tensor([[-0.0022, 1.0142], [-0.9994, -0.5060]]), b=tensor([-0.0066, 0.0103]) [58/128 - lr: 5.0182e-02] Epoch Error: 0.012457, W=tensor([[-0.0036, 1.0150], [-1.0009, -0.4884]]), b=tensor([0.0092, 0.0004]) [59/128 - lr: 4.6264e-02] Epoch Error: 0.010368, W=tensor([[-0.0141, 0.9839], [-0.9979, -0.4968]]), b=tensor([-0.0075, -0.0010]) [60/128 - lr: 4.2385e-02] Epoch Error: 0.007346, W=tensor([[-0.0030, 1.0039], [-0.9966, -0.4866]]), b=tensor([ 0.0062, -0.0022]) [61/128 - lr: 3.8570e-02] Epoch Error: 0.009046, W=tensor([[ 0.0055, 0.9880], [-0.9992, -0.4988]]), b=tensor([-0.0058, -0.0106]) [62/128 - lr: 3.4845e-02] Epoch Error: 0.008171, W=tensor([[ 0.0029, 1.0003], [-1.0060, -0.5129]]), b=tensor([ 0.0022, -0.0007]) [63/128 - lr: 3.1235e-02] Epoch Error: 0.008128, W=tensor([[-0.0015, 0.9900], [-0.9928, -0.4823]]), b=tensor([-0.0033, 0.0060]) [64/128 - lr: 2.7764e-02] Epoch Error: 0.010931, W=tensor([[ 0.0035, 1.0005], [-0.9982, -0.4849]]), b=tensor([-0.0046, -0.0011]) [65/128 - lr: 2.4456e-02] Epoch Error: 0.010787, W=tensor([[-0.0014, 1.0096], [-1.0108, -0.5066]]), b=tensor([ 0.0035, -0.0010]) [66/128 - lr: 2.1332e-02] Epoch Error: 0.010705, W=tensor([[-0.0018, 1.0000], [-0.9986, -0.5067]]), b=tensor([0.0024, 0.0030]) [67/128 - lr: 1.8414e-02] Epoch Error: 0.005952, W=tensor([[-0.0019, 0.9892], [-1.0084, -0.5000]]), b=tensor([-0.0002, -0.0005]) [68/128 - lr: 1.5721e-02] Epoch Error: 0.005751, W=tensor([[-0.0027, 0.9980], [-0.9947, -0.4958]]), b=tensor([-0.0012, 0.0012]) [69/128 - lr: 1.3271e-02] Epoch Error: 0.004823, W=tensor([[-0.0019, 1.0019], [-0.9976, -0.4980]]), b=tensor([0.0009, 0.0048]) [70/128 - lr: 1.1081e-02] Epoch Error: 0.003509, W=tensor([[ 6.6505e-04, 9.9884e-01], [-1.0029e+00, -4.9987e-01]]), b=tensor([0.0005, 0.0010]) [71/128 - lr: 9.1657e-03] Epoch Error: 0.002061, W=tensor([[ 4.3162e-05, 1.0008e+00], [-1.0001e+00, -4.9976e-01]]), b=tensor([0.0003, 0.0010]) [72/128 - lr: 7.5372e-03] Epoch Error: 0.002211, W=tensor([[-7.0880e-04, 9.9970e-01], [-1.0024e+00, -5.0111e-01]]), b=tensor([-0.0001, 0.0015]) [73/128 - lr: 6.2067e-03] Epoch Error: 0.002731, W=tensor([[ 3.7402e-04, 9.9696e-01], [-9.9885e-01, -4.9766e-01]]), b=tensor([-0.0014, -0.0012]) [74/128 - lr: 5.1832e-03] Epoch Error: 0.001516, W=tensor([[-4.3274e-04, 9.9969e-01], [-9.9975e-01, -4.9795e-01]]), b=tensor([-0.0005, -0.0004]) [75/128 - lr: 4.4736e-03] Epoch Error: 0.001567, W=tensor([[-1.5633e-04, 9.9906e-01], [-9.9966e-01, -4.9970e-01]]), b=tensor([-0.0006, -0.0003]) [76/128 - lr: 4.0827e-03] Epoch Error: 0.002169, W=tensor([[-3.6453e-04, 9.9956e-01], [-1.0023e+00, -4.9796e-01]]), b=tensor([ 0.0005, -0.0007]) [77/128 - lr: 3.9997e-03] Epoch Error: 0.001693, W=tensor([[-3.3162e-04, 9.9971e-01], [-1.0013e+00, -4.9860e-01]]), b=tensor([-7.7370e-05, -6.7228e-04]) [78/128 - lr: 3.9938e-03] Epoch Error: 0.001061, W=tensor([[-2.3620e-04, 9.9928e-01], [-9.9912e-01, -4.9921e-01]]), b=tensor([0.0004, 0.0001]) [79/128 - lr: 3.9804e-03] Epoch Error: 0.001350, W=tensor([[ 6.0312e-05, 1.0014e+00], [-1.0004e+00, -5.0052e-01]]), b=tensor([-0.0007, 0.0010]) [80/128 - lr: 3.9596e-03] Epoch Error: 0.001763, W=tensor([[ 1.9713e-04, 9.9968e-01], [-9.9843e-01, -5.0022e-01]]), b=tensor([ 0.0005, -0.0008]) [81/128 - lr: 3.9314e-03] Epoch Error: 0.001147, W=tensor([[-2.6483e-04, 1.0007e+00], [-1.0007e+00, -4.9960e-01]]), b=tensor([0.0015, 0.0001]) [82/128 - lr: 3.8960e-03] Epoch Error: 0.001231, W=tensor([[ 2.3433e-04, 1.0013e+00], [-1.0011e+00, -5.0177e-01]]), b=tensor([ 0.0004, -0.0007]) [83/128 - lr: 3.8534e-03] Epoch Error: 0.000781, W=tensor([[ 4.1967e-04, 9.9999e-01], [-1.0002e+00, -5.0061e-01]]), b=tensor([-0.0002, 0.0004]) [84/128 - lr: 3.8039e-03] Epoch Error: 0.000817, W=tensor([[-2.7855e-04, 1.0003e+00], [-9.9841e-01, -4.9858e-01]]), b=tensor([-0.0010, -0.0007]) [85/128 - lr: 3.7476e-03] Epoch Error: 0.000948, W=tensor([[ 2.9369e-04, 1.0004e+00], [-1.0010e+00, -4.9932e-01]]), b=tensor([-0.0001, 0.0002]) [86/128 - lr: 3.6847e-03] Epoch Error: 0.001242, W=tensor([[-7.7710e-04, 9.9899e-01], [-9.9972e-01, -5.0095e-01]]), b=tensor([-0.0008, -0.0001]) [87/128 - lr: 3.6155e-03] Epoch Error: 0.000979, W=tensor([[ 2.2082e-04, 9.9973e-01], [-1.0007e+00, -5.0043e-01]]), b=tensor([0.0007, 0.0010]) [88/128 - lr: 3.5403e-03] Epoch Error: 0.000985, W=tensor([[ 4.5567e-04, 9.9948e-01], [-1.0007e+00, -5.0107e-01]]), b=tensor([-0.0004, -0.0003]) [89/128 - lr: 3.4592e-03] Epoch Error: 0.000793, W=tensor([[-4.3191e-04, 9.9916e-01], [-9.9958e-01, -4.9898e-01]]), b=tensor([-0.0001, 0.0006]) [90/128 - lr: 3.3727e-03] Epoch Error: 0.000918, W=tensor([[ 2.3193e-04, 9.9881e-01], [-9.9952e-01, -5.0025e-01]]), b=tensor([-0.0006, -0.0004]) [91/128 - lr: 3.2810e-03] Epoch Error: 0.001123, W=tensor([[ 1.3298e-04, 1.0004e+00], [-1.0008e+00, -5.0098e-01]]), b=tensor([-0.0008, 0.0009]) [92/128 - lr: 3.1845e-03] Epoch Error: 0.001331, W=tensor([[ 4.5332e-04, 1.0006e+00], [-9.9945e-01, -4.9805e-01]]), b=tensor([-0.0003, -0.0008]) [93/128 - lr: 3.0835e-03] Epoch Error: 0.001310, W=tensor([[-2.6666e-04, 1.0002e+00], [-1.0006e+00, -4.9925e-01]]), b=tensor([-1.0954e-03, 2.1377e-05]) [94/128 - lr: 2.9785e-03] Epoch Error: 0.001219, W=tensor([[ 7.3422e-05, 9.9906e-01], [-1.0004e+00, -5.0057e-01]]), b=tensor([-0.0003, -0.0002]) [95/128 - lr: 2.8698e-03] Epoch Error: 0.000698, W=tensor([[ 2.7087e-04, 9.9979e-01], [-9.9974e-01, -4.9882e-01]]), b=tensor([ 0.0002, -0.0001]) [96/128 - lr: 2.7578e-03] Epoch Error: 0.001001, W=tensor([[-1.4277e-04, 1.0007e+00], [-1.0007e+00, -5.0149e-01]]), b=tensor([-8.3528e-05, 1.0019e-04]) [97/128 - lr: 2.6430e-03] Epoch Error: 0.001115, W=tensor([[ 3.0692e-04, 9.9981e-01], [-9.9941e-01, -4.9819e-01]]), b=tensor([ 0.0003, -0.0009]) [98/128 - lr: 2.5258e-03] Epoch Error: 0.001035, W=tensor([[-2.2373e-04, 1.0001e+00], [-1.0002e+00, -5.0032e-01]]), b=tensor([0.0002, 0.0002]) [99/128 - lr: 2.4065e-03] Epoch Error: 0.000735, W=tensor([[-6.8873e-04, 9.9979e-01], [-1.0008e+00, -4.9984e-01]]), b=tensor([0.0004, 0.0006]) [100/128 - lr: 2.2858e-03] Epoch Error: 0.000692, W=tensor([[-4.0295e-04, 1.0013e+00], [-9.9915e-01, -5.0024e-01]]), b=tensor([4.8592e-05, 9.7183e-05]) [101/128 - lr: 2.1640e-03] Epoch Error: 0.000592, W=tensor([[-4.4167e-05, 9.9992e-01], [-1.0011e+00, -5.0021e-01]]), b=tensor([-1.5593e-04, -6.6904e-05]) [102/128 - lr: 2.0416e-03] Epoch Error: 0.000730, W=tensor([[-6.3255e-05, 1.0001e+00], [-9.9936e-01, -5.0025e-01]]), b=tensor([ 0.0002, -0.0004]) [103/128 - lr: 1.9190e-03] Epoch Error: 0.000489, W=tensor([[-1.0249e-04, 1.0001e+00], [-1.0002e+00, -5.0025e-01]]), b=tensor([ 2.0460e-04, -2.3722e-05]) [104/128 - lr: 1.7967e-03] Epoch Error: 0.000276, W=tensor([[ 9.1129e-06, 1.0003e+00], [-9.9994e-01, -4.9993e-01]]), b=tensor([-1.5875e-04, 4.5437e-05]) [105/128 - lr: 1.6752e-03] Epoch Error: 0.000311, W=tensor([[-1.8016e-04, 9.9989e-01], [-9.9976e-01, -5.0008e-01]]), b=tensor([-6.5294e-06, 6.9909e-05]) [106/128 - lr: 1.5549e-03] Epoch Error: 0.000256, W=tensor([[ 1.8151e-04, 9.9986e-01], [-9.9995e-01, -5.0058e-01]]), b=tensor([-0.0001, -0.0003]) [107/128 - lr: 1.4363e-03] Epoch Error: 0.000459, W=tensor([[-2.8963e-05, 1.0004e+00], [-9.9995e-01, -5.0030e-01]]), b=tensor([-6.7238e-05, -8.7982e-05]) [108/128 - lr: 1.3198e-03] Epoch Error: 0.000367, W=tensor([[ 1.3832e-04, 9.9923e-01], [-1.0001e+00, -4.9991e-01]]), b=tensor([ 5.8027e-05, -5.0988e-05]) [109/128 - lr: 1.2059e-03] Epoch Error: 0.000400, W=tensor([[-1.0852e-04, 1.0000e+00], [-9.9959e-01, -4.9979e-01]]), b=tensor([-1.4766e-04, 1.1857e-05]) [110/128 - lr: 1.0949e-03] Epoch Error: 0.000225, W=tensor([[-1.1258e-04, 9.9941e-01], [-1.0000e+00, -5.0016e-01]]), b=tensor([-9.7096e-05, 4.7644e-04]) [111/128 - lr: 9.8735e-04] Epoch Error: 0.000410, W=tensor([[ 6.0554e-06, 9.9996e-01], [-9.9976e-01, -4.9961e-01]]), b=tensor([-0.0002, 0.0002]) [112/128 - lr: 8.8360e-04] Epoch Error: 0.000480, W=tensor([[-8.7386e-05, 1.0002e+00], [-9.9986e-01, -4.9996e-01]]), b=tensor([0.0002, 0.0001]) [113/128 - lr: 7.8405e-04] Epoch Error: 0.000378, W=tensor([[-1.1393e-04, 9.9990e-01], [-1.0001e+00, -5.0009e-01]]), b=tensor([ 1.0459e-05, -3.1913e-04]) [114/128 - lr: 6.8907e-04] Epoch Error: 0.000355, W=tensor([[-2.1793e-04, 1.0000e+00], [-9.9990e-01, -5.0007e-01]]), b=tensor([ 0.0001, -0.0002]) [115/128 - lr: 5.9901e-04] Epoch Error: 0.000216, W=tensor([[-1.0103e-04, 1.0000e+00], [-1.0001e+00, -4.9990e-01]]), b=tensor([ 1.8183e-04, -2.1772e-05]) [116/128 - lr: 5.1422e-04] Epoch Error: 0.000152, W=tensor([[-7.8162e-05, 1.0000e+00], [-1.0001e+00, -5.0001e-01]]), b=tensor([ 9.9333e-05, -5.6128e-05]) [117/128 - lr: 4.3501e-04] Epoch Error: 0.000081, W=tensor([[ 3.8341e-05, 9.9996e-01], [-1.0002e+00, -4.9998e-01]]), b=tensor([-1.5643e-05, 7.2027e-05]) [118/128 - lr: 3.6169e-04] Epoch Error: 0.000105, W=tensor([[-5.7335e-05, 1.0001e+00], [-9.9995e-01, -4.9999e-01]]), b=tensor([-2.6819e-05, -7.5429e-06]) [119/128 - lr: 2.9452e-04] Epoch Error: 0.000100, W=tensor([[-1.2137e-05, 9.9996e-01], [-1.0000e+00, -4.9986e-01]]), b=tensor([2.3959e-05, 1.1439e-04]) [120/128 - lr: 2.3376e-04] Epoch Error: 0.000074, W=tensor([[-1.0361e-05, 9.9995e-01], [-9.9993e-01, -4.9996e-01]]), b=tensor([-4.1451e-05, -3.1818e-06]) [121/128 - lr: 1.7964e-04] Epoch Error: 0.000091, W=tensor([[ 6.0607e-06, 1.0000e+00], [-1.0000e+00, -4.9995e-01]]), b=tensor([1.6299e-05, 7.9060e-05]) [122/128 - lr: 1.3235e-04] Epoch Error: 0.000048, W=tensor([[-1.5443e-06, 1.0001e+00], [-9.9996e-01, -5.0007e-01]]), b=tensor([2.9580e-06, 1.5721e-05]) [123/128 - lr: 9.2093e-05] Epoch Error: 0.000021, W=tensor([[ 3.2709e-06, 1.0000e+00], [-1.0000e+00, -4.9999e-01]]), b=tensor([-1.2055e-05, 3.4754e-06]) [124/128 - lr: 5.9001e-05] Epoch Error: 0.000014, W=tensor([[-1.9627e-06, 9.9999e-01], [-1.0000e+00, -5.0000e-01]]), b=tensor([-8.8738e-06, -9.2938e-06]) [125/128 - lr: 3.3203e-05] Epoch Error: 0.000012, W=tensor([[-1.9953e-06, 9.9999e-01], [-1.0000e+00, -4.9999e-01]]), b=tensor([ 5.6471e-06, -1.0543e-06]) [126/128 - lr: 1.4797e-05] Epoch Error: 0.000006, W=tensor([[-4.6119e-06, 1.0000e+00], [-9.9999e-01, -5.0000e-01]]), b=tensor([-1.9640e-06, -5.5545e-06]) [127/128 - lr: 3.8508e-06] Epoch Error: 0.000004, W=tensor([[-9.6412e-07, 1.0000e+00], [-1.0000e+00, -5.0000e-01]]), b=tensor([-1.0393e-06, -1.7788e-07]) [128/128 - lr: 4.0653e-07] Epoch Error: 0.000001, W=tensor([[-5.6339e-07, 1.0000e+00], [-1.0000e+00, -5.0000e-01]]), b=tensor([-1.1311e-07, -3.7789e-07]) Best matrix: tensor([[-5.6339e-07, 1.0000e+00], [-1.0000e+00, -5.0000e-01]]), mean absolute error: 0.000001 Best bias: tensor([-1.1311e-07, -3.7789e-07]), mean absolute error: 0.000000
fig_ref_position, axes_ref_position = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Position Ref.")
fig_ref_velocity, axes_ref_velocity = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK4(5) - SHA Velocity Ref.")
simple_oscillator_net.load_state_dict(best_parameters)
_, _, sha_states_optimised, sha_times_optimised, _ = adaptive_rk45_integrator.apply(sha_nn_fn, initial_state, initial_time, final_time, initial_timestep, integrator_kwargs)
_ = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_position, method_label="RK4(5) - SHA Position Opt.")
_ = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_optimised, sha_times_optimised)], axes=axes_ref_velocity, method_label="RK4(5) - SHA Velocity Opt.")
And we can see that the neural network is able to effectively learn the dynamics of this system, but it would not extend to other systems with different frequency and damping as they would have a different matrix.
In the coming notebooks, we will extend our network to learn the general dynamics by passing frequency and damping as a parameter. Further, we will learn more complex system dynamics and how to manipulate these systems.