The Driven Harmonic Oscillator¶
In this notebook, we will go through implementing the driven harmonic system for which we'll train a neural network to learn to control the oscillator state.
We will train a separate network to learn the system dynamics as well.
A part of this exercise will also show you how to make animations in matplotlib!
# We restrict the number of PyTorch CPU threads to 1 as cpu training is
# slower with multiple threads due to the small sizes of network we're dealing with
%env OMP_NUM_THREADS = 1
%env OPENBLAS_NUM_THREADS = 1
%env MKL_NUM_THREADS = 1
%env VECLIB_MAXIMUM_THREADS = 1
%env NUMEXPR_NUM_THREADS = 1
env: OMP_NUM_THREADS=1 env: OPENBLAS_NUM_THREADS=1 env: MKL_NUM_THREADS=1 env: VECLIB_MAXIMUM_THREADS=1 env: NUMEXPR_NUM_THREADS=1
import warnings
import torch
import random
import numpy as np
import neuralode
import copy
warnings.simplefilter('ignore', RuntimeWarning)
# Again restrict the number of pytorch threads
torch.set_num_threads(1)
# For convenience, we define the default tensor device and dtype here
torch.set_default_device('cpu')
# In neural networks, we prefer 32-bit/16-bit floats, but for precise integration, 64-bit is preferred. We will revisit this later when we need to mix integration with neural network training
torch.set_default_dtype(torch.float64)
# Set the random seeds so that each run of this notebook is reproducible
torch.manual_seed(1)
random.seed(1)
np.random.seed(1)
initial_position = torch.tensor(1.0)
initial_velocity = torch.tensor(0.0)
frequency = (torch.ones_like(initial_position))
damping = (torch.ones_like(initial_position)*0.25)
initial_state = torch.stack([
initial_position,
initial_velocity,
], dim=-1).requires_grad_(True)
initial_time = torch.tensor(0.0)
final_time = torch.tensor(25.0)
initial_timestep = (final_time - initial_time) / 100
current_integrator = neuralode.integrators.AdaptiveRKV87Integrator
atol = rtol = torch.tensor(torch.finfo(initial_state.dtype).eps**0.5)
final_state, _, sha_states, sha_times, _ = current_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, initial_state, initial_time, final_time, initial_timestep, {'atol': atol, 'rtol': rtol}, frequency, damping)
fig, axes = neuralode.plot.trajectory.plot_trajectory([(i, j) for i, j in zip(sha_states, sha_times)], method_label="RK7(8) - Simple Harmonic Oscillator")
Learning to Fix the Oscillator State¶
Previously we learned the dynamics of this oscillator directly, here we will incorporate control of the oscillator. We'll be using the forced harmonic oscillator model where we not only have movement of the oscillator with the dynamics that come from the simple harmonic oscillator but also a driving force that does work on the system.
Mathematically, this is equivalent to
$$ \begin{bmatrix} x^{(1)} \\ v^{(1)} \end{bmatrix} = \mathbf{A} \begin{bmatrix} x \\ v \end{bmatrix} + \begin{bmatrix} 0 \\ F \end{bmatrix} $$
where
$$ \mathbf{A} = \begin{bmatrix} 0 & 1 \\ -\omega^2 & -2\zeta\omega \end{bmatrix} $$
If we were to learn both the dynamics and the driving force together, we'd need both driven and undriven oscillator states as the two networks could learn some kind of combined dynamics that integrates the driving force into the RHS. In that fashion, we'd no longer be able to distinguish the dynamics of the oscillator itself from the driving force.
Dataset and Loss Function Configuration¶
We create a dataset of 8 initial states which cover a range of positions and velocities, and use a loss function that penalises both the deviation from the target state at the end of the integration, but also cumulatively over the whole trajectory.
state_range_min = torch.tensor([-25.0, -5.0])
state_range_max = torch.tensor([ 25.0, 5.0])
state_dataset = torch.rand(8, 2) * (state_range_max - state_range_min)[None] + state_range_min[None]
target_state = torch.tensor([1.0, 0.0])
def damping_closure(rhs, parameters, minibatch, optimiser, integrator_kwargs, integrator=None):
if integrator is None:
integrator = neuralode.integrators.AdaptiveRK45Integrator
optimiser.zero_grad()
t_final = minibatch['final_time']
t_initial = minibatch['initial_time']
initial_dt = minibatch['dt']
states = minibatch['states']
def damping_loss(state, time):
# Measure of the unnormalised cosine alignement between the target state and the current state
# When the current state is exactly the same as the target state, then the error is 0.0
# When the current state is the opposite of the target state (i.e. -target_state), the error
# is at its highest at 2|state|x|target| where |.| denotes the vector 2-norm
deviation_from_target = torch.linalg.vector_norm(state.detach(), dim=-1, keepdim=True)*torch.linalg.norm(target_state, dim=-1, keepdim=True) - torch.linalg.vecdot(state, target_state, dim=-1)[...,None]
return deviation_from_target
def damping_augmented_rhs(state, time, *nn_parameters):
return torch.cat([
rhs(state[...,:-1], time, *nn_parameters),
damping_loss(state[...,-1:], time)
], dim=-1)
augmented_states = torch.cat([
states,
torch.zeros_like(states[...,0,None])
], dim=-1)
final_augmented_state, _, _, _, _ = integrator.apply(damping_augmented_rhs, augmented_states, t_initial, t_final, initial_dt, integrator_kwargs, *parameters)
error = (target_state - final_augmented_state[...,:-1]).square().sum(dim=-1).mean() + final_augmented_state[...,-1:].sum(dim=-1).mean()
if error.requires_grad:
error.backward()
return error
Network Specification¶
The network is a simple feed-forward network composed of dense layers and non-linearities, outputting a single control value, the acceleration of the oscillator.
driven_oscillator_net = neuralode.models.oscillator.DrivenOscillatorNet()
driven_oscillator_net.apply(neuralode.models.util.init_weights)
max_acceleration = torch.tensor(25.0)
def get_control(x, t, controller, *nn_parameters):
x_encoded = torch.stack([x[...,0]/100.0, x[...,1]/100.0], dim=-1)
control_u = torch.func.functional_call(controller, {k: p for (k, _), p in zip(controller.named_parameters(), nn_parameters)}, (x_encoded, t))
control_u = control_u*max_acceleration
return torch.cat([
torch.zeros_like(control_u),
control_u, # The network can only apply forces, not directly control the velocity of the oscillator
], dim=-1)
def dha_nn_fn(x, t, frequency, damping, *nn_parameters):
# We rescale the input to the [-1, 1] range for better stability of the network
control_vec = get_control(x, t, driven_oscillator_net, *nn_parameters)
return neuralode.dynamics.simple_harmonic_oscillator(x, t, frequency, damping) + control_vec
Training Configuration¶
We train for 128 with a batch size of 4 using the Adam optimiser and OneCycleLR learning rate schedule for accelerated convergence.
batch_size = 4
number_of_epochs = 64
driven_oscillator_optimiser = torch.optim.Adam(driven_oscillator_net.parameters(), lr=1e-4, amsgrad=True)
one_cycle_lr_driven_oscillator = torch.optim.lr_scheduler.OneCycleLR(driven_oscillator_optimiser, max_lr=2e-2, div_factor=200.0, steps_per_epoch=round(state_dataset.shape[0]/batch_size+0.5), epochs=number_of_epochs, three_phase=False)
Training the Network¶
driven_oscillator_net.train()
best_error = torch.inf
best_parameters = copy.deepcopy(driven_oscillator_net.state_dict())
common_closure_args = [driven_oscillator_optimiser, {'atol': atol, 'rtol': rtol}, current_integrator]
for step in range(number_of_epochs):
epoch_error = 0.0
shuffled_indices = torch.randperm(state_dataset.shape[0])
for batch_idx in range(0, state_dataset.shape[0], batch_size):
batch_dict = {
'states': state_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
'initial_time': initial_time.detach().clone(),
'final_time': initial_time.detach().clone() + 2.5,
'dt': initial_timestep.detach().clone()
}
step_error = driven_oscillator_optimiser.step(lambda: damping_closure(dha_nn_fn, [frequency, torch.zeros_like(damping)] + list(driven_oscillator_net.parameters()), batch_dict, *common_closure_args))
one_cycle_lr_driven_oscillator.step()
epoch_error = epoch_error + step_error.item()*batch_dict['states'].shape[0]
print(f"[{step+1}/{number_of_epochs} - lr: {one_cycle_lr_driven_oscillator.get_last_lr()[0]:.4e}]/[{batch_idx}/{state_dataset.shape[0]}] Batch Error: {step_error:.4f} ", end='\r', flush=True)
epoch_error = epoch_error/state_dataset.shape[0]
if epoch_error < best_error:
best_error = epoch_error
best_parameters = copy.deepcopy(driven_oscillator_net.state_dict())
print(" "*128, end='\r', flush=True)
print(f"[{step+1}/{number_of_epochs} - lr: {one_cycle_lr_driven_oscillator.get_last_lr()[0]:.4e}] Epoch Error: {epoch_error:.6}")
print()
[1/64 - lr: 2.4008e-04] Epoch Error: 151.945 [2/64 - lr: 6.5639e-04] Epoch Error: 151.252 [3/64 - lr: 1.3372e-03] Epoch Error: 149.567 [4/64 - lr: 2.2633e-03] Epoch Error: 146.498 [5/64 - lr: 3.4087e-03] Epoch Error: 139.44 [6/64 - lr: 4.7411e-03] Epoch Error: 140.579 [7/64 - lr: 6.2230e-03] Epoch Error: 128.389 [8/64 - lr: 7.8126e-03] Epoch Error: 117.897 [9/64 - lr: 9.4653e-03] Epoch Error: 94.8791 [10/64 - lr: 1.1134e-02] Epoch Error: 74.8794 [11/64 - lr: 1.2773e-02] Epoch Error: 42.4302 [12/64 - lr: 1.4335e-02] Epoch Error: 33.4042 [13/64 - lr: 1.5776e-02] Epoch Error: 31.0913 [14/64 - lr: 1.7056e-02] Epoch Error: 19.6234 [15/64 - lr: 1.8139e-02] Epoch Error: 7.62874 [16/64 - lr: 1.8994e-02] Epoch Error: 4.02601 [17/64 - lr: 1.9597e-02] Epoch Error: 1.27341 [18/64 - lr: 1.9931e-02] Epoch Error: 2.04025 [19/64 - lr: 1.9998e-02] Epoch Error: 2.61187 [20/64 - lr: 1.9958e-02] Epoch Error: 1.27375 [21/64 - lr: 1.9870e-02] Epoch Error: 0.852886 [22/64 - lr: 1.9733e-02] Epoch Error: 0.332227 [23/64 - lr: 1.9549e-02] Epoch Error: 0.525715 [24/64 - lr: 1.9317e-02] Epoch Error: 0.386378 [25/64 - lr: 1.9040e-02] Epoch Error: 0.451681 [26/64 - lr: 1.8718e-02] Epoch Error: 0.369986 [27/64 - lr: 1.8353e-02] Epoch Error: 0.171453 [28/64 - lr: 1.7948e-02] Epoch Error: 0.126298 [29/64 - lr: 1.7503e-02] Epoch Error: 0.096483 [30/64 - lr: 1.7021e-02] Epoch Error: 0.0432904 [31/64 - lr: 1.6505e-02] Epoch Error: 0.111354 [32/64 - lr: 1.5957e-02] Epoch Error: 0.12867 [33/64 - lr: 1.5380e-02] Epoch Error: 0.0925516 [34/64 - lr: 1.4776e-02] Epoch Error: 0.130442 [35/64 - lr: 1.4148e-02] Epoch Error: 0.0731119 [36/64 - lr: 1.3501e-02] Epoch Error: 0.0871143 [37/64 - lr: 1.2836e-02] Epoch Error: 0.0568891 [38/64 - lr: 1.2157e-02] Epoch Error: 0.0849734 [39/64 - lr: 1.1467e-02] Epoch Error: 0.0672761 [40/64 - lr: 1.0771e-02] Epoch Error: 0.0692755 [41/64 - lr: 1.0070e-02] Epoch Error: 0.0560489 [42/64 - lr: 9.3693e-03] Epoch Error: 0.0513731 [43/64 - lr: 8.6716e-03] Epoch Error: 0.0604199 [44/64 - lr: 7.9804e-03] Epoch Error: 0.0440056 [45/64 - lr: 7.2991e-03] Epoch Error: 0.0508528 [46/64 - lr: 6.6311e-03] Epoch Error: 0.0547071 [47/64 - lr: 5.9797e-03] Epoch Error: 0.0434798 [48/64 - lr: 5.3480e-03] Epoch Error: 0.0451367 [49/64 - lr: 4.7392e-03] Epoch Error: 0.047124 [50/64 - lr: 4.1562e-03] Epoch Error: 0.0455598 [51/64 - lr: 3.6020e-03] Epoch Error: 0.0439628 [52/64 - lr: 3.0793e-03] Epoch Error: 0.0436574 [53/64 - lr: 2.5905e-03] Epoch Error: 0.044996 [54/64 - lr: 2.1382e-03] Epoch Error: 0.0455611 [55/64 - lr: 1.7245e-03] Epoch Error: 0.0442852 [56/64 - lr: 1.3515e-03] Epoch Error: 0.0433767 [57/64 - lr: 1.0210e-03] Epoch Error: 0.043196 [58/64 - lr: 7.3461e-04] Epoch Error: 0.0436175 [59/64 - lr: 4.9379e-04] Epoch Error: 0.0432106 [60/64 - lr: 2.9970e-04] Epoch Error: 0.0431902 [61/64 - lr: 1.5329e-04] Epoch Error: 0.0432194 [62/64 - lr: 5.5281e-05] Epoch Error: 0.0432728 [63/64 - lr: 6.1562e-06] Epoch Error: 0.043241 [64/64 - lr: 6.1562e-06] Epoch Error: 0.0432406
Testing the Network on Unseen Systems¶
In order to test this controller, let's generate several cases that the network hasn't seen and integrate the system. Ideally, we should observe rapid damping of oscillations, akin to critical damping.
test_range_min = torch.tensor([-256, -32.0])
test_range_max = torch.tensor([ 256, 32.0])
test_state_dataset = torch.rand(64, 2) * (test_range_max - test_range_min)[None] + test_range_min[None]
driven_oscillator_net.load_state_dict(best_parameters)
driven_oscillator_net.eval()
test_final_time = final_time*2.0
_, _, dha_states_optimised, dha_times_optimised, _ = current_integrator.apply(dha_nn_fn, initial_state, initial_time, test_final_time, initial_timestep, {'atol': atol, 'rtol': rtol}, frequency, torch.zeros_like(damping))
fig_ref_position, axes_ref_position = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(dha_states_optimised, dha_times_optimised)], method_label=None)
fig_ref_position.suptitle("Position")
fig_ref_velocity, axes_ref_velocity = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(dha_states_optimised, dha_times_optimised)], method_label=None)
fig_ref_velocity.suptitle("Velocity")
_, _, integrated_test_states, integrated_test_times, _ = current_integrator.apply(dha_nn_fn, test_state_dataset, initial_time, test_final_time, initial_timestep, {'atol': atol, 'rtol': rtol}, frequency, torch.zeros_like(damping))
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,0], j) for i, j in zip(integrated_test_states, integrated_test_times)], axes=axes_ref_position, method_label=None)
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,1], j) for i, j in zip(integrated_test_states, integrated_test_times)], axes=axes_ref_velocity, method_label=None)
print("Training results on test data:")
print(f"\tMean position at t={test_final_time.item():.2f}s: {integrated_test_states[-1, ..., 0].mean().item():.4e} ± {integrated_test_states[-1, ..., 0].std().item():.4e}")
print(f"\tMean position error at t={test_final_time.item():.2f}s: {(target_state - integrated_test_states)[-1, ..., 0].mean().item():.4e} ± {(target_state - integrated_test_states)[-1, ..., 0].std().item():.4e}")
print()
print(f"\tMean velocity at t={test_final_time.item():.2f}s: {integrated_test_states[-1, ..., 1].mean().item():.4e} ± {integrated_test_states[-1, ..., 1].std().item():.4e}")
print(f"\tMean velocity error at t={test_final_time.item():.2f}s: {(target_state - integrated_test_states)[-1, ..., 1].mean().item():.4e} ± {(target_state - integrated_test_states)[-1, ..., 1].std().item():.4e}")
fig_ref_position_closeup, axes_ref_position_closeup = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(dha_states_optimised, dha_times_optimised)], method_label=None)
fig_ref_velocity_closeup, axes_ref_velocity_closeup = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(dha_states_optimised, dha_times_optimised)], method_label=None)
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,0], j) for i, j in zip(integrated_test_states, integrated_test_times)], axes=axes_ref_position_closeup, method_label=None)
fig_ref_position_closeup.suptitle("Close-Up of Position")
axes_ref_position_closeup[0].set_xlim(test_final_time.item() - 5.0, test_final_time.item())
axes_ref_position_closeup[0].set_ylim(target_state[0].item() - 1.0, target_state[0].item() + 1.0)
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,1], j) for i, j in zip(integrated_test_states, integrated_test_times)], axes=axes_ref_velocity_closeup, method_label=None)
fig_ref_velocity_closeup.suptitle("Close-Up of Velocity")
axes_ref_velocity_closeup[0].set_xlim(test_final_time.item() - 5.0, test_final_time.item())
axes_ref_velocity_closeup[0].set_ylim(target_state[1].item() - 1.0, target_state[1].item() + 1.0)
C:\Users\ekin4\PycharmProjects\ReCoDE-NeuralODEs\neuralode\plot\trajectory.py:79: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. ax.legend() C:\Users\ekin4\PycharmProjects\ReCoDE-NeuralODEs\neuralode\plot\trajectory.py:79: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. ax.legend()
Training results on test data: Mean position at t=50.00s: 9.3575e-01 ± 3.4454e-11 Mean position error at t=50.00s: 6.4254e-02 ± 3.4454e-11 Mean velocity at t=50.00s: -2.4343e-11 ± 2.2306e-10 Mean velocity error at t=50.00s: 2.4343e-11 ± 2.2306e-10
(-1.0, 1.0)
Here we can see that the network has learned how to dampen the oscillations of the system and nudge it towards the target state at the end of the trajectory. Although the damping isn't perfect (i.e. critically damped), the network is able to damp configurations that it has not encountered before within 20 seconds.
Co-Learning Dynamics and Control of an Oscillator¶
Loosely based on https://arxiv.org/html/2401.01836v1.
Suppose we didn't know the dynamics of the system either; we had one reference trajectory and many initial states that we'd like to control. One method for solving this would be to have two networks that are trained in an alternating fashion. One network learns the dynamics, like we've seen before, and another network, like the one above, learns the control. We can posit that the same structure applies, the network can only apply forces, but not manipulate the velocity directly which should preserve the physicality of the system.
Generating Reference Trajectory Data¶
We take a damped oscillator as previously specified and simply take the points generated along its trajectory as our dataset. With a high-order integrator, this dataset could end up being too sparse.
# Generate reference trajectory for optimisation/learning
with torch.no_grad():
_, _, sha_states_ref, sha_times_ref, _ = current_integrator.apply(neuralode.dynamics.simple_harmonic_oscillator, initial_state, initial_time, final_time, initial_timestep, {'atol': torch.zeros_like(atol), 'rtol': rtol}, frequency, damping)
sha_states_ref, sha_times_ref = sha_states_ref.detach(), sha_times_ref.detach()
Network Specification¶
We specify two networks: one dynamics network which is simply a single affine transformation (which we know fits the simple harmonic oscillator and can be easily analysed) and a control network, same as before.
The pytorch manual seed is set in order to make the notebook more reproducible, and this seed has been found to converge, otherwise the network can sometimes learn the time-average dynamics which neglects the oscillations.
torch.manual_seed(36)
dynamics_network = neuralode.models.oscillator.OscillatorNet()
control_network = neuralode.models.oscillator.DrivenOscillatorNet()
dynamics_network.apply(neuralode.models.util.init_weights)
control_network.apply(neuralode.models.util.init_weights)
max_acceleration = torch.tensor(25.0)
def learned_dynamics_rhs(x, t, *nn_parameters):
if len(nn_parameters) > 0:
return torch.func.functional_call(dynamics_network, {k: p for (k, _), p in zip(dynamics_network.named_parameters(), nn_parameters)}, (x, t))
else:
return dynamics_network(x, t)
def controlled_dynamics_rhs(x, t, *nn_parameters):
control_vec = get_control(x, t, control_network, *nn_parameters)
return learned_dynamics_rhs(x, t) + control_vec
Training Configuration¶
Ideally, the dynamics network converges before the control network so that the control network learns controls suitable for the system we're interested in.
By running the training in cycles every epoch, we train the dynamics network 4 times more than the control network. These are tunable parameters, and can be tweaked to control the trade-off between convergence of the dynamics and the control network.
number_of_epochs = 32
number_of_dynamics_cycles = 8
number_of_control_cycles = 4
dynamics_optimiser = torch.optim.Adam(dynamics_network.parameters(), lr=1e-3, amsgrad=True)
control_optimiser = torch.optim.Adam(control_network.parameters(), lr=1e-3, amsgrad=True)
one_cycle_lr_dynamics = torch.optim.lr_scheduler.OneCycleLR(dynamics_optimiser, max_lr=1e-1, div_factor=50.0, steps_per_epoch=round(sha_states_ref.shape[0]/batch_size+0.5), epochs=number_of_epochs*number_of_dynamics_cycles, three_phase=True)
one_cycle_lr_control = torch.optim.lr_scheduler.OneCycleLR(control_optimiser, max_lr=2e-2, div_factor=200.0, steps_per_epoch=round(state_dataset.shape[0]/batch_size+0.5), epochs=number_of_epochs*number_of_control_cycles, three_phase=False)
ideal_matrix = neuralode.dynamics.get_simple_harmonic_oscillator_matrix(frequency, damping)
ideal_bias = torch.zeros_like(initial_state)
Training the Networks¶
best_dynamics_error = torch.inf
best_control_error = torch.inf
best_dynamics_parameters = copy.deepcopy(dynamics_network.state_dict())
best_control_parameters = copy.deepcopy(control_network.state_dict())
control_network.train()
dynamics_network.train()
common_dyn_closure_args = [dynamics_optimiser, {'atol': atol, 'rtol': rtol}, current_integrator]
common_ctrl_closure_args = [control_optimiser, {'atol': atol, 'rtol': rtol}, current_integrator]
for epoch in range(number_of_epochs):
# First phase is to train the dynamics as it wouldn't make sense to learn control of
# random dynamics.
dynamics_network.train()
for cycle in range(number_of_dynamics_cycles):
dynamics_epoch_error = 0.0
shuffled_indices = torch.randperm(sha_times_ref.shape[0])
for batch_idx in range(0, sha_times_ref.shape[0], batch_size):
batch_dict = {
'times': sha_times_ref[shuffled_indices][batch_idx:batch_idx+batch_size],
'states': sha_states_ref[shuffled_indices][batch_idx:batch_idx+batch_size],
'initial_state': initial_state.detach().clone(),
'initial_time': initial_time.detach().clone(),
'dt': initial_timestep.detach().clone(),
}
step_error = dynamics_optimiser.step(lambda: neuralode.closures.dynamics_closure(learned_dynamics_rhs, list(dynamics_network.parameters()), batch_dict, *common_dyn_closure_args))
one_cycle_lr_dynamics.step()
dynamics_epoch_error = dynamics_epoch_error + step_error.item()*batch_dict['times'].shape[0]
print(f"(Dynamics) Batch: [{batch_idx}/{sha_times_ref.shape[0]}] Error: {step_error:.6f} ", end='\r', flush=True)
dynamics_epoch_error = dynamics_epoch_error/sha_times_ref.shape[0]
if dynamics_epoch_error < best_dynamics_error:
best_dynamics_error = dynamics_epoch_error
best_dynamics_parameters = copy.deepcopy(dynamics_network.state_dict())
learned_matrix = dynamics_network.state_dict()['internal_net.0.weight']
learned_bias = dynamics_network.state_dict()['internal_net.0.bias']
print(" "*128, end='\r', flush=True)
print(f"(Dynamics) Cycle: [{cycle+1}/{number_of_dynamics_cycles} - lr: {one_cycle_lr_dynamics.get_last_lr()[0]:.4e}] Error: {dynamics_epoch_error:.6f}")
# Second phase is to train the control of the dynamics
# We would like to learn the controls on the best representation of the dynamics,
# but continue the optimisation as normal.
# So we save the current dynamics network state and load the best state
current_dynamics_parameters = copy.deepcopy(dynamics_network.state_dict())
dynamics_network.load_state_dict(best_dynamics_parameters)
dynamics_network.eval()
for cycle in range(number_of_control_cycles):
control_epoch_error = 0.0
shuffled_indices = torch.randperm(state_dataset.shape[0])
for batch_idx in range(0, state_dataset.shape[0], batch_size):
batch_dict = {
'states': state_dataset[shuffled_indices][batch_idx:batch_idx+batch_size],
'initial_time': initial_time.detach().clone(),
'final_time': initial_time.detach().clone() + 2.5,
'dt': initial_timestep.detach().clone(),
}
step_error = control_optimiser.step(lambda: damping_closure(controlled_dynamics_rhs, list(control_network.parameters()), batch_dict, *common_ctrl_closure_args))
one_cycle_lr_control.step()
control_epoch_error = control_epoch_error + step_error.item()*batch_dict['states'].shape[0]
print(f"(Control) Batch: [{batch_idx}/{state_dataset.shape[0]}] Error: {step_error:.4f} ", end='\r', flush=True)
control_epoch_error = control_epoch_error/state_dataset.shape[0]
if control_epoch_error < best_control_error:
best_control_error = control_epoch_error
best_control_parameters = copy.deepcopy(control_network.state_dict())
print(" "*128, end='\r', flush=True)
print(f"(Control) Cycle: [{cycle+1}/{number_of_control_cycles} - lr: {one_cycle_lr_control.get_last_lr()[0]:.4e}] Error: {control_epoch_error:.4f}")
print(f"[{epoch+1}/{number_of_epochs}] Best Error: (Dynamics) {best_dynamics_error:.6f}/(Control) {best_control_error:.6f}")
print()
# And then restore the state to continue the optimisation
# (If we reset to the best every epoch, the optimisation may get stuck in a local minimum)
dynamics_network.load_state_dict(current_dynamics_parameters)
(Dynamics) Cycle: [1/8 - lr: 2.0411e-03] Error: 1.631016 (Dynamics) Cycle: [2/8 - lr: 2.1642e-03] Error: 1.084869 (Dynamics) Cycle: [3/8 - lr: 2.3692e-03] Error: 0.713909 (Dynamics) Cycle: [4/8 - lr: 2.6557e-03] Error: 0.534566 (Dynamics) Cycle: [5/8 - lr: 3.0232e-03] Error: 0.421840 (Dynamics) Cycle: [6/8 - lr: 3.4712e-03] Error: 0.396048 (Dynamics) Cycle: [7/8 - lr: 3.9988e-03] Error: 0.388483 (Dynamics) Cycle: [8/8 - lr: 4.6052e-03] Error: 0.368714 (Control) Cycle: [1/4 - lr: 1.3416e-04] Error: 197.5658 (Control) Cycle: [2/4 - lr: 2.3642e-04] Error: 196.5516 (Control) Cycle: [3/4 - lr: 4.0607e-04] Error: 195.6410 (Control) Cycle: [4/4 - lr: 6.4194e-04] Error: 194.0422 [1/32] Best Error: (Dynamics) 0.368714/(Control) 194.042170 (Dynamics) Cycle: [1/8 - lr: 5.2894e-03] Error: 0.357290 (Dynamics) Cycle: [2/8 - lr: 6.0502e-03] Error: 0.341688 (Dynamics) Cycle: [3/8 - lr: 6.8863e-03] Error: 0.328375 (Dynamics) Cycle: [4/8 - lr: 7.7964e-03] Error: 0.308969 (Dynamics) Cycle: [5/8 - lr: 8.7789e-03] Error: 0.296022 (Dynamics) Cycle: [6/8 - lr: 9.8322e-03] Error: 0.278045 (Dynamics) Cycle: [7/8 - lr: 1.0954e-02] Error: 0.261383 (Dynamics) Cycle: [8/8 - lr: 1.2144e-02] Error: 0.244250 (Control) Cycle: [1/4 - lr: 9.4242e-04] Error: 113.7902 (Control) Cycle: [2/4 - lr: 1.3054e-03] Error: 110.8693 (Control) Cycle: [3/4 - lr: 1.7285e-03] Error: 108.8710 (Control) Cycle: [4/4 - lr: 2.2087e-03] Error: 105.6138 [2/32] Best Error: (Dynamics) 0.244250/(Control) 105.613768 (Dynamics) Cycle: [1/8 - lr: 1.3398e-02] Error: 0.228030 (Dynamics) Cycle: [2/8 - lr: 1.4716e-02] Error: 0.216600 (Dynamics) Cycle: [3/8 - lr: 1.6094e-02] Error: 0.209134 (Dynamics) Cycle: [4/8 - lr: 1.7531e-02] Error: 0.202860 (Dynamics) Cycle: [5/8 - lr: 1.9024e-02] Error: 0.193912 (Dynamics) Cycle: [6/8 - lr: 2.0571e-02] Error: 0.191886 (Dynamics) Cycle: [7/8 - lr: 2.2168e-02] Error: 0.187635 (Dynamics) Cycle: [8/8 - lr: 2.3814e-02] Error: 0.182537 (Control) Cycle: [1/4 - lr: 2.7428e-03] Error: 57.6988 (Control) Cycle: [2/4 - lr: 3.3270e-03] Error: 60.4677 (Control) Cycle: [3/4 - lr: 3.9574e-03] Error: 61.5710 (Control) Cycle: [4/4 - lr: 4.6297e-03] Error: 58.6384 [3/32] Best Error: (Dynamics) 0.182537/(Control) 57.698777 (Dynamics) Cycle: [1/8 - lr: 2.5506e-02] Error: 0.177405 (Dynamics) Cycle: [2/8 - lr: 2.7240e-02] Error: 0.175327 (Dynamics) Cycle: [3/8 - lr: 2.9014e-02] Error: 0.173151 (Dynamics) Cycle: [4/8 - lr: 3.0825e-02] Error: 0.175517 (Dynamics) Cycle: [5/8 - lr: 3.2670e-02] Error: 0.167917 (Dynamics) Cycle: [6/8 - lr: 3.4545e-02] Error: 0.168342 (Dynamics) Cycle: [7/8 - lr: 3.6448e-02] Error: 0.173596 (Dynamics) Cycle: [8/8 - lr: 3.8376e-02] Error: 0.167835 (Control) Cycle: [1/4 - lr: 5.3392e-03] Error: 32.6097 (Control) Cycle: [2/4 - lr: 6.0810e-03] Error: 29.8663 (Control) Cycle: [3/4 - lr: 6.8501e-03] Error: 32.3584 (Control) Cycle: [4/4 - lr: 7.6411e-03] Error: 33.1556 [4/32] Best Error: (Dynamics) 0.167835/(Control) 29.866251 (Dynamics) Cycle: [1/8 - lr: 4.0324e-02] Error: 0.161230 (Dynamics) Cycle: [2/8 - lr: 4.2291e-02] Error: 0.158924 (Dynamics) Cycle: [3/8 - lr: 4.4272e-02] Error: 0.157252 (Dynamics) Cycle: [4/8 - lr: 4.6264e-02] Error: 0.158818 (Dynamics) Cycle: [5/8 - lr: 4.8264e-02] Error: 0.155971 (Dynamics) Cycle: [6/8 - lr: 5.0269e-02] Error: 0.151703 (Dynamics) Cycle: [7/8 - lr: 5.2275e-02] Error: 0.152077 (Dynamics) Cycle: [8/8 - lr: 5.4279e-02] Error: 0.147333 (Control) Cycle: [1/4 - lr: 8.4487e-03] Error: 16.9703 (Control) Cycle: [2/4 - lr: 9.2673e-03] Error: 16.6878 (Control) Cycle: [3/4 - lr: 1.0091e-02] Error: 17.3312 (Control) Cycle: [4/4 - lr: 1.0915e-02] Error: 16.9633 [5/32] Best Error: (Dynamics) 0.147333/(Control) 16.687759 (Dynamics) Cycle: [1/8 - lr: 5.6278e-02] Error: 0.145760 (Dynamics) Cycle: [2/8 - lr: 5.8267e-02] Error: 0.143774 (Dynamics) Cycle: [3/8 - lr: 6.0245e-02] Error: 0.142286 (Dynamics) Cycle: [4/8 - lr: 6.2207e-02] Error: 0.142686 (Dynamics) Cycle: [5/8 - lr: 6.4150e-02] Error: 0.140096 (Dynamics) Cycle: [6/8 - lr: 6.6071e-02] Error: 0.139707 (Dynamics) Cycle: [7/8 - lr: 6.7967e-02] Error: 0.136047 (Dynamics) Cycle: [8/8 - lr: 6.9834e-02] Error: 0.134965 (Control) Cycle: [1/4 - lr: 1.1733e-02] Error: 10.8345 (Control) Cycle: [2/4 - lr: 1.2539e-02] Error: 11.2427 (Control) Cycle: [3/4 - lr: 1.3328e-02] Error: 10.9508 (Control) Cycle: [4/4 - lr: 1.4095e-02] Error: 10.3681 [6/32] Best Error: (Dynamics) 0.134965/(Control) 10.368097 (Dynamics) Cycle: [1/8 - lr: 7.1670e-02] Error: 0.132654 (Dynamics) Cycle: [2/8 - lr: 7.3471e-02] Error: 0.131947 (Dynamics) Cycle: [3/8 - lr: 7.5235e-02] Error: 0.128349 (Dynamics) Cycle: [4/8 - lr: 7.6958e-02] Error: 0.127554 (Dynamics) Cycle: [5/8 - lr: 7.8637e-02] Error: 0.123253 (Dynamics) Cycle: [6/8 - lr: 8.0270e-02] Error: 0.121095 (Dynamics) Cycle: [7/8 - lr: 8.1854e-02] Error: 0.118037 (Dynamics) Cycle: [8/8 - lr: 8.3386e-02] Error: 0.115236 (Control) Cycle: [1/4 - lr: 1.4833e-02] Error: 8.8873 (Control) Cycle: [2/4 - lr: 1.5539e-02] Error: 8.4264 (Control) Cycle: [3/4 - lr: 1.6208e-02] Error: 6.8011 (Control) Cycle: [4/4 - lr: 1.6834e-02] Error: 6.7792 [7/32] Best Error: (Dynamics) 0.115236/(Control) 6.779247 (Dynamics) Cycle: [1/8 - lr: 8.4864e-02] Error: 0.110782 (Dynamics) Cycle: [2/8 - lr: 8.6286e-02] Error: 0.107762 (Dynamics) Cycle: [3/8 - lr: 8.7648e-02] Error: 0.103760 (Dynamics) Cycle: [4/8 - lr: 8.8948e-02] Error: 0.100746 (Dynamics) Cycle: [5/8 - lr: 9.0185e-02] Error: 0.094112 (Dynamics) Cycle: [6/8 - lr: 9.1357e-02] Error: 0.082775 (Dynamics) Cycle: [7/8 - lr: 9.2461e-02] Error: 0.068870 (Dynamics) Cycle: [8/8 - lr: 9.3495e-02] Error: 0.053022 (Control) Cycle: [1/4 - lr: 1.7413e-02] Error: 12.4582 (Control) Cycle: [2/4 - lr: 1.7942e-02] Error: 5.6166 (Control) Cycle: [3/4 - lr: 1.8416e-02] Error: 2.0667 (Control) Cycle: [4/4 - lr: 1.8834e-02] Error: 1.8554 [8/32] Best Error: (Dynamics) 0.053022/(Control) 1.855380 (Dynamics) Cycle: [1/8 - lr: 9.4458e-02] Error: 0.047605 (Dynamics) Cycle: [2/8 - lr: 9.5348e-02] Error: 0.045129 (Dynamics) Cycle: [3/8 - lr: 9.6164e-02] Error: 0.044022 (Dynamics) Cycle: [4/8 - lr: 9.6904e-02] Error: 0.041730 (Dynamics) Cycle: [5/8 - lr: 9.7567e-02] Error: 0.040621 (Dynamics) Cycle: [6/8 - lr: 9.8152e-02] Error: 0.037731 (Dynamics) Cycle: [7/8 - lr: 9.8658e-02] Error: 0.036201 (Dynamics) Cycle: [8/8 - lr: 9.9085e-02] Error: 0.034634 (Control) Cycle: [1/4 - lr: 1.9190e-02] Error: 8.4843 (Control) Cycle: [2/4 - lr: 1.9485e-02] Error: 3.5692 (Control) Cycle: [3/4 - lr: 1.9714e-02] Error: 3.6376 (Control) Cycle: [4/4 - lr: 1.9877e-02] Error: 1.6553 [9/32] Best Error: (Dynamics) 0.034634/(Control) 1.655267 (Dynamics) Cycle: [1/8 - lr: 9.9430e-02] Error: 0.033063 (Dynamics) Cycle: [2/8 - lr: 9.9695e-02] Error: 0.030263 (Dynamics) Cycle: [3/8 - lr: 9.9877e-02] Error: 0.031640 (Dynamics) Cycle: [4/8 - lr: 9.9978e-02] Error: 0.029122 (Dynamics) Cycle: [5/8 - lr: 9.9997e-02] Error: 0.025567 (Dynamics) Cycle: [6/8 - lr: 9.9934e-02] Error: 0.024702 (Dynamics) Cycle: [7/8 - lr: 9.9788e-02] Error: 0.021571 (Dynamics) Cycle: [8/8 - lr: 9.9561e-02] Error: 0.021457 (Control) Cycle: [1/4 - lr: 1.9972e-02] Error: 2.0148 (Control) Cycle: [2/4 - lr: 2.0000e-02] Error: 0.7153 (Control) Cycle: [3/4 - lr: 1.9993e-02] Error: 0.4305 (Control) Cycle: [4/4 - lr: 1.9973e-02] Error: 0.3249 [10/32] Best Error: (Dynamics) 0.021457/(Control) 0.324926 (Dynamics) Cycle: [1/8 - lr: 9.9253e-02] Error: 0.018149 (Dynamics) Cycle: [2/8 - lr: 9.8863e-02] Error: 0.020014 (Dynamics) Cycle: [3/8 - lr: 9.8393e-02] Error: 0.016999 (Dynamics) Cycle: [4/8 - lr: 9.7844e-02] Error: 0.015486 (Dynamics) Cycle: [5/8 - lr: 9.7217e-02] Error: 0.012047 (Dynamics) Cycle: [6/8 - lr: 9.6512e-02] Error: 0.014024 (Dynamics) Cycle: [7/8 - lr: 9.5730e-02] Error: 0.009492 (Dynamics) Cycle: [8/8 - lr: 9.4874e-02] Error: 0.008979 (Control) Cycle: [1/4 - lr: 1.9941e-02] Error: 0.4691 (Control) Cycle: [2/4 - lr: 1.9897e-02] Error: 0.5185 (Control) Cycle: [3/4 - lr: 1.9841e-02] Error: 0.4165 (Control) Cycle: [4/4 - lr: 1.9772e-02] Error: 0.3335 [11/32] Best Error: (Dynamics) 0.008979/(Control) 0.324926 (Dynamics) Cycle: [1/8 - lr: 9.3944e-02] Error: 0.007531 (Dynamics) Cycle: [2/8 - lr: 9.2942e-02] Error: 0.008360 (Dynamics) Cycle: [3/8 - lr: 9.1870e-02] Error: 0.007104 (Dynamics) Cycle: [4/8 - lr: 9.0729e-02] Error: 0.004836 (Dynamics) Cycle: [5/8 - lr: 8.9522e-02] Error: 0.008515 (Dynamics) Cycle: [6/8 - lr: 8.8250e-02] Error: 0.008518 (Dynamics) Cycle: [7/8 - lr: 8.6916e-02] Error: 0.008236 (Dynamics) Cycle: [8/8 - lr: 8.5521e-02] Error: 0.004738 (Control) Cycle: [1/4 - lr: 1.9692e-02] Error: 0.1379 (Control) Cycle: [2/4 - lr: 1.9599e-02] Error: 0.1403 (Control) Cycle: [3/4 - lr: 1.9495e-02] Error: 0.1123 (Control) Cycle: [4/4 - lr: 1.9379e-02] Error: 0.1116 [12/32] Best Error: (Dynamics) 0.004738/(Control) 0.111606 (Dynamics) Cycle: [1/8 - lr: 8.4069e-02] Error: 0.010618 (Dynamics) Cycle: [2/8 - lr: 8.2561e-02] Error: 0.007935 (Dynamics) Cycle: [3/8 - lr: 8.1001e-02] Error: 0.006500 (Dynamics) Cycle: [4/8 - lr: 7.9390e-02] Error: 0.006584 (Dynamics) Cycle: [5/8 - lr: 7.7731e-02] Error: 0.002766 (Dynamics) Cycle: [6/8 - lr: 7.6028e-02] Error: 0.004383 (Dynamics) Cycle: [7/8 - lr: 7.4282e-02] Error: 0.003150 (Dynamics) Cycle: [8/8 - lr: 7.2498e-02] Error: 0.005326 (Control) Cycle: [1/4 - lr: 1.9252e-02] Error: 0.1311 (Control) Cycle: [2/4 - lr: 1.9113e-02] Error: 0.0898 (Control) Cycle: [3/4 - lr: 1.8964e-02] Error: 0.0637 (Control) Cycle: [4/4 - lr: 1.8803e-02] Error: 0.0758 [13/32] Best Error: (Dynamics) 0.002766/(Control) 0.063721 (Dynamics) Cycle: [1/8 - lr: 7.0678e-02] Error: 0.004473 (Dynamics) Cycle: [2/8 - lr: 6.8824e-02] Error: 0.007433 (Dynamics) Cycle: [3/8 - lr: 6.6941e-02] Error: 0.003469 (Dynamics) Cycle: [4/8 - lr: 6.5031e-02] Error: 0.003518 (Dynamics) Cycle: [5/8 - lr: 6.3098e-02] Error: 0.004494 (Dynamics) Cycle: [6/8 - lr: 6.1144e-02] Error: 0.004412 (Dynamics) Cycle: [7/8 - lr: 5.9173e-02] Error: 0.007177 (Dynamics) Cycle: [8/8 - lr: 5.7189e-02] Error: 0.003617 (Control) Cycle: [1/4 - lr: 1.8631e-02] Error: 0.0641 (Control) Cycle: [2/4 - lr: 1.8449e-02] Error: 0.0616 (Control) Cycle: [3/4 - lr: 1.8256e-02] Error: 0.0469 (Control) Cycle: [4/4 - lr: 1.8053e-02] Error: 0.0617 [14/32] Best Error: (Dynamics) 0.002766/(Control) 0.046933 (Dynamics) Cycle: [1/8 - lr: 5.5194e-02] Error: 0.003768 (Dynamics) Cycle: [2/8 - lr: 5.3192e-02] Error: 0.005125 (Dynamics) Cycle: [3/8 - lr: 5.1186e-02] Error: 0.004564 (Dynamics) Cycle: [4/8 - lr: 4.9180e-02] Error: 0.007029 (Dynamics) Cycle: [5/8 - lr: 4.7178e-02] Error: 0.009791 (Dynamics) Cycle: [6/8 - lr: 4.5181e-02] Error: 0.006955 (Dynamics) Cycle: [7/8 - lr: 4.3195e-02] Error: 0.003950 (Dynamics) Cycle: [8/8 - lr: 4.1221e-02] Error: 0.005212 (Control) Cycle: [1/4 - lr: 1.7840e-02] Error: 0.0419 (Control) Cycle: [2/4 - lr: 1.7618e-02] Error: 0.0482 (Control) Cycle: [3/4 - lr: 1.7386e-02] Error: 0.0378 (Control) Cycle: [4/4 - lr: 1.7145e-02] Error: 0.0391 [15/32] Best Error: (Dynamics) 0.002766/(Control) 0.037761 (Dynamics) Cycle: [1/8 - lr: 3.9264e-02] Error: 0.002857 (Dynamics) Cycle: [2/8 - lr: 3.7326e-02] Error: 0.002244 (Dynamics) Cycle: [3/8 - lr: 3.5412e-02] Error: 0.002740 (Dynamics) Cycle: [4/8 - lr: 3.3523e-02] Error: 0.001380 (Dynamics) Cycle: [5/8 - lr: 3.1664e-02] Error: 0.002367 (Dynamics) Cycle: [6/8 - lr: 2.9838e-02] Error: 0.002010 (Dynamics) Cycle: [7/8 - lr: 2.8046e-02] Error: 0.002192 (Dynamics) Cycle: [8/8 - lr: 2.6293e-02] Error: 0.002703 (Control) Cycle: [1/4 - lr: 1.6895e-02] Error: 0.0339 (Control) Cycle: [2/4 - lr: 1.6637e-02] Error: 0.0314 (Control) Cycle: [3/4 - lr: 1.6371e-02] Error: 0.0317 (Control) Cycle: [4/4 - lr: 1.6097e-02] Error: 0.0290 [16/32] Best Error: (Dynamics) 0.001380/(Control) 0.029003 (Dynamics) Cycle: [1/8 - lr: 2.4582e-02] Error: 0.002062 (Dynamics) Cycle: [2/8 - lr: 2.2915e-02] Error: 0.001393 (Dynamics) Cycle: [3/8 - lr: 2.1295e-02] Error: 0.002654 (Dynamics) Cycle: [4/8 - lr: 1.9725e-02] Error: 0.001584 (Dynamics) Cycle: [5/8 - lr: 1.8207e-02] Error: 0.001202 (Dynamics) Cycle: [6/8 - lr: 1.6744e-02] Error: 0.001402 (Dynamics) Cycle: [7/8 - lr: 1.5339e-02] Error: 0.001286 (Dynamics) Cycle: [8/8 - lr: 1.3993e-02] Error: 0.000670 (Control) Cycle: [1/4 - lr: 1.5815e-02] Error: 0.0395 (Control) Cycle: [2/4 - lr: 1.5527e-02] Error: 0.0269 (Control) Cycle: [3/4 - lr: 1.5231e-02] Error: 0.0310 (Control) Cycle: [4/4 - lr: 1.4929e-02] Error: 0.0281 [17/32] Best Error: (Dynamics) 0.000670/(Control) 0.026862 (Dynamics) Cycle: [1/8 - lr: 1.2709e-02] Error: 0.000749 (Dynamics) Cycle: [2/8 - lr: 1.1490e-02] Error: 0.001233 (Dynamics) Cycle: [3/8 - lr: 1.0337e-02] Error: 0.000934 (Dynamics) Cycle: [4/8 - lr: 9.2517e-03] Error: 0.000848 (Dynamics) Cycle: [5/8 - lr: 8.2367e-03] Error: 0.000517 (Dynamics) Cycle: [6/8 - lr: 7.2933e-03] Error: 0.000334 (Dynamics) Cycle: [7/8 - lr: 6.4232e-03] Error: 0.000463 (Dynamics) Cycle: [8/8 - lr: 5.6278e-03] Error: 0.000418 (Control) Cycle: [1/4 - lr: 1.4621e-02] Error: 0.0280 (Control) Cycle: [2/4 - lr: 1.4307e-02] Error: 0.0245 (Control) Cycle: [3/4 - lr: 1.3988e-02] Error: 0.0247 (Control) Cycle: [4/4 - lr: 1.3664e-02] Error: 0.0232 [18/32] Best Error: (Dynamics) 0.000334/(Control) 0.023191 (Dynamics) Cycle: [1/8 - lr: 4.9084e-03] Error: 0.000251 (Dynamics) Cycle: [2/8 - lr: 4.2663e-03] Error: 0.000312 (Dynamics) Cycle: [3/8 - lr: 3.7026e-03] Error: 0.000282 (Dynamics) Cycle: [4/8 - lr: 3.2181e-03] Error: 0.000175 (Dynamics) Cycle: [5/8 - lr: 2.8137e-03] Error: 0.000131 (Dynamics) Cycle: [6/8 - lr: 2.4901e-03] Error: 0.000142 (Dynamics) Cycle: [7/8 - lr: 2.2478e-03] Error: 0.000099 (Dynamics) Cycle: [8/8 - lr: 2.0872e-03] Error: 0.000172 (Control) Cycle: [1/4 - lr: 1.3336e-02] Error: 0.0223 (Control) Cycle: [2/4 - lr: 1.3003e-02] Error: 0.0247 (Control) Cycle: [3/4 - lr: 1.2667e-02] Error: 0.0209 (Control) Cycle: [4/4 - lr: 1.2328e-02] Error: 0.0217 [19/32] Best Error: (Dynamics) 0.000099/(Control) 0.020869 (Dynamics) Cycle: [1/8 - lr: 2.0086e-03] Error: 0.000175 (Dynamics) Cycle: [2/8 - lr: 1.9999e-03] Error: 0.000222 (Dynamics) Cycle: [3/8 - lr: 1.9989e-03] Error: 0.000174 (Dynamics) Cycle: [4/8 - lr: 1.9970e-03] Error: 0.000165 (Dynamics) Cycle: [5/8 - lr: 1.9941e-03] Error: 0.000137 (Dynamics) Cycle: [6/8 - lr: 1.9903e-03] Error: 0.000148 (Dynamics) Cycle: [7/8 - lr: 1.9856e-03] Error: 0.000133 (Dynamics) Cycle: [8/8 - lr: 1.9800e-03] Error: 0.000121 (Control) Cycle: [1/4 - lr: 1.1985e-02] Error: 0.0258 (Control) Cycle: [2/4 - lr: 1.1640e-02] Error: 0.0208 (Control) Cycle: [3/4 - lr: 1.1294e-02] Error: 0.0198 (Control) Cycle: [4/4 - lr: 1.0945e-02] Error: 0.0197 [20/32] Best Error: (Dynamics) 0.000099/(Control) 0.019681 (Dynamics) Cycle: [1/8 - lr: 1.9734e-03] Error: 0.000144 (Dynamics) Cycle: [2/8 - lr: 1.9659e-03] Error: 0.000083 (Dynamics) Cycle: [3/8 - lr: 1.9575e-03] Error: 0.000082 (Dynamics) Cycle: [4/8 - lr: 1.9482e-03] Error: 0.000136 (Dynamics) Cycle: [5/8 - lr: 1.9380e-03] Error: 0.000115 (Dynamics) Cycle: [6/8 - lr: 1.9270e-03] Error: 0.000073 (Dynamics) Cycle: [7/8 - lr: 1.9150e-03] Error: 0.000117 (Dynamics) Cycle: [8/8 - lr: 1.9023e-03] Error: 0.000139 (Control) Cycle: [1/4 - lr: 1.0596e-02] Error: 0.0190 (Control) Cycle: [2/4 - lr: 1.0245e-02] Error: 0.0196 (Control) Cycle: [3/4 - lr: 9.8948e-03] Error: 0.0190 (Control) Cycle: [4/4 - lr: 9.5444e-03] Error: 0.0198 [21/32] Best Error: (Dynamics) 0.000073/(Control) 0.018959 (Dynamics) Cycle: [1/8 - lr: 1.8886e-03] Error: 0.000135 (Dynamics) Cycle: [2/8 - lr: 1.8741e-03] Error: 0.000176 (Dynamics) Cycle: [3/8 - lr: 1.8588e-03] Error: 0.000112 (Dynamics) Cycle: [4/8 - lr: 1.8427e-03] Error: 0.000150 (Dynamics) Cycle: [5/8 - lr: 1.8258e-03] Error: 0.000199 (Dynamics) Cycle: [6/8 - lr: 1.8082e-03] Error: 0.000216 (Dynamics) Cycle: [7/8 - lr: 1.7897e-03] Error: 0.000183 (Dynamics) Cycle: [8/8 - lr: 1.7706e-03] Error: 0.000162 (Control) Cycle: [1/4 - lr: 9.1944e-03] Error: 0.0180 (Control) Cycle: [2/4 - lr: 8.8455e-03] Error: 0.0182 (Control) Cycle: [3/4 - lr: 8.4980e-03] Error: 0.0177 (Control) Cycle: [4/4 - lr: 8.1524e-03] Error: 0.0199 [22/32] Best Error: (Dynamics) 0.000073/(Control) 0.017730 (Dynamics) Cycle: [1/8 - lr: 1.7507e-03] Error: 0.000166 (Dynamics) Cycle: [2/8 - lr: 1.7301e-03] Error: 0.000087 (Dynamics) Cycle: [3/8 - lr: 1.7088e-03] Error: 0.000127 (Dynamics) Cycle: [4/8 - lr: 1.6868e-03] Error: 0.000198 (Dynamics) Cycle: [5/8 - lr: 1.6642e-03] Error: 0.000187 (Dynamics) Cycle: [6/8 - lr: 1.6410e-03] Error: 0.000125 (Dynamics) Cycle: [7/8 - lr: 1.6171e-03] Error: 0.000097 (Dynamics) Cycle: [8/8 - lr: 1.5927e-03] Error: 0.000066 (Control) Cycle: [1/4 - lr: 7.8090e-03] Error: 0.0176 (Control) Cycle: [2/4 - lr: 7.4683e-03] Error: 0.0172 (Control) Cycle: [3/4 - lr: 7.1307e-03] Error: 0.0170 (Control) Cycle: [4/4 - lr: 6.7967e-03] Error: 0.0166 [23/32] Best Error: (Dynamics) 0.000066/(Control) 0.016637 (Dynamics) Cycle: [1/8 - lr: 1.5678e-03] Error: 0.000089 (Dynamics) Cycle: [2/8 - lr: 1.5423e-03] Error: 0.000112 (Dynamics) Cycle: [3/8 - lr: 1.5163e-03] Error: 0.000087 (Dynamics) Cycle: [4/8 - lr: 1.4898e-03] Error: 0.000122 (Dynamics) Cycle: [5/8 - lr: 1.4628e-03] Error: 0.000119 (Dynamics) Cycle: [6/8 - lr: 1.4354e-03] Error: 0.000155 (Dynamics) Cycle: [7/8 - lr: 1.4076e-03] Error: 0.000264 (Dynamics) Cycle: [8/8 - lr: 1.3795e-03] Error: 0.000272 (Control) Cycle: [1/4 - lr: 6.4666e-03] Error: 0.0164 (Control) Cycle: [2/4 - lr: 6.1408e-03] Error: 0.0163 (Control) Cycle: [3/4 - lr: 5.8198e-03] Error: 0.0159 (Control) Cycle: [4/4 - lr: 5.5039e-03] Error: 0.0165 [24/32] Best Error: (Dynamics) 0.000066/(Control) 0.015885 (Dynamics) Cycle: [1/8 - lr: 1.3509e-03] Error: 0.000164 (Dynamics) Cycle: [2/8 - lr: 1.3221e-03] Error: 0.000148 (Dynamics) Cycle: [3/8 - lr: 1.2929e-03] Error: 0.000091 (Dynamics) Cycle: [4/8 - lr: 1.2634e-03] Error: 0.000084 (Dynamics) Cycle: [5/8 - lr: 1.2337e-03] Error: 0.000082 (Dynamics) Cycle: [6/8 - lr: 1.2038e-03] Error: 0.000088 (Dynamics) Cycle: [7/8 - lr: 1.1737e-03] Error: 0.000101 (Dynamics) Cycle: [8/8 - lr: 1.1435e-03] Error: 0.000102 (Control) Cycle: [1/4 - lr: 5.1935e-03] Error: 0.0170 (Control) Cycle: [2/4 - lr: 4.8891e-03] Error: 0.0156 (Control) Cycle: [3/4 - lr: 4.5909e-03] Error: 0.0162 (Control) Cycle: [4/4 - lr: 4.2994e-03] Error: 0.0154 [25/32] Best Error: (Dynamics) 0.000066/(Control) 0.015429 (Dynamics) Cycle: [1/8 - lr: 1.1131e-03] Error: 0.000086 (Dynamics) Cycle: [2/8 - lr: 1.0826e-03] Error: 0.000066 (Dynamics) Cycle: [3/8 - lr: 1.0520e-03] Error: 0.000073 (Dynamics) Cycle: [4/8 - lr: 1.0213e-03] Error: 0.000059 (Dynamics) Cycle: [5/8 - lr: 9.9068e-04] Error: 0.000091 (Dynamics) Cycle: [6/8 - lr: 9.6004e-04] Error: 0.000054 (Dynamics) Cycle: [7/8 - lr: 9.2943e-04] Error: 0.000098 (Dynamics) Cycle: [8/8 - lr: 8.9889e-04] Error: 0.000094 (Control) Cycle: [1/4 - lr: 4.0149e-03] Error: 0.0154 (Control) Cycle: [2/4 - lr: 3.7377e-03] Error: 0.0156 (Control) Cycle: [3/4 - lr: 3.4683e-03] Error: 0.0157 (Control) Cycle: [4/4 - lr: 3.2069e-03] Error: 0.0153 [26/32] Best Error: (Dynamics) 0.000054/(Control) 0.015268 (Dynamics) Cycle: [1/8 - lr: 8.6845e-04] Error: 0.000074 (Dynamics) Cycle: [2/8 - lr: 8.3813e-04] Error: 0.000093 (Dynamics) Cycle: [3/8 - lr: 8.0796e-04] Error: 0.000070 (Dynamics) Cycle: [4/8 - lr: 7.7797e-04] Error: 0.000067 (Dynamics) Cycle: [5/8 - lr: 7.4819e-04] Error: 0.000056 (Dynamics) Cycle: [6/8 - lr: 7.1864e-04] Error: 0.000070 (Dynamics) Cycle: [7/8 - lr: 6.8937e-04] Error: 0.000054 (Dynamics) Cycle: [8/8 - lr: 6.6038e-04] Error: 0.000037 (Control) Cycle: [1/4 - lr: 2.9538e-03] Error: 0.0151 (Control) Cycle: [2/4 - lr: 2.7094e-03] Error: 0.0151 (Control) Cycle: [3/4 - lr: 2.4739e-03] Error: 0.0152 (Control) Cycle: [4/4 - lr: 2.2477e-03] Error: 0.0152 [27/32] Best Error: (Dynamics) 0.000037/(Control) 0.015056 (Dynamics) Cycle: [1/8 - lr: 6.3171e-04] Error: 0.000022 (Dynamics) Cycle: [2/8 - lr: 6.0339e-04] Error: 0.000022 (Dynamics) Cycle: [3/8 - lr: 5.7544e-04] Error: 0.000016 (Dynamics) Cycle: [4/8 - lr: 5.4789e-04] Error: 0.000031 (Dynamics) Cycle: [5/8 - lr: 5.2077e-04] Error: 0.000038 (Dynamics) Cycle: [6/8 - lr: 4.9409e-04] Error: 0.000030 (Dynamics) Cycle: [7/8 - lr: 4.6790e-04] Error: 0.000015 (Dynamics) Cycle: [8/8 - lr: 4.4220e-04] Error: 0.000026 (Control) Cycle: [1/4 - lr: 2.0310e-03] Error: 0.0153 (Control) Cycle: [2/4 - lr: 1.8242e-03] Error: 0.0149 (Control) Cycle: [3/4 - lr: 1.6273e-03] Error: 0.0147 (Control) Cycle: [4/4 - lr: 1.4408e-03] Error: 0.0147 [28/32] Best Error: (Dynamics) 0.000015/(Control) 0.014693 (Dynamics) Cycle: [1/8 - lr: 4.1702e-04] Error: 0.000038 (Dynamics) Cycle: [2/8 - lr: 3.9240e-04] Error: 0.000026 (Dynamics) Cycle: [3/8 - lr: 3.6834e-04] Error: 0.000017 (Dynamics) Cycle: [4/8 - lr: 3.4488e-04] Error: 0.000037 (Dynamics) Cycle: [5/8 - lr: 3.2204e-04] Error: 0.000041 (Dynamics) Cycle: [6/8 - lr: 2.9983e-04] Error: 0.000037 (Dynamics) Cycle: [7/8 - lr: 2.7828e-04] Error: 0.000036 (Dynamics) Cycle: [8/8 - lr: 2.5741e-04] Error: 0.000024 (Control) Cycle: [1/4 - lr: 1.2648e-03] Error: 0.0147 (Control) Cycle: [2/4 - lr: 1.0995e-03] Error: 0.0147 (Control) Cycle: [3/4 - lr: 9.4518e-04] Error: 0.0147 (Control) Cycle: [4/4 - lr: 8.0198e-04] Error: 0.0146 [29/32] Best Error: (Dynamics) 0.000015/(Control) 0.014644 (Dynamics) Cycle: [1/8 - lr: 2.3724e-04] Error: 0.000027 (Dynamics) Cycle: [2/8 - lr: 2.1778e-04] Error: 0.000021 (Dynamics) Cycle: [3/8 - lr: 1.9906e-04] Error: 0.000010 (Dynamics) Cycle: [4/8 - lr: 1.8109e-04] Error: 0.000011 (Dynamics) Cycle: [5/8 - lr: 1.6389e-04] Error: 0.000009 (Dynamics) Cycle: [6/8 - lr: 1.4748e-04] Error: 0.000010 (Dynamics) Cycle: [7/8 - lr: 1.3187e-04] Error: 0.000005 (Dynamics) Cycle: [8/8 - lr: 1.1708e-04] Error: 0.000004 (Control) Cycle: [1/4 - lr: 6.7008e-04] Error: 0.0146 (Control) Cycle: [2/4 - lr: 5.4965e-04] Error: 0.0147 (Control) Cycle: [3/4 - lr: 4.4084e-04] Error: 0.0146 (Control) Cycle: [4/4 - lr: 3.4378e-04] Error: 0.0146 [30/32] Best Error: (Dynamics) 0.000004/(Control) 0.014585 (Dynamics) Cycle: [1/8 - lr: 1.0311e-04] Error: 0.000006 (Dynamics) Cycle: [2/8 - lr: 8.9989e-05] Error: 0.000007 (Dynamics) Cycle: [3/8 - lr: 7.7723e-05] Error: 0.000005 (Dynamics) Cycle: [4/8 - lr: 6.6323e-05] Error: 0.000006 (Dynamics) Cycle: [5/8 - lr: 5.5801e-05] Error: 0.000007 (Dynamics) Cycle: [6/8 - lr: 4.6166e-05] Error: 0.000008 (Dynamics) Cycle: [7/8 - lr: 3.7428e-05] Error: 0.000005 (Dynamics) Cycle: [8/8 - lr: 2.9595e-05] Error: 0.000003 (Control) Cycle: [1/4 - lr: 2.5859e-04] Error: 0.0146 (Control) Cycle: [2/4 - lr: 1.8538e-04] Error: 0.0146 (Control) Cycle: [3/4 - lr: 1.2423e-04] Error: 0.0145 (Control) Cycle: [4/4 - lr: 7.5215e-05] Error: 0.0145 [31/32] Best Error: (Dynamics) 0.000003/(Control) 0.014533 (Dynamics) Cycle: [1/8 - lr: 2.2674e-05] Error: 0.000002 (Dynamics) Cycle: [2/8 - lr: 1.6671e-05] Error: 0.000002 (Dynamics) Cycle: [3/8 - lr: 1.1593e-05] Error: 0.000001 (Dynamics) Cycle: [4/8 - lr: 7.4438e-06] Error: 0.000000 (Dynamics) Cycle: [5/8 - lr: 4.2275e-06] Error: 0.000000 (Dynamics) Cycle: [6/8 - lr: 1.9473e-06] Error: 0.000000 (Dynamics) Cycle: [7/8 - lr: 6.0516e-07] Error: 0.000000 (Dynamics) Cycle: [8/8 - lr: 2.0240e-07] Error: 0.000000 (Control) Cycle: [1/4 - lr: 3.8403e-05] Error: 0.0145 (Control) Cycle: [2/4 - lr: 1.3837e-05] Error: 0.0145 (Control) Cycle: [3/4 - lr: 1.5467e-06] Error: 0.0145 (Control) Cycle: [4/4 - lr: 1.5467e-06] Error: 0.0145 [32/32] Best Error: (Dynamics) 0.000000/(Control) 0.014524
Testing the Network¶
We first look at the dynamics matrix learned by the dynamics network.
dynamics_network.load_state_dict(best_dynamics_parameters)
control_network.load_state_dict(best_control_parameters)
dynamics_network.eval()
control_network.eval()
learned_matrix = dynamics_network.state_dict()['internal_net.0.weight']
learned_bias = dynamics_network.state_dict()['internal_net.0.bias']
print(f"Best dynamics matrix: {learned_matrix}, mean absolute error: {torch.mean(torch.abs(ideal_matrix - learned_matrix)).item():.6f}")
print(f"Best dynamics bias: {learned_bias}, mean absolute error: {torch.mean(torch.abs(ideal_bias - learned_bias)).item():.6f}")
Best dynamics matrix: tensor([[ 1.0136e-07, 1.0000e+00], [-1.0000e+00, -5.0000e-01]]), mean absolute error: 0.000000 Best dynamics bias: tensor([ 4.7776e-08, -4.1761e-08]), mean absolute error: 0.000000
We see that the matrix closely matches the expected dynamics matrix, and furthermore, the bias is zero as expected for this system.
Taking a look at the trajectory below, we see that this matrix translates to a near-perfect match to the reference trajectory. We further see that the control network has learned the appropriate control function to reach the target state on the learned dynamics.
_, _, learned_dynamics_states, learned_dynamics_times, _ = current_integrator.apply(learned_dynamics_rhs, initial_state, initial_time, final_time, initial_timestep, {'atol': atol, 'rtol': rtol})
ig_ref_position, axes_ref_position = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK7(8) - SHA Position Ref.")
fig_ref_velocity, axes_ref_velocity = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(sha_states_ref, sha_times_ref)], method_label="RK7(8) - SHA Velocity Ref.")
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,0], j) for i, j in zip(learned_dynamics_states, learned_dynamics_times)], axes=axes_ref_position, method_label="RK7(8) - SHA Position Opt.")
fig_ref_position.suptitle("Position")
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,1], j) for i, j in zip(learned_dynamics_states, learned_dynamics_times)], axes=axes_ref_velocity, method_label="RK7(8) - SHA Velocity Opt.")
fig_ref_velocity.suptitle("Velocity")
_, _, integrated_test_states, integrated_test_times, _ = current_integrator.apply(controlled_dynamics_rhs, test_state_dataset, initial_time, final_time, initial_timestep, {'atol': atol, 'rtol': rtol})
print("Training results on test data:")
print(f"\tMean position at t={final_time.item():.2f}s: {integrated_test_states[-1, ..., 0].mean().item():.4e} ± {integrated_test_states[-1, ..., 0].std().item():.4e}")
print(f"\tMean position error at t={final_time.item():.2f}s: {(target_state - integrated_test_states)[-1, ..., 0].mean().item():.4e} ± {(target_state - integrated_test_states)[-1, ..., 0].std().item():.4e}")
print()
print(f"\tMean velocity at t={final_time.item():.2f}s: {integrated_test_states[-1, ..., 1].mean().item():.4e} ± {integrated_test_states[-1, ..., 1].std().item():.4e}")
print(f"\tMean velocity error at t={final_time.item():.2f}s: {(target_state - integrated_test_states)[-1, ..., 1].mean().item():.4e} ± {(target_state - integrated_test_states)[-1, ..., 1].std().item():.4e}")
neuralode.plot.trajectory.plot_trajectory([(i[...,0], j) for i, j in zip(integrated_test_states, integrated_test_times)], method_label=None)
neuralode.plot.trajectory.plot_trajectory([(i[...,1], j) for i, j in zip(integrated_test_states, integrated_test_times)], method_label=None)
fig_ref_position_closeup, axes_ref_position_closeup = neuralode.plot.trajectory.plot_trajectory([(i[0], j) for i, j in zip(dha_states_optimised, dha_times_optimised)], method_label=None)
fig_ref_velocity_closeup, axes_ref_velocity_closeup = neuralode.plot.trajectory.plot_trajectory([(i[1], j) for i, j in zip(dha_states_optimised, dha_times_optimised)], method_label=None)
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,0], j) for i, j in zip(integrated_test_states, integrated_test_times)], axes=axes_ref_position_closeup, method_label=None)
fig_ref_position_closeup.suptitle("Close-Up of Position")
axes_ref_position_closeup[0].set_xlim(final_time.item() - 5.0, final_time.item())
axes_ref_position_closeup[0].set_ylim(target_state[0].item() - 1.0, target_state[0].item() + 1.0)
_ = neuralode.plot.trajectory.plot_trajectory([(i[...,1], j) for i, j in zip(integrated_test_states, integrated_test_times)], axes=axes_ref_velocity_closeup, method_label=None)
fig_ref_velocity_closeup.suptitle("Close-Up of Velocity")
axes_ref_velocity_closeup[0].set_xlim(final_time.item() - 5.0, final_time.item())
axes_ref_velocity_closeup[0].set_ylim(target_state[1].item() - 1.0, target_state[1].item() + 1.0)
Training results on test data: Mean position at t=25.00s: 9.4453e-01 ± 9.8263e-12 Mean position error at t=25.00s: 5.5466e-02 ± 9.8263e-12 Mean velocity at t=25.00s: -1.4351e-07 ± 6.6683e-12 Mean velocity error at t=25.00s: 1.4351e-07 ± 6.6683e-12
C:\Users\ekin4\PycharmProjects\ReCoDE-NeuralODEs\neuralode\plot\trajectory.py:79: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. ax.legend()
(-1.0, 1.0)