%config InlineBackend.figure_format = "svg"
import ase.io
import matplotlib.pyplot as plt
import numpy as np
import jax
from ase.visualize.plot import plot_atoms
from bop_jax import bop as bop_jax # JAX 버전
from bop_rs import bop as bop_rs # 검증용 Rust 버전
jax.config.update("jax_enable_x64", True) # float64 허용
Pt_3atoms = ase.io.read("training_3atoms.traj")
r_cut = 6.5
centers = Pt_3atoms.positions # Center atom들의 위치
neighbors = np.array([centers[[1, 2]], centers[[0, 2]], centers[[0, 1]]]) # Neighbor atom들의 위치
# Pre-calculated bop parameters
bop_params = np.array([
# A a alpha B beta h lambda sigma
[8.45925165024959, 0.07248156460734, 3.33119467517366, 4.67977296766391, 1.24501156141294, 1.90843223151883, 0.87041937045874, 0.20909290922816],
[8.46842563162531, 0.07209710360845, 3.33385847596619, 4.68345712045158, 1.25002282124036, 1.91056105318887, 0.87115826698344, 0.20813240220751],
[8.46717196267474, 0.07214868604406, 3.33349396868284, 4.68295356541593, 1.24933854987283, 1.91027013440084, 0.87105869827311, 0.20826332815426],
])
plot_atoms(Pt_3atoms);
# atom-wise energy
print(f"===== JAX version =====")
total_energy = 0
for idx, (c, n, b) in enumerate(zip(centers, neighbors, bop_params)):
atomic_energy = bop_jax.calculate_energy(c, n, b, eps=1e-12, r_cut=r_cut)
total_energy += atomic_energy
print(f"E_{idx} = {atomic_energy:.5f}")
print(f"E_tot = {total_energy:.5f}")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
===== JAX version =====
E_0 = -2.83641
E_1 = -2.76670
E_2 = -2.77614
E_tot = -8.37925
from jax import vmap
from functools import partial
print(f"===== JAX version (vmap) =====")
bop_func = partial(bop_jax.calculate_energy, eps=1e-12, r_cut=r_cut)
atomic_energies = vmap(bop_func)(centers, neighbors, bop_params)
print(f"E_1, E_2, E_3: {atomic_energies}")
print(f"E_tot = {atomic_energies.sum():.5f}")
===== JAX version (vmap) =====
E_1, E_2, E_3: [-2.83640736 -2.76670081 -2.7761413 ]
E_tot = -8.37925
# atom-wise energy
print(f"===== Rust version =====")
total_energy = 0
for idx, (c, n, b) in enumerate(zip(centers, neighbors, bop_params)):
atomic_energy = bop_rs.calculate_energy_atom(c, n, b, r_cut)
total_energy += atomic_energy
print(f"E_{idx} = {atomic_energy:.5f}")
print(f"E_tot = {total_energy:.5f}")
===== Rust version =====
E_0 = -2.83641
E_1 = -2.76670
E_2 = -2.77614
E_tot = -8.37925
num_neighbors = 100
for i in range(30):
center = np.random.rand(3)
neighbors = np.tile(center, (num_neighbors, 1)) + np.random.rand(num_neighbors, 3) * 8 + 1.0
bop_params = np.array([8.46, 0.07, 3.33, 4.68, 1.25, 1.91, 0.87, 0.20])
atomic_energy_jax = bop_jax.calculate_energy(center, neighbors, bop_params, eps=1e-12, r_cut=r_cut)
atomic_energy_rust = bop_rs.calculate_energy_atom(center, neighbors, bop_params, r_cut)
print(f"===== Trial {i + 1} =====")
print(f"JAX: \tE0 = {atomic_energy_jax:.9f}")
print(f"Rust: \tE0 = {atomic_energy_rust:.9f}")
print(f"Difference: {atomic_energy_rust - atomic_energy_jax}")
if (atomic_energy_rust - atomic_energy_jax) < 1e-9:
print("OK\n")
else:
print("FAILED\n")
===== Trial 1 =====
JAX: E0 = -0.569821372
Rust: E0 = -0.569821372
Difference: 1.8407497748285095e-13
OK
===== Trial 2 =====
JAX: E0 = -0.595373428
Rust: E0 = -0.595373428
Difference: 1.474376176702208e-13
OK
===== Trial 3 =====
JAX: E0 = -0.361962476
Rust: E0 = -0.361962476
Difference: 1.8884893648873913e-13
OK
===== Trial 4 =====
JAX: E0 = -1.571904625
Rust: E0 = -1.571904625
Difference: 2.7533531010703882e-14
OK
===== Trial 5 =====
JAX: E0 = -0.138897935
Rust: E0 = -0.138897935
Difference: 1.0871858968641845e-13
OK
===== Trial 6 =====
JAX: E0 = -0.153499073
Rust: E0 = -0.153499073
Difference: 4.2299497238218464e-14
OK
===== Trial 7 =====
JAX: E0 = -0.532255234
Rust: E0 = -0.532255234
Difference: 1.454392162258955e-13
OK
===== Trial 8 =====
JAX: E0 = -0.294012081
Rust: E0 = -0.294012081
Difference: 7.877032359715486e-14
OK
===== Trial 9 =====
JAX: E0 = -0.319658504
Rust: E0 = -0.319658504
Difference: 8.765210779415611e-14
OK
===== Trial 10 =====
JAX: E0 = -1.034761482
Rust: E0 = -1.034761482
Difference: 1.9673151996357774e-13
OK
===== Trial 11 =====
JAX: E0 = -0.845069568
Rust: E0 = -0.845069568
Difference: 1.7341683644644945e-13
OK
===== Trial 12 =====
JAX: E0 = -0.143740470
Rust: E0 = -0.143740470
Difference: 5.5538906806873456e-14
OK
===== Trial 13 =====
JAX: E0 = -0.477497663
Rust: E0 = -0.477497663
Difference: 8.64863736182997e-14
OK
===== Trial 14 =====
JAX: E0 = -0.823570602
Rust: E0 = -0.823570602
Difference: 3.4072744625746054e-13
OK
===== Trial 15 =====
JAX: E0 = -0.249483465
Rust: E0 = -0.249483465
Difference: 9.525713551283843e-14
OK
===== Trial 16 =====
JAX: E0 = -0.966488637
Rust: E0 = -0.966488637
Difference: 6.405986852087153e-14
OK
===== Trial 17 =====
JAX: E0 = -0.520015448
Rust: E0 = -0.520015448
Difference: 1.8818280267396403e-13
OK
===== Trial 18 =====
JAX: E0 = -0.639614003
Rust: E0 = -0.639614003
Difference: 1.9817480989559044e-13
OK
===== Trial 19 =====
JAX: E0 = -0.874637103
Rust: E0 = -0.874637103
Difference: 3.490541189421492e-13
OK
===== Trial 20 =====
JAX: E0 = -0.764815010
Rust: E0 = -0.764815010
Difference: 2.347011474057581e-13
OK
===== Trial 21 =====
JAX: E0 = -1.076461176
Rust: E0 = -1.076461176
Difference: 2.8332891588433995e-13
OK
===== Trial 22 =====
JAX: E0 = -0.633589016
Rust: E0 = -0.633589016
Difference: 1.9972912213006566e-13
OK
===== Trial 23 =====
JAX: E0 = -0.176238506
Rust: E0 = -0.176238506
Difference: 8.174017018802715e-14
OK
===== Trial 24 =====
JAX: E0 = -0.529420924
Rust: E0 = -0.529420924
Difference: 1.475486399726833e-13
OK
===== Trial 25 =====
JAX: E0 = -0.237769238
Rust: E0 = -0.237769238
Difference: 1.8041124150158794e-13
OK
===== Trial 26 =====
JAX: E0 = -0.431682202
Rust: E0 = -0.431682202
Difference: 1.268984917146554e-13
OK
===== Trial 27 =====
JAX: E0 = -0.791213975
Rust: E0 = -0.791213975
Difference: 1.8773871346411397e-13
OK
===== Trial 28 =====
JAX: E0 = -1.475962655
Rust: E0 = -1.475962655
Difference: 4.276579090856103e-13
OK
===== Trial 29 =====
JAX: E0 = -0.722553025
Rust: E0 = -0.722553025
Difference: 3.0531133177191805e-13
OK
===== Trial 30 =====
JAX: E0 = -0.356022409
Rust: E0 = -0.356022409
Difference: 9.259260025373806e-14
OK