import torch
import numpy as np
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
from tqdm import tqdm
#import holoviews as hv
TORCH_EXP_MAP = {
'torch.exp': 'e^',
'torch.sin': 'sin',
'torch.cos': 'cos',
'torch.tanh': 'tanh',
'torch.log': 'log',
'torch.sqrt': 'sqrt',
}
EXP_TORCH_MAP = {v: k for k, v in TORCH_EXP_MAP.items()}
Run to view results
def expression_to_torch(expression: str):
expression = expression.replace('**', '^')
for base_function, latex_function in EXP_TORCH_MAP.items():
expression = expression.replace(base_function, latex_function)
return expression
def expression_to_latex(expression: str):
import sympy as sp
expression = expression.replace('**', '^')
latex_expression = sp.latex(sp.sympify(expression))
return latex_expression
def expression_to_ODE(expression: str):
'''
Return function to evaluate y_x - F(x,y)
'''
assert 'y_x' not in expression, 'F must be function of x and y, not y_x'
torch_expression = expression_to_torch(expression = expression)
def ODE(x, y, y_x):
eq = y_x - eval(torch_expression)
return eq
return ODE
Run to view results
Settngs
ODE: y'=...
My solution (optional)
expression
my_solution
Initial condition
x_0
y_0
PINN
EPOCHS
100 / 5000
# Input preprocess
x_0 = float(x_0)
y_0 = float(y_0)
latex_expression = expression_to_latex(expression=expression)
ODE = expression_to_ODE(expression = expression)
Run to view results
if my_solution:
torch_my_solution = expression_to_torch(expression = my_solution)
print(torch_my_solution)
def my_solution_fun(x):
return eval(torch_my_solution)
else:
my_solution_fun = None
Run to view results
from IPython.display import display, Markdown, Latex
Markdown(f"ODE: $y' = {latex_expression}$. Boundary condition: $y({x_0}) = {y_0}$")
Run to view results
PINN design
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hidden1 = nn.Linear(1, 100, bias=True)
self.act1 = nn.Tanh()
self.hidden2 = nn.Linear(100, 100, bias=True)
self.act2 = nn.Tanh()
self.hidden3 = nn.Linear(100, 10, bias=True)
self.act3 = nn.Tanh()
self.output = nn.Linear(10, 1, bias=True)
return
def forward(self, x):
x = self.hidden1(x)
x = self.act1(x)
x = self.hidden2(x)
x = self.act2(x)
x = self.hidden3(x)
x = self.act3(x)
x = self.output(x)
return x
model = Net()
print(model)
Run to view results
Fit
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
# Set up boundary condition
X_BOUNDARY = torch.from_numpy(np.array([x_0])).float()
Y_BOUNDARY = torch.from_numpy(np.array([y_0])).float()
# Discretize x axis
N_ODE = 1000
X_ODE = np.linspace(-1, 1, N_ODE)
ZEROS_ODE = np.zeros(N_ODE)
loss_epochs = pd.DataFrame(np.nan, columns = ['ode_loss', 'bc_loss'], index=range(EPOCHS))
solution_evolution = []
for epoch in tqdm(range(EPOCHS)):
optimizer.zero_grad()
# Loss boundary
loss_boundary = loss(model(X_BOUNDARY), Y_BOUNDARY.float())
# Loss ODE
x_ode = Variable(torch.from_numpy(X_ODE.reshape(-1, 1)).float(), requires_grad=True)
zeros_ode = Variable(torch.from_numpy(ZEROS_ODE.reshape(-1, 1)).float(), requires_grad=True)
u = model(x_ode)
u_x = torch.autograd.grad(model(x_ode).sum(), x_ode, create_graph =True)[0]
eq = ODE(x=x_ode, y_x=u_x, y = u)
loss_ode = loss(eq, zeros_ode)
total_loss = loss_boundary + loss_ode
total_loss.backward()
optimizer.step()
loss_epochs.loc[epoch, 'ode_loss'] = loss_ode.detach().numpy()
loss_epochs.loc[epoch, 'bc_loss'] = loss_boundary.detach().numpy()
solution_evolution.append(u.detach().numpy())
Run to view results
Visualisation
plt.figure(figsize=(10, 5))
plt.suptitle(r"PINN diagnostic: $y'=" + latex_expression + r"$")
solution_plot = plt.subplot2grid(shape=(2, 3), loc=(0, 0), colspan=2, rowspan=2)
bc_loss_plot = plt.subplot2grid(shape=(2, 3), loc=(0, 2), colspan=1, rowspan=1)
ode_loss_plot = plt.subplot2grid(shape=(2, 3), loc=(1, 2), colspan=1, rowspan=1)
for v in solution_evolution:
solution_plot.plot(X_ODE, v, color='orange', alpha = 0.1)
solution_plot.plot(X_ODE, solution_evolution[0], color='green', alpha = 1, label='Initial solution', linewidth=3)
solution_plot.plot(X_ODE, solution_evolution[-1], color='red', alpha = 1, label='Final solution', linewidth=3)
if my_solution_fun:
solution_plot.plot(X_ODE, my_solution_fun(torch.from_numpy(X_ODE).float()),
color='black',
alpha = 1,
label='My solution',
linewidth=2,
linestyle = ':')
solution_plot.scatter(x = x_0, y = y_0, s = 100 ,
label = 'Boundary condition',
marker='x',
color = 'blue',
zorder = 100)
solution_plot.legend()
solution_plot.grid(linestyle=':')
solution_plot.set_xlabel('x')
solution_plot.set_ylabel('y')
ode_loss_plot.plot(loss_epochs.index, loss_epochs['ode_loss'], color='red', label='ODE Loss')
ode_loss_plot.legend()
ode_loss_plot.grid(linestyle=':')
ode_loss_plot.set_yscale('log')
ode_loss_plot.set_xlabel('Epochs')
bc_loss_plot.plot(loss_epochs.index, loss_epochs['bc_loss'], color='blue', label='Boundary Loss')
bc_loss_plot.legend()
bc_loss_plot.grid(linestyle=':')
bc_loss_plot.set_yscale('log')
bc_loss_plot.set_xlabel('Epochs')
plt.tight_layout()
plt.show()
Run to view results