%config InlineBackend.figure_format = "svg"
import ase.io
import matplotlib.pyplot as plt
import numpy as np
import jax
from ase.visualize.plot import plot_atoms
from bop_jax import bop as bop_jax # JAX 버전
from bop_rs import bop as bop_rs # 검증용 Rust 버전
jax.config.update("jax_enable_x64", True) # float64 허용
Pt_3atoms = ase.io.read("training_3atoms.traj")
r_cut = 6.5
centers = Pt_3atoms.positions # Center atom들의 위치
neighbors = np.array([centers[[1, 2]], centers[[0, 2]], centers[[0, 1]]]) # Neighbor atom들의 위치
# Pre-calculated bop parameters
bop_params = np.array([
# A a alpha B beta h lambda sigma
[8.45925165024959, 0.07248156460734, 3.33119467517366, 4.67977296766391, 1.24501156141294, 1.90843223151883, 0.87041937045874, 0.20909290922816],
[8.46842563162531, 0.07209710360845, 3.33385847596619, 4.68345712045158, 1.25002282124036, 1.91056105318887, 0.87115826698344, 0.20813240220751],
[8.46717196267474, 0.07214868604406, 3.33349396868284, 4.68295356541593, 1.24933854987283, 1.91027013440084, 0.87105869827311, 0.20826332815426],
])
plot_atoms(Pt_3atoms);
# atom-wise energy
print(f"===== JAX version =====")
total_energy = 0
for idx, (c, n, b) in enumerate(zip(centers, neighbors, bop_params)):
atomic_energy = bop_jax.calculate_energy(c, n, b, eps=1e-12, r_cut=r_cut)
total_energy += atomic_energy
print(f"E_{idx} = {atomic_energy:.5f}")
print(f"E_tot = {total_energy:.5f}")
from jax import vmap
from functools import partial
print(f"===== JAX version (vmap) =====")
bop_func = partial(bop_jax.calculate_energy, eps=1e-12, r_cut=r_cut)
atomic_energies = vmap(bop_func)(centers, neighbors, bop_params)
print(f"E_1, E_2, E_3: {atomic_energies}")
print(f"E_tot = {atomic_energies.sum():.5f}")
# atom-wise energy
print(f"===== Rust version =====")
total_energy = 0
for idx, (c, n, b) in enumerate(zip(centers, neighbors, bop_params)):
atomic_energy = bop_rs.calculate_energy_atom(c, n, b, r_cut)
total_energy += atomic_energy
print(f"E_{idx} = {atomic_energy:.5f}")
print(f"E_tot = {total_energy:.5f}")
num_neighbors = 100
for i in range(30):
center = np.random.rand(3)
neighbors = np.tile(center, (num_neighbors, 1)) + np.random.rand(num_neighbors, 3) * 8 + 1.0
bop_params = np.array([8.46, 0.07, 3.33, 4.68, 1.25, 1.91, 0.87, 0.20])
atomic_energy_jax = bop_jax.calculate_energy(center, neighbors, bop_params, eps=1e-12, r_cut=r_cut)
atomic_energy_rust = bop_rs.calculate_energy_atom(center, neighbors, bop_params, r_cut)
print(f"===== Trial {i + 1} =====")
print(f"JAX: \tE0 = {atomic_energy_jax:.9f}")
print(f"Rust: \tE0 = {atomic_energy_rust:.9f}")
print(f"Difference: {atomic_energy_rust - atomic_energy_jax}")
if (atomic_energy_rust - atomic_energy_jax) < 1e-9:
print("OK\n")
else:
print("FAILED\n")