import tensorflow as tf
import numpy as np
from sklearn.utils import shuffle
from sklearn.preprocessing import MinMaxScaler
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
%matplotlib
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import itertools
import matplotlib.pyplot as plt
import os.path
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input, Lambda, Dense, Flatten,Conv2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator,load_img
from tensorflow.keras.models import Sequential
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
from tensorflow.keras.layers import MaxPooling2D
import numpy as np
from enum import Enum
from random import random, randint, choice
from collections import defaultdict
from typing import List, Callable, Tuple
from copy import deepcopy
# Generate Gridworld Function
def gen_gridworld(dim:int, prob:float, solvable: bool = None, is_terrain_enabled: bool = True, source: Tuple[int, int] = None, target: Tuple[int, int] = None) -> (List[List[float]], Tuple[int, int]):
"""generate gridworld implemenation
generates a dim by dim gridworld where the (0,0) is source
every cell besides the source cell has probability of "prob" to be blocked
if a cell is not blocked then it is randomly assigned a terrain type out of flat, hilly, and forest
at the end a random not blocked cell will be chosen as the target
NOTE: take of note of what is stored at each cell. The first value indicates what cell type it is and the second value indicates if the cell contains the target
Args:
dim (int) : the dimensions of the square grid
prob (float): the probability of a square being blocked
solvability (bool): returns a grid world that is either solvable or unsolvable depending on this param
Returns:
list[list[int]]: a dim by dim square grid with blocked and unblock squares
Tuple[int, int]: a tuple containing the coordinates of target cell
"""
grid = np.array([[0.0] * dim for _ in range(dim)])
while True:
for i in range(dim):
for j in range(dim):
p = random()
if p >= prob:
if is_terrain_enabled:
t = randint(1,3)
if t == 1:
grid[i][j] = 0.8 # Terrain.FLAT
elif t == 2:
grid[i][j] = 0.5 # Terrain.HILLY
elif t == 3:
grid[i][j] = 0.2 # Terrain.FOREST
else:
grid[i][j] = 0.0
else:
grid[i][j] = 1.0 # blocked cell
# Place the source and target
if source:
grid[source[0]][source[1]] = 0.0
if target:
grid[target[0]][target[1]] = 0.0
unblocked = [(i,j) for i in range(dim) for j in range(dim) if grid[i][j] != 1.0]
# no unblocked cells
if not unblocked:
continue
# choose q random source and target, note source and target can be the same cell
if not source:
source = choice(unblocked)
if not target:
source = choice(unblocked)
# pretty_print_explored_region(grid)
path, *_ = a_star(grid, source, target)
# print(path)
if solvable is None or solvable == bool(path):
break
return grid, source, target
# Heuristics
def manhattan_heuristic(x1: int, y1: int, x2: int, y2: int):
return abs(x1 - x2) + abs(y1 - y2)
# A*
from collections import defaultdict
import heapq
from typing import Tuple, Callable, List, Dict
import numpy as np
def a_star(
grid: List[List[float]],
source: Tuple[int, int] = (0, 0),
goal: Tuple[int, int] = None,
h: Callable = manhattan_heuristic,
W: int = 1,
) -> (List[Tuple[int, int]], int, List[Tuple[int,int]], int, int):
"""A* Implementation
Used to find optimal shortest path from source to goal on a full gridworld (i.e with complete information of the block cells).
Args:
grid (2d int list): A square grid with 0 representing empty square and 1 representing a blocked square
source (tuple) Optional: control where to start the A* search from
goal (tuple): control where the target of the A* search algorithm
h (func) Optional: heuristic function and the default heuristic function is the manhattan distance
W (int) Optional: weight on the heuristic function
Returns:
list(tuple(int, int)): the path from the source node to the goal node which consists of coordinate points (i.e. x,y) in the form of tuples.
int: the number of steps that need to taken from soure to goal (i.e. len(path) - 1 or 2(dim - 1))
list(tuple(int,int)): all visited squares from the source point. Contains int tuples in the form of (x,y).
int: trajectory - num of times a node was popped from the fringe
int: cells processed - num of times we check if a particular grid square was block or not
"""
if source == goal:
return [source], 0, [source], 0, 0
g = defaultdict(lambda: float("inf")) # dictionary with key: (x,y) tuple and value: length of the shortest path from (0,0) to (x,y)
g[source] = 0 # set distance to source 0
parent = {
source: None
} # dictionary with key: (x,y) tuple and value is a (x,y) tuple where the value point is the parent point of the key point in terms of path traversal
pq = [(0, -g[source], source)]
trajectory = cells_processed = 0 # statistics to keep track while A* runs
while pq:
_, _, curr = heapq.heappop(pq) # pop from priority queue
# pretty_print_explored_region(grid, list(g.keys())) # uncomment if you want to see each step of A*
trajectory += 1
if curr == goal: # check if we found our goal point
break
# generate children
for dx, dy in [(0, 1), (-1, 0), (0, -1), (1, 0)]: # right down left up
new_x, new_y = (curr[0] + dx, curr[1] + dy) # compute children
# skip out of bounds or blocked children
if not (0 <= new_x < len(grid) and 0 <= new_y < len(grid)) or grid[new_x][new_y] >= 1.0:
continue
children = (new_x, new_y)
new_g= g[curr] + 1 # add + 1 to the real cost of children
if (
children not in g or new_g < g[children]
): # only care about new undiscovered children or children with lower cost
cells_processed += 1
g[children] = new_g # update cost
parent[children] = curr # update parent of children to current
h_value = h(*children, *goal) # calculate h value
f = new_g + W * h_value # generate f(n')
heapq.heappush(
pq, (f, -new_g, children)
) # add children to priority queue | -new_g is added to break ties
else:
return [], float("inf"), list(g.keys()), trajectory, cells_processed
# generate path traversal using parent dict
path = [curr]
while curr != source:
curr = parent[curr]
path.append(curr)
path.reverse() # reverse the path so that it is the right order
return path, g[goal], list(g.keys()), trajectory, cells_processed
# Nicely prints out the grid
# Testing use only.
class colors: #
PINK = '\033[95m'
RED = '\033[91m'
BLUE = '\033[94m'
BOLD = '\033[1m'
ENDC = '\033[0m'
def pretty_print_explored_region(grid, visited=[], path=[]):
color = '\33[102m'
grid_copy = deepcopy(grid.tolist())
for x,y in visited:
grid_copy[x][y] = colors.RED + str(grid[x][y]) + colors.ENDC
for x,y in path:
if type(grid_copy[x][y]) == str:
grid_copy[x][y] = colors.PINK + str(grid[x][y]) + colors.ENDC
else:
grid_copy[x][y] = colors.BLUE + str(grid[x][y]) + colors.ENDC
for i in range(len(grid)):
for j in range(len(grid)):
if type(grid_copy[i][j]) != str:
if grid_copy[i][j] == 0.8:
color = '\33[43m'
elif grid_copy[i][j] == 0.5:
color = '\33[102m'
elif grid_copy[i][j] == 0.2:
color = '\33[42m'
elif grid_copy[i][j] == 1:
color = '\33[100m'
elif grid_copy[i][j] == 0:
color = '\033[94m'
elif grid_copy[i][j] == 2.0:
color = '\033[104m'
grid_copy[i][j] = color + str(grid_copy[i][j]) + colors.ENDC
for row in grid_copy:
for col in row:
print(col, end=" ") # print each element separated by space
print() # Add newline
print("-------")
# RED == visited
# Purple == path that was also visited
# Blue == path but was not visited (i.e. if agent knew the complete grid then it was have used the blue path)
class CardinalDirection(Enum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
def __int__(self):
return self.value
def _get_cardinal_direction(x1, y1, x2, y2):
if x1 > x2:
return CardinalDirection.NORTH
elif x1 < x2:
return CardinalDirection.SOUTH
elif y1 > y2:
return CardinalDirection.WEST
elif y1 < y2:
return CardinalDirection.EAST
else:
raise Exception("Invalid cardinal direction")
import numpy as np
### Repeated A* - Remastered for P4
def repeated_a_star(complete_grid: List[List[int]], use_full_grid=False, h:Callable = manhattan_heuristic, source:Tuple[int,int] = (0,0), goal:Tuple[int,int] = None, W:int = 1) -> (List[List[int]], List[Tuple[int,int]], int, int, int, int):
"""Repeated A* Implementation
Used to find optimal shortest path from source to goal on a full gridworld (i.e with complete information of the block cells).
Args:
complete_grid (2d int list): A square grid with 0 representing empty square and 1 representing a blocked square
user_full_grid (bool) Optinal: False indicates that the agent should not be able to see the whole grid and True represents the agent has information of the entire grid
h (func) Optional: heuristic function and the default heuristic function is the manhattan distance
source (tuple) Optional: can control where to start the A* search from
goal (tuple) Optional: can control where A* star tries to explore to
Returns:
list[list[int]]: the discovered grid after running repeated A*
list[tuple[int, int]]: the path from the source node to the goal node which consists of coordinate points (i.e. x,y) in the form of tuples.
int: the number of steps in the path from source to goal (i.e. len(path) - 1 or 2(dim - 1))
list[tuple[int,int]]: all visited squares from the source point. Contains int tuples in the form of (x,y).
int: trajectory
int: num of cells processed
int: number of repeats in the repeated A* algorithm
return grid, path_discovered_grid, len(path_discovered_grid)-1, list(g.keys()), trajectory, cells_processed, repeats
"""
original_source = source #we need to remember the original_source node bc in repeated A* source changes
n, m = len(complete_grid) - 1, len(complete_grid[0]) - 1 # dimensions of grid
goal = (n, m) # goal coord
trajectory = cells_processed = repeats = 0 #statistics that is tracked during function call
safe = set()
if use_full_grid: # if agent is allowed knowledge of full grid
grid = np.array(deepcopy(complete_grid))
else: # if agent has to discover
grid = np.array([[0.5]*(m+1) for _ in range(n+1)])
grid[source[0]][source[1]] = 0.0
grid_states = []
actions = []
while source != goal: # execution step of repeated A*
repeats += 1
path, *_ = a_star(grid, source, goal)
if not path: # no path was found from original source to goal
return grid, [], grid_states, actions, 0, list(safe), trajectory, cells_processed, repeats
# iterator = iter(path)
# next(iterator)
for pos in path: # travel along the planned path
x,y = pos
if source != pos:
source_neighbors = [source[0], source[1]]
for dx, dy in [(0, 1), (-1, 0), (0, -1), (1, 0)]: # generate children
new_x, new_y = (source[0] + dx, source[1] + dy)
if 0 <= new_x <= n and 0 <= new_y <= m:
source_neighbors.extend([new_x, new_y, grid[new_x][new_y]])
else:
source_neighbors.extend([0,0,1])
grid_states.append([np.array(deepcopy(grid)), source_neighbors])
actions.append(_get_cardinal_direction(*source, *pos))
grid[x][y] = complete_grid[x][y]
if grid[x][y] == 1.0:
break
safe.add(pos) # add to closed set
trajectory += 1
source = pos
# uncomment if we want agent to see in all 4 directions
# for dx, dy in [(0, 1), (-1, 0), (0, -1), (1, 0)]: # generate children
# new_x, new_y = [x + dx, y + dy]soru
# if 0 <= new_x <= n and 0 <= new_y <= m:
# grid[new_x][new_y] = complete_grid[new_x][new_y] # update knowledge of the grid world
# pretty_print_explored_region(grid, list(safe))
if not use_full_grid:
for i in range(len(grid)):
for j in range(len(grid[0])):
if not (i,j) in safe and grid[i][j] != 1.0:
grid[i][j] = 2.0
path_discovered_grid, length_discovered_grid, visited_discovered_grid, _, _ = a_star(grid, original_source, target) # call regular A* to generate path from source to goal in final discovered grid world return grid, path_discovered_grid, grid_states, actions, len(path_discovered_grid)-1, list(safe), trajectory, cells_processed, repeats
return grid, path_discovered_grid, grid_states, actions, len(path_discovered_grid) - 1, list(safe), trajectory, cells_processed, repeats
NUM_TRAINING_MAZES = 100
NUM_TESTING_MAZES = 100
MAZE_DIM = 50
DENSITY = 0.3
SOLVABILITY = True
IS_TERRAIN_ENABLED = False
source = (0,0)
target = (MAZE_DIM-1, MAZE_DIM-1)
# grid, _, _ = gen_gridworld(dim, density, solvability, is_terrain_enabled, source, target)
# # pretty_print_explored_region(grid)
# explored_grid, path, grid_states, actions, len_path, visited_cells, *_ = repeated_a_star(grid, True)
# pretty_print_explored_region(explored_grid, visited_cells, path)
# explored_grid, path, grid_states, actions, len_path, visited_cells, *_ = repeated_a_star(grid)
# pretty_print_explored_region(explored_grid, visited_cells, path)
# print(len(grid_states))
# print(len(actions))
def gen_grid_data(num_iter: int, data_format: int):
all_grids = []
all_sources = []
all_actions = []
for i in range(num_iter):
if i % 100 == 0:
print("Maze", i)
grid, _, _ = gen_gridworld(MAZE_DIM, DENSITY, SOLVABILITY, IS_TERRAIN_ENABLED, source, target)
explored_grid, path, grid_states, actions, len_path, visited_cells, *_ = agent(grid, dataformat)
grids, sources = zip(*grid_states)
all_grids.extend(grids)
all_sources.extend(sources)
all_actions.extend(actions)
return (np.array(all_grids), np.array(all_sources), np.array(all_actions))
def generate_confusion_matrix( data, labels ):
mat = [ [ 0 for i in range(10) ] for j in range(10) ]
out = model.predict(data)
predictions = np.argmax( out , axis = 1 )
for i in range( data.shape[0] ):
mat[ labels[i] ][ predictions[i] ] += 1
for i in range(10):
print( "\t".join( [ str(c) for c in mat[i] ] ) )
# Full Dense Layer for P1's Agent
# Model Creation
# NUM_PROCESSES = 10
# PROCESS_DATA_LIMIT = 100
# BATCH_SIZE = 100
# NUM_TRAINING_MAZES = 1000 * NUM_PROCESSES
# NUM_TESTING_MAZES = 100
# MAZE_DIM = 50
# DENSITY = 0.3
# SOLVABILITY = True
# IS_TERRAIN_ENABLED = False
# source = (0,0)
# target = (MAZE_DIM-1, MAZE_DIM-1)
# total_grid_states = []
# total_actions = []
# total_test_grid_states = []
# total_test_grid_actions = []
# def gen_grid_data(idx, num_iter):
# all_grids = []
# all_sources = []
# all_actions = []
# for i in range(num_iter):
# if i % 100 == 0:
# print("Maze", i)
# grid, _, _ = gen_gridworld(MAZE_DIM, DENSITY, SOLVABILITY, IS_TERRAIN_ENABLED, source, target)
# explored_grid, path, grid_states, actions, len_path, visited_cells, *_ = repeated_a_star(grid)
# grids, sources = zip(*grid_states)
# all_grids.extend(grids)
# all_sources.extend(sources)
# all_actions.extend(actions)
# return (np.array(all_grids), np.array(all_sources), np.array(all_actions))
# def gen_grid_data_train(idx=0):
# return gen_grid_data(idx, 250)
# def gen_grid_data_test(idx=0):
# return gen_grid_data(idx, PROCESS_DATA_LIMIT)
# arrs = [gen_grid_data_test()]
# total_test_grids = np.array([x for sublist in arrs for x in sublist[0]])
# total_test_sources = np.array([x for sublist in arrs for x in sublist[1]])
# total_test_actions = np.array([x for sublist in arrs for x in sublist[2]])
### BLOCK 4 ###
# import tensorflow as tf
# # To Enable GPU
# # NOTE: Deepnote has no GPU available for free tier
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# print("NUM GPUs Available: ", len(physical_devices))
# tf.config.experimental.set_memory_growth(physical_devices[0], True)
# def generate_confusion_matrix( data, labels ):
# mat = [ [ 0 for i in range(4) ] for j in range(4) ]
# print(np.shape(data[0]))
# print(np.shape(data[1]))
# out = model.predict(data)
# predictions = np.argmax( out , axis = 1 )
# unique, counts = np.unique(predictions, return_counts=True)
# print("predict", dict(zip(unique, counts)))
# unique, counts = np.unique(np.array([int(x) for x in labels]), return_counts=True)
# print("labels", dict(zip(unique, counts)))
# for i in range( data[0].shape[0] ):
# mat[ int(labels[i]) ][ predictions[i] ] += 1
# for i in range(4):
# print( "\t".join( [ str(c) for c in mat[i] ] ) )
# grids_train = np.reshape(total_grids, (-1, MAZE_DIM, MAZE_DIM))
# sources_train = total_sources
# output_train = tf.keras.utils.to_categorical(
# np.array([int(x) for x in total_actions]),
# 4,
# )
# grids_test = np.array([x[0] for x in total_test_grid_states])
# grids_test = np.reshape(total_test_grids, (-1, MAZE_DIM, MAZE_DIM) )
# sources_test = total_test_sources
# output_test = np.array([int(x) for x in total_test_actions])
# print("bottle 3")
# Grid input
# Models
def generate_dense_model():
grid_input = tf.keras.layers.Input( shape = (MAZE_DIM, MAZE_DIM) )
flatten_grids = tf.keras.layers.Flatten()( grid_input )
# Source input
source_input = tf.keras.layers.Input( shape = (14) )
# Full input
grids_and_sources = tf.keras.layers.Concatenate()([flatten_grids, source_input])
dense_1 = tf.keras.layers.Dense( units = 2000, activation = tf.nn.relu )( grids_and_sources )
dense_2 = tf.keras.layers.Dense( units = 500, activation = tf.nn.relu )( dense_1 )
probabilities = tf.keras.layers.Dense( units = 4, activation = tf.nn.softmax )( dense_2 )
model = tf.keras.Model(inputs=[grid_input, source_input], outputs=probabilities)
model.compile( optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'] )
print(model.summary())
return model
def generate_cnn_model():
# Grid input
grid_input = tf.keras.layers.Input( shape = (MAZE_DIM, MAZE_DIM) )
flatten_grids = tf.keras.layers.Flatten()( grid_input )
# Source input
source_input = tf.keras.layers.Input( shape = (14) )
# Full input
grids_and_sources = tf.keras.layers.Concatenate()([flatten_grids, source_input])
grids_and_sources = tf.expand_dims(grids_and_sources, axis = -1)
grids_and_sources = tf.expand_dims(grids_and_sources, axis = -1)
conv2d_1 = tf.keras.layers.Conv2D(filters=256,kernel_size=2,padding="same",activation="relu")( grids_and_sources )
maxpool2d_1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_1 )
conv2d_2 = tf.keras.layers.Conv2D(filters=32,kernel_size=2,padding="same",activation="relu")( maxpool2d_1 )
maxpool2d_2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_2 )
conv2d_3 = tf.keras.layers.Conv2D(filters=4,kernel_size=2,padding="same",activation="relu")( maxpool2d_2 )
maxpool2d_3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_3 )
flatten = tf.keras.layers.Flatten()( maxpool2d_3 )
dense_1 = tf.keras.layers.Dense( units = 500, activation = tf.nn.relu )( flatten )
logits = tf.keras.layers.Dense( units = 4, activation = "softmax" )( dense_1 )
model = tf.keras.Model(inputs=[grid_input, source_input], outputs=logits)
model.compile( optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'] )
print(model.summary())
return model
# model = generate_cnn_model()
# print("bottle 5")
# model = tf.keras.models.load_model('./test_data/')
# generate_confusion_matrix( [grids_test, sources_test], output_test )
# NUM_PROCESSES = 10
# PROCESS_DATA_LIMIT = 100
# BATCH_SIZE = 250
# from sklearn.utils import class_weight
# NUM_TRAINING_MAZES = 1000
# total_grids = []
# total_sources = []
# total_actions = []
# for i in range(0, NUM_TRAINING_MAZES, BATCH_SIZE):
# if i != 0:
# model = tf.keras.models.load_model('./test_data/')
# print(f"Training starting on maze {i}")
# total_grids, total_sources, total_actions = gen_grid_data_train()
# grids_train = np.reshape(total_grids, (-1, MAZE_DIM, MAZE_DIM))
# sources_train = total_sources
# output_train = tf.keras.utils.to_categorical(
# np.array([int(x) for x in total_actions]),
# 4,
# )
# # y_ints = [y.argmax() for y in output_train]
# # class_weights = dict(enumerate(class_weight.compute_class_weight('balanced', classes=np.unique(y_ints), y=y_ints)))
# north_multiplier = 3
# west_multiplier = 2
# class_weights = {
# 0 : north_multiplier,
# 1 : 1,
# 2 : 1,
# 3 : west_multiplier,
# }
# # multiplier = 0 if multiplier == 0 else multiplier - 1
# history = model.fit( [grids_train, sources_train], output_train, epochs = 10, validation_split=.1)
# model.save('./test_data/')
# generate_confusion_matrix( [grids_test, sources_test], output_test )
# model.save('./test_data/')
# Dataset will be of shape (# of inputs, # rows, # cols)
# CNN for P1's Agent
def p1_cnn_model():
# Grid input
grid_input = tf.keras.layers.Input( shape = (MAZE_DIM, MAZE_DIM) )
flatten_grids = tf.keras.layers.Flatten()( grid_input )
# Source input
source_input = tf.keras.layers.Input( shape = (2) )
# Full input
grids_and_sources = tf.keras.layers.Concatenate()([flatten_grids, source_input])
#Conv2d layer requires minimum ndim of 4, so adding 2 more empty dims
grids_and_sources = tf.expand_dims(grids_and_sources, axis = -1)
grids_and_sources = tf.expand_dims(grids_and_sources, axis = -1)
#Model
conv2d_1 = tf.keras.layers.Conv2D(filters=256,kernel_size=2,padding="same",activation="relu")( grids_and_sources )
maxpool2d_1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_1 )
conv2d_2 = tf.keras.layers.Conv2D(filters=32,kernel_size=2,padding="same",activation="relu")( maxpool2d_1 )
maxpool2d_2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_2 )
conv2d_3 = tf.keras.layers.Conv2D(filters=4,kernel_size=2,padding="same",activation="relu")( maxpool2d_2 )
maxpool2d_3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_3 )
flatten = tf.keras.layers.Flatten()( maxpool2d_3 )
dense_1 = tf.keras.layers.Dense( units = 500, activation = tf.nn.relu )( flatten )
logits = tf.keras.layers.Dense( units = 4, activation = "softmax" )( dense_1 )
model = tf.keras.Model(inputs=[grid_input, source_input], outputs=logits)
model.compile( optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'] )
print(model.summary())
return model
# # Data
# nn_path_data =[
# [172, 171, 161, 143, 150, 146, 125, 127, 176, 311, 137, 148, 179, 316, 143, 157, 221, 152, 213, 184, 156, 225, 148, 191, 173, 216, 158, 168, 164, 127, 167, 191, 174, 179, 171, 160, 194, 163, 154, 137, 165, 147, 179, 208, 181, 157, 199, 145, 168, 153, 188, 145, 183, 126, 135, 155, 189, 239, 164, 214, 152, 155, 160, 397, 177, 173, 145, 134, 157, 151, 171, 144, 219, 162, 224, 199, 155, 154, 156, 192, 183, 166, 157, 188, 166, 150, 177, 157, 162, 146, 155, 206, 133, 123, 162, 160, 157, 175, 176, 171],
# [167, 156, 131, 212, 186, 129, 151, 179, 154, 136, 160, 172, 194, 128, 118, 163, 246, 129, 162, 136, 191, 144, 140, 191, 150, 148, 159, 163, 192, 141, 263, 149, 171, 139, 175, 189, 173, 129, 150, 137, 163, 182, 184, 127, 164, 114, 202, 171, 132, 141, 150, 190, 125, 172, 180, 200, 193, 138, 130, 142, 187, 173, 159, 146, 200, 193, 172, 127, 142, 167, 193, 198, 190, 163, 169, 143, 196, 198, 224, 276, 149, 174, 157, 141, 160, 175, 162, 134, 125, 211, 115, 152, 143, 233, 276, 184, 157, 155, 160, 136],
# [143, 206, 150, 156, 189, 126, 149, 139, 158, 130, 160, 163, 203, 157, 139, 154, 201, 198, 220, 156, 186, 157, 127, 180, 178, 135, 176, 160, 180, 159, 173, 172, 180, 132, 184, 165, 133, 217, 227, 152, 189, 239, 166, 154, 163, 283, 142, 188, 151, 248, 152, 153, 163, 132, 147, 152, 221, 149, 165, 186, 151, 139, 193, 184, 181, 175, 210, 144, 139, 188, 144, 167, 157, 171, 175, 169, 180, 142, 125, 124, 191, 168, 186, 190, 203, 193, 183, 184, 176, 182, 168, 201, 147, 153, 162, 187, 162, 209, 141, 183],
# [236, 178, 137, 121, 138, 160, 181, 136, 132, 139, 150, 168, 124, 211, 169, 182, 245, 165, 216, 275, 171, 195, 220, 170, 174, 165, 140, 138, 148, 125, 171, 179, 204, 216, 252, 172, 177, 133, 191, 131, 179, 157, 137, 256, 185, 174, 143, 149, 123, 151, 137, 211, 255, 135, 169, 208, 137, 180, 140, 195, 192, 227, 236, 133, 183, 212, 163, 131, 162, 142, 151, 150, 180, 149, 145, 124, 160, 198, 152, 165, 198, 156, 136, 273, 147, 145, 144, 174, 187, 148, 169, 148, 172, 192, 174, 155, 149, 136, 181, 157],
# [231, 151, 135, 207, 225, 150, 151, 212, 140, 239, 171, 187, 248, 235, 145, 228, 196, 197, 287, 166, 156, 140, 159, 187, 201, 166, 175, 202, 140, 190, 133, 151, 148, 157, 184, 188, 144, 122, 196, 171, 257, 178, 113, 202, 138, 135, 161, 163, 206, 171, 163, 143, 177, 148, 170, 212, 141, 194, 146, 134, 175, 164, 196, 168, 138, 143, 186, 147, 264, 127, 136, 135, 176, 191, 148, 123, 177, 164, 131, 152, 162, 208, 169, 170, 230, 136, 151, 182, 154, 163, 292, 137, 189, 171, 213, 162, 247, 136, 151, 170],
# [171, 165, 164, 147, 225, 151, 133, 149, 144, 192, 158, 244, 160, 188, 173, 138, 122, 178, 201, 140, 208, 231, 149, 168, 154, 149, 148, 175, 197, 171, 234, 153, 152, 140, 355, 213, 148, 174, 142, 126, 155, 174, 202, 174, 143, 146, 145, 203, 160, 154, 231, 153, 165, 195, 178, 151, 191, 133, 192, 136, 188, 201, 172, 257, 146, 174, 193, 163, 153, 142, 174, 321, 140, 134, 244, 212, 147, 154, 176, 132, 160, 175, 234, 138, 163, 118, 193, 172, 195, 213, 214, 160, 113, 148, 129, 225, 220, 188, 182, 213],
# [136, 165, 177, 201, 138, 168, 156, 240, 243, 152, 136, 188, 134, 162, 183, 163, 166, 216, 135, 129, 164, 160, 129, 133, 144, 154, 205, 142, 198, 181, 157, 208, 138, 163, 196, 201, 177, 176, 221, 141, 115, 161, 139, 232, 175, 133, 158, 183, 210, 125, 150, 137, 152, 201, 161, 171, 142, 143, 204, 167, 159, 175, 150, 170, 110, 159, 172, 158, 148, 135, 202, 223, 170, 246, 186, 124, 143, 163, 211, 151, 187, 213, 186, 179, 173, 172, 153, 225, 148, 193, 272, 115, 180, 160, 176, 150, 148, 208, 157, 132],
# [177, 332, 147, 207, 132, 162, 133, 133, 195, 184, 174, 153, 158, 149, 171, 189, 108, 203, 131, 162, 225, 176, 143, 154, 279, 234, 181, 152, 223, 142, 139, 193, 138, 150, 168, 149, 155, 131, 136, 167, 126, 157, 209, 156, 153, 155, 136, 134, 147, 139, 149, 139, 159, 214, 130, 211, 143, 148, 137, 133, 156, 200, 201, 174, 161, 205, 131, 137, 157, 154, 170, 161, 237, 178, 180, 172, 124, 138, 143, 185, 126, 126, 151, 156, 159, 247, 177, 167, 168, 139, 142, 126, 150, 164, 156, 120, 172, 184, 148, 185],
# [159, 187, 206, 237, 147, 198, 224, 130, 167, 208, 240, 156, 216, 188, 141, 152, 131, 167, 146, 146, 178, 175, 139, 152, 183, 144, 248, 173, 132, 173, 180, 187, 148, 164, 159, 170, 139, 169, 231, 153, 156, 197, 204, 118, 190, 208, 137, 134, 188, 164, 132, 148, 124, 204, 235, 188, 144, 152, 188, 122, 177, 141, 143, 181, 168, 209, 191, 157, 152, 180, 184, 196, 132, 116, 123, 166, 184, 127, 164, 177, 148, 175, 135, 152, 141, 230, 159, 164, 188, 159, 143, 143, 141, 162, 183, 157, 147, 142, 148, 188],
# ]
# nn_traj_data = [
# [198, 206, 178, 158, 174, 172, 138, 136, 202, 416, 148, 168, 210, 560, 164, 184, 302, 202, 268, 210, 180, 298, 160, 238, 190, 356, 196, 184, 218, 140, 190, 254, 216, 226, 216, 190, 234, 188, 176, 146, 208, 168, 218, 276, 238, 184, 264, 162, 196, 190, 224, 170, 238, 134, 152, 174, 216, 350, 190, 246, 192, 174, 186, 682, 220, 204, 154, 144, 188, 176, 196, 162, 292, 190, 268, 228, 192, 216, 182, 232, 226, 194, 180, 224, 218, 166, 220, 184, 182, 158, 184, 310, 150, 130, 180, 184, 178, 188, 206, 194],
# [194, 194, 140, 276, 216, 142, 172, 204, 168, 164, 194, 232, 242, 132, 126, 194, 344, 142, 186, 148, 232, 152, 150, 248, 166, 170, 188, 196, 244, 160, 296, 168, 200, 178, 214, 236, 218, 140, 180, 166, 188, 224, 260, 130, 190, 122, 294, 208, 144, 168, 198, 238, 136, 218, 248, 234, 232, 156, 138, 158, 258, 188, 208, 164, 280, 250, 220, 142, 154, 182, 258, 236, 216, 176, 196, 202, 274, 268, 302, 400, 158, 212, 198, 152, 202, 202, 182, 142, 134, 248, 118, 168, 160, 314, 370, 204, 176, 178, 186, 150],
# ]
# nn_time_data = [
# [1.4462645419989713, 1.7088955150029506, 1.361946345998149, 1.3695014600016293, 1.5671857359993737, 1.481699056996149, 1.0677907689969288, 1.0918150110010174, 1.6448368850033148, 2.956772263001767, 1.130766610003775, 1.3165569060001872, 1.492839238999295, 3.1228196939991903, 1.3154978949969518, 1.6214799270019284, 2.2099633420002647, 1.5267259890024434, 2.034569035000459, 1.5807912589953048, 1.461429915005283, 2.1437796460013487, 1.3408771809990867, 1.958894108000095, 1.637375147001876, 2.5072193649975816, 1.4501445119967684, 1.7286867209986667, 1.610526900003606, 1.2609322850039462, 1.4866449640030623, 1.7460907530039549, 1.6500354439995135, 1.6812347960003535, 1.572385696999845, 1.3791274210016127, 1.7791926530044293, 1.4637708570007817, 1.3210894119984005, 1.238126561001991, 1.5753058370028157, 1.4708553770033177, 1.7796351209981367, 2.026341176999267, 1.4498691010012408, 1.5006212630032678, 2.0487539819951053, 1.2980627099968842, 1.5662024880002718, 1.5562543800042477, 1.7270862340010353, 1.5350803129986161, 1.7540661050006747, 1.1820562849970884, 1.2153055049930117, 1.5176550920004956, 1.8629052899996168, 2.285796620999463, 1.6908715270037646, 1.991654160003236, 1.3634275820004405, 1.5014306609955383, 1.632629353000084, 4.531791525994777, 1.863866973995755, 1.5935716189997038, 1.1897914500004845, 1.2630833810035256, 1.567113402998075, 1.427508641994791, 1.6500954360017204, 1.367469896998955, 2.382199089995993, 1.5909635170028196, 2.2305142379991594, 1.711697260994697, 1.6914257719981833, 1.5398958540026797, 1.4747449179994874, 1.7230304740005522, 1.7504821960028494, 1.5616723990024184, 1.3164274230002775, 1.6892336719974992, 1.6094090199985658, 1.3447162219963502, 1.5157135109984665, 1.5700635639950633, 1.5621553220044007, 1.3887742759980028, 1.3669046880022506, 2.3299648619940854, 1.2999092739992193, 1.0635495370006538, 1.3798666419970687, 1.5076081209990662, 1.3359034300010535, 1.5524928880040534, 1.558548189997964, 1.6895183870001347],
# [1.4504115989984712, 1.5137215349968756, 1.1622854930028552, 2.041229182999814, 1.711910677004198, 1.1075217110046651, 1.3842580400014413, 1.5851321219961392, 1.4574365169974044, 1.375445090001449, 1.6561897060018964, 1.8218259740024223, 1.8567485329986084, 1.133061666994763, 1.0824113999988185, 1.4250733490043785, 2.421184942999389, 1.261483969996334, 1.3674899180041393, 1.1437649149956997, 1.7144728459970793, 1.2725698549984372, 1.260526967002079, 1.8355907760051196, 1.2751189100017655, 1.3319198090030113, 1.4852830349991564, 1.5170181389985373, 2.0113240930004395, 1.3794324949994916, 2.297262980006053, 1.4246883660016465, 1.6256173689980642, 1.3537732810000307, 1.6936436590040103, 1.7586271629988914, 1.7012998569989577, 1.093642608000664, 1.4618213540015859, 1.1387302250004723, 1.3610037839971483, 1.5712911130030989, 1.9385262889991282, 1.0183718569969642, 1.4017351840011543, 1.0468262499998673, 1.97225828399678, 1.6544202570003108, 1.2952454659971409, 1.3765625430023647, 1.4951180589996511, 2.111654845000885, 1.3042532300023595, 1.7988734469981864, 1.8699254210005165, 2.0605206780019216, 1.6530067200001213, 1.32013470500533, 1.191049490000296, 1.3372293620050186, 1.8475971349980682, 2.3003795240001637, 2.4239870899982634, 2.3174276360005024, 1.9914600330012036, 1.9868471890004002, 1.6869703020056477, 1.2174907840017113, 1.2701826710035675, 1.5897011569977622, 1.837059165998653, 1.7019206950062653, 1.703163393001887, 1.6077323589997832, 1.4664895879977848, 1.5578441969992127, 2.254363836997072, 2.014105900998402, 2.288929766000365, 2.720100979000563, 1.367125153003144, 1.557388383000216, 1.3633819360038615, 1.4611007810017327, 1.7401004769999417, 1.648301409004489, 1.2716110490000574, 1.2832517150018248, 1.1009501360022114, 1.9707238739938475, 1.089886841000407, 1.4893638449939317, 1.3898896640021121, 2.3699478070047917, 2.840002678000019, 1.6281013840052765, 1.633849243997247, 1.4584663139976328, 1.2724822099989979, 1.3325593629997456],
# ]
# cnn_path_data =[
# [178, 182, 139, 146, 185, 129, 129, 130, 218, 292, 159, 205, 174, 295, 142, 156, 167, 152, 192, 196, 132, 213, 155, 190, 164, 188, 154, 199, 134, 117, 165, 187, 185, 136, 165, 158, 208, 149, 150, 140, 164, 147, 172, 198, 183, 131, 177, 147, 172, 159, 163, 136, 177, 126, 145, 149, 183, 214, 164, 205, 123, 157, 159, 214, 168, 168, 150, 134, 191, 152, 145, 145, 164, 216, 169, 245, 195, 208, 188, 365, 205, 217, 181, 186, 138, 151, 144, 149, 171, 146, 188, 175, 259, 122, 145, 204, 172, 183, 147, 274],
# [169, 154, 181, 175, 162, 130, 173, 185, 161, 131, 175, 192, 134, 126, 139, 159, 164, 194, 162, 133, 177, 144, 138, 179, 149, 167, 173, 155, 160, 158, 247, 134, 166, 126, 171, 179, 156, 130, 134, 133, 180, 179, 165, 129, 130, 114, 196, 150, 163, 160, 146, 185, 128, 166, 194, 204, 187, 141, 130, 155, 185, 142, 168, 192, 158, 141, 143, 165, 135, 170, 270, 197, 178, 177, 149, 162, 199, 231, 147, 219, 150, 148, 157, 142, 150, 188, 162, 142, 142, 210, 115, 157, 153, 299, 173, 199, 144, 168, 143, 142],
# ]
# cnn_traj_data = [
# [198, 210, 152, 164, 210, 136, 140, 146, 268, 338, 170, 282, 192, 416, 158, 186, 208, 186, 212, 234, 150, 256, 178, 228, 180, 240, 180, 256, 150, 126, 192, 252, 216, 152, 200, 178, 246, 162, 162, 146, 186, 162, 206, 256, 216, 140, 238, 162, 190, 192, 180, 158, 224, 134, 168, 160, 202, 298, 184, 224, 136, 174, 186, 288, 190, 188, 154, 142, 234, 184, 164, 160, 190, 324, 198, 284, 256, 260, 208, 466, 248, 252, 230, 222, 152, 166, 160, 166, 190, 162, 216, 218, 338, 130, 158, 256, 212, 214, 152, 364],
# [206, 176, 196, 198, 180, 138, 194, 210, 178, 146, 214, 214, 142, 132, 150, 188, 178, 236, 184, 140, 210, 152, 146, 214, 158, 198, 190, 182, 210, 178, 274, 148, 182, 134, 190, 222, 170, 138, 140, 146, 196, 214, 196, 138, 142, 120, 260, 172, 176, 186, 174, 214, 142, 186, 236, 258, 210, 164, 138, 180, 266, 166, 216, 246, 178, 154, 156, 192, 140, 182, 330, 226, 214, 210, 166, 188, 244, 292, 168, 292, 158, 160, 186, 150, 182, 204, 174, 152, 154, 232, 118, 182, 180, 396, 218, 244, 158, 202, 172, 154],
# ]
# cnn_time_data = [
# [3.1681611710009747, 3.4488704670002335, 2.5972534750035265, 2.39050128999952, 3.2114720739991753, 2.21953018700151, 2.2693315770011395, 2.1936211850043037, 4.004001472996606, 5.266598957001406, 2.719514608994359, 3.7566687229991658, 2.8973439629990025, 5.8524786829948425, 2.54256507899845, 3.073332918997039, 2.9624054540036013, 2.679070022000815, 3.3946765509972465, 3.393356473003223, 2.131488144004834, 3.7603600369984633, 2.657604979998723, 3.5237622869972256, 2.9966810870027984, 3.287517933000345, 2.8187359319999814, 3.647578831994906, 2.334226947998104, 2.425244407997525, 2.562130715996318, 3.546778969001025, 3.306786970999383, 2.3190064980008174, 2.7759447139978874, 2.748246379996999, 3.6393377860003966, 2.2729815380007494, 2.6706092660024296, 2.4872900230038795, 2.9862147189996904, 2.463295052999456, 3.235506531003921, 3.5093516210035887, 3.307985501996882, 2.0174351430032402, 2.876303734999965, 2.428748362995975, 3.375652690003335, 3.0004235390006215, 2.686729392000416, 2.7738840309975785, 3.5097064369983855, 2.193867797999701, 2.1587642910017166, 2.7350801730062813, 3.175484421000874, 3.7803448620034033, 2.961349263001466, 3.4810102789997472, 1.93788139600656, 3.102783415997692, 3.281143915002758, 4.547659856994869, 3.6292848410012084, 3.3342111149977427, 2.7366149719964596, 2.5005846380008734, 3.641561188000196, 2.79611100400507, 2.5080093610013137, 2.587433703993156, 3.0924017799989088, 4.064349804000813, 3.0644004969944945, 4.892293147000601, 4.038530581994564, 4.295947452999826, 3.74842912999884, 6.897448293995694, 3.745304677002423, 3.878832918999251, 3.138111257998389, 3.409589037000842, 2.3090972069985582, 2.6011260119994404, 2.420690259998082, 2.8477310970047256, 2.9397604209952988, 2.478690045994881, 3.305076871998608, 3.2803818749962375, 4.56525243500073, 2.1952713069986203, 2.466344464002759, 3.717819346995384, 2.593732893001288, 3.0792648589995224, 2.2692328090051888, 4.546211059001507],
# [2.9352778350003064, 2.6295722520007985, 3.304458691003674, 3.1032361170000513, 2.7665672970033484, 2.111409345998254, 3.2600351959990803, 3.3958402250063955, 3.2778234599973075, 2.549487885000417, 3.539209075999679, 3.3255862110017915, 2.154346262002946, 2.0528118280053604, 2.5815510890024598, 3.2778984589967877, 2.6950241529993946, 3.757954674998473, 2.6924402369986637, 2.1856680310011143, 3.2522206470021047, 2.450288427993655, 2.393204070001957, 3.5483487370001967, 2.582777345996874, 2.8061551520004286, 3.0597442100042826, 2.7193629489993327, 2.8052823940015514, 2.733398912998382, 4.248718199996802, 2.4132908469982794, 2.932154548994731, 2.230320798997127, 3.108970108005451, 3.3919590239966055, 2.628881643002387, 2.1138404999946943, 2.108195180000621, 2.3336283349999576, 3.257250654998643, 3.336392330005765, 3.467787073001091, 2.1202006910025375, 1.9102215129969409, 1.9083302529979846, 3.6138152300045476, 2.789393573999405, 2.7651154149934882, 2.514866298995912, 2.7682248769997386, 3.5333717159955995, 2.360732624998491, 3.465297296999779, 3.942351227000472, 3.847256548004225, 3.1378387049990124, 2.409698789997492, 2.125815118997707, 2.5057235290005337, 4.536037727004441, 4.019048437003221, 5.548169278001296, 5.136147662997246, 3.0572892850032076, 2.479010286995617, 2.4478035350039136, 3.142310513998382, 2.4767421310025384, 2.9505696230044123, 4.980051755002933, 3.187295167001139, 3.057738083000004, 2.9806741420034086, 2.854782365000574, 3.5710944370002835, 3.596790516996407, 4.576518321999174, 2.740291005000472, 4.123230744000466, 2.472417308999866, 2.5566429489990696, 2.7170862979983212, 2.6116838620000635, 2.5988648330021533, 3.3835079269993003, 2.5776510770010646, 2.963742710002407, 2.7303960749995895, 3.8479974700021558, 1.9144705860016984, 2.8661852720033494, 2.8295660099975066, 5.925010934006423, 3.803017469996121, 3.5116002290014876, 2.236997373998747, 2.729357700001856, 2.672384099998453, 2.2958169620033004],
# ]
# a_path_data =[
# [172, 201, 134, 137, 153, 129, 129, 127, 174, 283, 159, 144, 174, 231, 142, 240, 164, 150, 184, 184, 153, 206, 144, 189, 161, 188, 147, 186, 134, 116, 162, 185, 175, 162, 161, 160, 207, 149, 147, 140, 162, 144, 172, 192, 229, 131, 177, 147, 150, 163, 162, 153, 176, 126, 144, 149, 183, 211, 164, 216, 129, 157, 196, 185, 166, 168, 150, 143, 177, 152, 144, 143, 179, 181, 155, 177, 177, 148, 162, 188, 150, 206, 157, 186, 193, 151, 144, 149, 174, 146, 166, 157, 221, 120, 145, 190, 152, 179, 154, 164],
# [165, 155, 157, 191, 175, 129, 179, 163, 158, 129, 154, 186, 134, 123, 139, 148, 164, 192, 160, 133, 168, 144, 137, 198, 149, 172, 155, 155, 159, 158, 256, 136, 166, 124, 158, 176, 156, 130, 134, 136, 180, 176, 168, 128, 132, 114, 181, 171, 138, 155, 153, 165, 128, 164, 160, 191, 185, 137, 133, 155, 182, 146, 165, 181, 158, 140, 143, 159, 135, 168, 265, 194, 178, 150, 147, 162, 185, 207, 147, 253, 148, 148, 158, 142, 163, 166, 162, 140, 142, 204, 115, 199, 150, 285, 162, 183, 141, 168, 134, 142],
# ]
# a_traj_data = [
# [246, 315, 202, 204, 233, 185, 189, 185, 266, 426, 221, 213, 268, 427, 214, 384, 262, 245, 267, 281, 239, 300, 211, 302, 249, 315, 226, 310, 206, 178, 235, 317, 275, 269, 263, 262, 332, 206, 216, 208, 242, 203, 282, 319, 414, 182, 307, 210, 237, 248, 240, 231, 294, 180, 202, 221, 275, 361, 252, 327, 187, 242, 313, 315, 266, 253, 207, 208, 283, 232, 212, 209, 286, 316, 242, 325, 317, 246, 261, 304, 238, 337, 234, 298, 299, 233, 211, 237, 275, 225, 273, 244, 366, 168, 205, 324, 221, 287, 240, 237],
# [261, 243, 241, 302, 266, 178, 276, 246, 245, 205, 230, 289, 190, 164, 203, 225, 251, 303, 231, 181, 260, 209, 200, 307, 215, 265, 233, 239, 260, 241, 363, 209, 256, 189, 248, 295, 222, 182, 185, 211, 273, 280, 274, 175, 186, 163, 309, 305, 198, 243, 256, 243, 200, 240, 274, 298, 279, 211, 201, 235, 326, 226, 292, 299, 254, 211, 215, 250, 186, 256, 449, 291, 287, 224, 208, 262, 312, 350, 217, 446, 213, 204, 260, 206, 249, 265, 223, 200, 197, 312, 156, 302, 239, 476, 263, 282, 201, 251, 186, 201],
# ]
# a_time_data = [
# [0.03764109400071902, 0.0665087639936246, 0.055336618002911564, 0.040562173002399504, 0.04211132999625988, 0.03701496499706991, 0.029527172999223694, 0.03046992699819384, 0.052500756006338634, 0.11247568399994634, 0.03238277600030415, 0.04376925800170284, 0.036870178002573084, 0.09187774400197668, 0.03047523599525448, 0.08022449100099038, 0.0431034740031464, 0.04383408099965891, 0.0401312269968912, 0.03998642199439928, 0.039533506002044305, 0.057588183000916615, 0.03489538400026504, 0.05877347600471694, 0.04503441700217081, 0.06283477300166851, 0.03367466400231933, 0.05791689600300742, 0.03354507999756606, 0.04060867600492202, 0.03304345200012904, 0.07326998899952741, 0.05628500199964037, 0.05682707199594006, 0.053689710002799984, 0.05700245999469189, 0.05808545000036247, 0.040780000999802724, 0.04950576599367196, 0.05084559500392061, 0.0429189580027014, 0.048003872005210724, 0.041323500001453795, 0.04930623900145292, 0.08477202700305497, 0.031486784006119706, 0.11072091100504622, 0.034518994005338755, 0.05179514999326784, 0.06562065599428024, 0.03846069800056284, 0.07542236299923388, 0.04859672699967632, 0.03368868500547251, 0.027400225997553207, 0.04908360500121489, 0.04184313500445569, 0.07303698199393693, 0.035634817999380175, 0.10317792299611028, 0.03038611500232946, 0.034473084000637755, 0.05132676699577132, 0.06860959399637068, 0.10102555200137431, 0.03964684499806026, 0.033419158004107885, 0.07612029400479514, 0.05185671000072034, 0.06676756500382908, 0.030711729006725363, 0.03168183499656152, 0.03670132899424061, 0.07147242700011702, 0.041541025006154086, 0.07393235699419165, 0.06053043800056912, 0.037619652997818775, 0.05440364500100259, 0.06293175499740755, 0.05568182199931471, 0.06096842400438618, 0.03585370800283272, 0.048104785004397854, 0.07540274199709529, 0.0364063380038715, 0.039604245001100935, 0.03865141999995103, 0.06167639000341296, 0.04609003199584549, 0.038729963001969736, 0.05075645900069503, 0.07289155200123787, 0.03347561300324742, 0.02754861499852268, 0.047877056000288576, 0.0515510799959884, 0.08050083499983884, 0.04470267800206784, 0.04011809099756647],
# [0.06059766700491309, 0.044321524001134094, 0.044852432001789566, 0.05600145799689926, 0.04371619399898918, 0.034987535000254866, 0.04593473400018411, 0.035924358999182004, 0.05700145500304643, 0.048054200000478886, 0.036137909002718516, 0.052905071002896875, 0.036891600997478236, 0.024210422998294234, 0.030869342997903004, 0.03456927899969742, 0.05486399600340519, 0.049145509001391474, 0.0292262840011972, 0.031160965001618024, 0.04314926200459013, 0.03736786099761957, 0.03871891900053015, 0.058568262997141574, 0.034802846996171866, 0.06539985199924558, 0.03965758399863262, 0.06704498599719955, 0.05524629499996081, 0.040190864994656295, 0.07270387899916386, 0.0407109540028614, 0.04606305899505969, 0.03648358999635093, 0.04739266200340353, 0.056367481003690045, 0.0343653070012806, 0.028153966006357223, 0.031423766995430924, 0.040090016002068296, 0.04929436400561826, 0.04215431299962802, 0.0690132480012835, 0.03548841000156244, 0.034863436994783115, 0.029178277996834368, 0.09049790399876656, 0.07526266600325471, 0.031259764000424184, 0.03934472100081621, 0.06213217500044266, 0.04164046500227414, 0.04282726800011005, 0.03346603699901607, 0.0558953189975, 0.07174053600465413, 0.04245580400311155, 0.050032202001602855, 0.03799430999788456, 0.044824462005635723, 0.11715058999834582, 0.06697170199913671, 0.14550606600096216, 0.04877648500405485, 0.0633019150045584, 0.03407352700014599, 0.03742479899665341, 0.038927824003621936, 0.042967663997842465, 0.0609712329969625, 0.09092714000144042, 0.04700433099787915, 0.05556646500190254, 0.0371584710010211, 0.02929646100528771, 0.04669775700313039, 0.0539834939991124, 0.05742871400434524, 0.03565712099953089, 0.07697771699895384, 0.029064542999549303, 0.02602607200242346, 0.03956666799786035, 0.040727680003328715, 0.038130745000671595, 0.04621806100476533, 0.02929288100131089, 0.06483432200184325, 0.028797492996091023, 0.05589781500020763, 0.02958510299504269, 0.058332075001089834, 0.11833356100396486, 0.16752738200011663, 0.05765397800132632, 0.05956849700305611, 0.03988447099982295, 0.05109297700255411, 0.028284004001761787, 0.030339450000610668],
# ]
# # Remaining data not yet pput into arrays (groups of 100)
# NN Traj [158, 286, 188, 174, 242, 140, 156, 152, 180, 142, 192, 196, 258, 178, 158, 176, 266, 288, 308, 182, 226, 210, 144, 232, 230, 156, 236, 178, 224, 188, 214, 216, 232, 152, 228, 194, 144, 302, 304, 190, 250, 340, 196, 156, 184, 360, 152, 248, 172, 350, 172, 194, 204, 144, 160, 170, 288, 168, 208, 218, 184, 154, 226, 206, 210, 232, 280, 160, 158, 232, 166, 204, 186, 218, 246, 234, 206, 164, 132, 140, 258, 218, 220, 256, 232, 288, 216, 274, 208, 242, 190, 236, 174, 180, 172, 228, 198, 254, 164, 268]
# NN Time [1.3212577890008106, 1.9925428720016498, 1.3859341220013448, 1.4921062669964158, 1.965525075996993, 1.0918242190018645, 1.383629554999061, 1.2273182930002804, 1.6578422710008454, 1.0917626789960195, 1.6642442839947762, 1.5164902140022605, 1.8275394200027222, 1.401741535999463, 1.289632284999243, 1.3692463170009432, 2.0517160320014227, 2.1291743889960344, 2.3463153379998403, 1.2555686500054435, 1.8674906399974134, 1.4713658559994656, 1.2121894880037871, 1.6551104230020428, 1.7551386889972491, 1.238382997995359, 1.9709356849998585, 1.3084131499999785, 1.6703547579963924, 1.5418989840036375, 1.5615753020028933, 1.7052438819955569, 1.6515617219993146, 1.2367673230037326, 1.5349027129996102, 1.4816381900018314, 1.1258776110043982, 2.4193257439983427, 2.3871968629973708, 2.6267324489963357, 2.7794182319994434, 4.116087060996506, 3.523471344000427, 1.7980555339963757, 1.4649449959979393, 2.96255233699776, 1.3810303479986032, 2.073908033002226, 1.392418313997041, 2.508210179003072, 1.7352602049941197, 1.5681698729968048, 1.9273307109979214, 1.1843873539983178, 1.56123356900207, 1.8738530909977271, 2.1801255000027595, 1.3677090589990257, 1.8160756040015258, 1.751122685993323, 1.1401434490035172, 1.197604050001246, 1.7129510140002822, 1.5621954550006194, 1.5126225510030054, 1.6719115700034308, 1.8650925140027539, 1.3389352570011397, 1.1960232900019037, 1.8027359010011423, 1.2123910179943778, 1.6604682120014331, 1.36814405000041, 1.5568391190026887, 1.7133855099964421, 1.4436364029970719, 1.6422913339993102, 1.2569241259989212, 1.0541661610041047, 1.0454005009960383, 1.5958859980019042, 1.6708123909993446, 1.6902598020024016, 1.7733599159982987, 1.833881590006058, 2.0393629140016856, 1.719179918000009, 2.0140843690023758, 1.7394078900033492, 1.8313406050001504, 1.7330921989996568, 1.8716963350016158, 1.3781554410015815, 1.48177477200079, 1.36401301099977, 1.9833854219978093, 1.428101715995581, 1.6270873530011158, 1.2846593180001946, 1.8905200650042389]
# CNN Path [155, 186, 156, 153, 189, 128, 134, 122, 167, 135, 155, 186, 193, 150, 133, 141, 162, 161, 234, 156, 173, 127, 115, 147, 161, 137, 193, 158, 152, 168, 184, 165, 173, 131, 135, 162, 128, 145, 215, 156, 151, 153, 161, 154, 177, 218, 154, 187, 155, 265, 147, 142, 163, 123, 145, 141, 198, 184, 185, 180, 149, 138, 185, 165, 130, 128, 204, 154, 150, 198, 135, 163, 147, 159, 147, 190, 165, 152, 133, 123, 187, 156, 212, 191, 174, 200, 191, 195, 226, 161, 168, 186, 136, 232, 152, 174, 157, 209, 146, 151]
# CNN Traj [182, 238, 196, 170, 246, 144, 138, 128, 194, 142, 170, 228, 218, 168, 152, 156, 180, 194, 312, 174, 202, 142, 122, 154, 208, 154, 220, 176, 174, 186, 232, 190, 210, 144, 150, 184, 132, 160, 258, 174, 178, 172, 188, 156, 196, 258, 178, 244, 172, 358, 158, 164, 216, 130, 158, 156, 228, 224, 214, 200, 170, 150, 214, 184, 144, 138, 250, 202, 178, 226, 148, 176, 164, 178, 170, 252, 186, 178, 138, 130, 238, 176, 284, 246, 184, 246, 232, 266, 286, 184, 196, 216, 142, 300, 194, 198, 176, 248, 176, 162]
# CNN Time [2.711082063993672, 3.360159816002124, 2.4103589779988397, 2.814428728997882, 3.9095874690028722, 2.343037851002009, 2.5263848089962266, 1.9235784070042428, 2.7977815010017366, 2.1468019550011377, 2.9434080020000692, 3.6278398399954312, 3.4232137280050665, 2.75045240700274, 2.039091312995879, 2.5355182910061558, 3.0564265240027453, 3.133683671003382, 4.309712808004406, 2.823043731004873, 3.2579220879997592, 2.2821742500018445, 2.0027810159954242, 2.3793237799982307, 2.9308623199976864, 2.175165319000371, 3.5040336800011573, 2.508222181000747, 2.4399698410052224, 2.951628137998341, 3.430839189000835, 2.9938548259960953, 3.3077243320003618, 2.3336603689967887, 2.30668538199825, 2.82677543599857, 1.9464514600040275, 2.4404622389993165, 7.249775797004986, 5.745624082999711, 5.3269884979963535, 8.641017690999433, 7.737131839996437, 2.895719125997857, 3.799083691003034, 4.672995979999541, 2.816895960997499, 3.9647072269945056, 2.9998880150014884, 6.232357878005132, 3.347415946998808, 3.205842149000091, 2.9803414500056533, 1.9407362510028179, 2.9318480500005535, 2.7663469029939733, 3.44031678999454, 3.7135396090016, 3.7496643609993043, 3.2510515229951125, 2.5566094040041207, 2.381697887998598, 3.298104030000104, 2.65121609599737, 2.169672404997982, 1.950294741000107, 3.5918463979978696, 2.581729618999816, 2.8209513889960363, 3.49545944500278, 2.094565125000372, 2.8403869599933387, 2.3321344229989336, 2.8226805529993726, 2.587849434996315, 3.3409201279937406, 2.9403476399966166, 2.4944126210029935, 2.092010244996345, 1.954739177999727, 3.301532457000576, 2.7501605249999557, 3.8731951910012867, 3.2537436069978867, 3.0512338699991233, 3.6181718919979176, 3.409836382998037, 3.5232774070027517, 4.1151843030020245, 2.8276591839967296, 2.934140084995306, 3.068190054000297, 2.4336851369953365, 4.413591144999373, 2.8378188950009644, 3.071988056995906, 2.6329255890013883, 3.6279318000015337, 2.5155894560011802, 2.8506642369975452]
# A Path [148, 185, 157, 159, 155, 128, 134, 122, 170, 135, 163, 177, 190, 153, 133, 141, 162, 159, 228, 155, 162, 138, 115, 147, 161, 137, 194, 156, 210, 160, 161, 166, 167, 139, 135, 161, 128, 145, 210, 156, 151, 152, 164, 154, 164, 212, 186, 168, 167, 241, 147, 139, 157, 144, 151, 161, 197, 179, 161, 180, 145, 138, 183, 175, 130, 128, 201, 151, 139, 185, 132, 163, 134, 145, 147, 205, 165, 149, 126, 123, 158, 167, 227, 183, 173, 200, 175, 186, 222, 156, 165, 127, 136, 144, 167, 167, 156, 208, 141, 151]
# A Traj [219, 323, 256, 248, 237, 190, 192, 158, 260, 177, 261, 284, 286, 239, 202, 222, 255, 273, 383, 223, 272, 230, 171, 204, 273, 200, 319, 228, 365, 249, 228, 263, 263, 227, 201, 243, 165, 209, 350, 232, 236, 240, 262, 200, 253, 338, 283, 265, 250, 410, 217, 211, 247, 223, 226, 250, 306, 274, 264, 269, 209, 200, 279, 262, 192, 172, 324, 239, 215, 291, 187, 250, 210, 216, 227, 352, 252, 221, 174, 170, 239, 283, 393, 297, 247, 332, 290, 344, 401, 251, 271, 185, 192, 232, 256, 254, 231, 315, 217, 237]
# A Time [0.04326045200286899, 0.051850573006959166, 0.05430974200135097, 0.05303594500583131, 0.04570116399554536, 0.039044782002747525, 0.03518680100387428, 0.02227607099484885, 0.06342742999549955, 0.035134876998199616, 0.04856938600278227, 0.06087633300194284, 0.04494999199960148, 0.042730002001917455, 0.04018744900531601, 0.040613445002236404, 0.04560205199959455, 0.06288264899922069, 0.08604311799717834, 0.031193105998681858, 0.061960535997059196, 0.05067266499827383, 0.034087302003172226, 0.039914455999678466, 0.06953652700030943, 0.038544707000255585, 0.062337283001397736, 0.031694805998995434, 0.07207159799872898, 0.04633512000145856, 0.036190220998832956, 0.05062284800078487, 0.04506821499671787, 0.03770229600195307, 0.03713996899750782, 0.040979500001412816, 0.017390916000294965, 0.03442163900035666, 0.08519965400046203, 0.06763657999545103, 0.11181949300225824, 0.1360999480020837, 0.08034329499787418, 0.031582003000949044, 0.05417806899640709, 0.10522602999844821, 0.06454800099891145, 0.042256225999153685, 0.046528282000508625, 0.0780488050004351, 0.05487411700596567, 0.05783274599525612, 0.05470834100560751, 0.03906335899955593, 0.057480971998302266, 0.03930255400337046, 0.04133783899305854, 0.04202608700143173, 0.0506966890025069, 0.05153069700463675, 0.03172251100477297, 0.039150333002908155, 0.039017419003357645, 0.061327715004154015, 0.030289039001218043, 0.024384370997722726, 0.056557789001089986, 0.05720518199814251, 0.0410596829970018, 0.07504594799684128, 0.026770820004458074, 0.04000819499924546, 0.04284277799888514, 0.041523981002683286, 0.03336424199369503, 0.06809560300462181, 0.04300578399852384, 0.0521691899994039, 0.030581840001104865, 0.027566729004320223, 0.039313431007030886, 0.046262224997917656, 0.08027823999873362, 0.05241790999571094, 0.04242302400234621, 0.053181338000285905, 0.06403886499902, 0.06517877699661767, 0.07930802500050049, 0.03836063499329612, 0.05043060699972557, 0.029304352996405214, 0.028831475996412337, 0.04270634899876313, 0.04144349200214492, 0.04035905199998524, 0.03629380899656098, 0.053951046997099183, 0.030079632000706624, 0.03998750800383277]
# starting iteration 3
# NN Traj [316, 242, 144, 130, 148, 198, 204, 148, 148, 156, 180, 184, 130, 308, 194, 242, 374, 188, 304, 382, 212, 244, 290, 212, 192, 206, 156, 154, 166, 132, 194, 240, 264, 268, 366, 200, 214, 150, 252, 144, 204, 184, 146, 368, 254, 204, 152, 178, 126, 172, 148, 254, 328, 148, 194, 278, 150, 224, 160, 258, 248, 304, 324, 154, 242, 262, 192, 152, 198, 168, 170, 178, 254, 190, 162, 140, 196, 252, 168, 192, 258, 180, 164, 370, 178, 168, 150, 236, 210, 168, 182, 168, 238, 268, 206, 174, 190, 154, 220, 184]
# NN Time [2.3131831509963376, 1.9592696899999282, 1.1757495279962313, 1.1162934629974188, 1.2056505209984607, 1.628381795999303, 1.6193115950009087, 1.184142060999875, 1.1214932040020358, 1.1213788320019376, 1.3345970379959908, 1.5084329419987625, 1.0881288469972787, 2.5816097439965233, 1.466197301997454, 1.9190272750056465, 2.5244659680029145, 1.577415909996489, 2.190823803000967, 3.0233920050013694, 1.7538702970050508, 1.8991334689999348, 2.1044306859985227, 1.4652353779965779, 1.6574595999991288, 1.6027719570047338, 1.0961469290050445, 1.2187497180057107, 1.3944048289995408, 1.1185819960010122, 1.540080181999656, 1.7789756970014423, 1.9221053319997736, 1.9515726099998574, 2.581804976995045, 1.6516621420014417, 1.6605321960014408, 1.1264536750022671, 1.853592708997894, 1.2055239629989956, 1.463603886004421, 1.525715436997416, 1.3319383150010253, 2.6544540179966134, 2.0457256009976845, 1.5161012000025949, 1.4016927749980823, 1.4130931040053838, 1.005016158997023, 1.3371255149977515, 1.1590161040003295, 1.91730419400119, 2.4123229239994544, 1.0638457710010698, 1.5192504820006434, 2.149744362999627, 1.270587869999872, 1.9520225710002705, 1.1762785620012437, 1.9623558989987941, 1.745537955997861, 2.3166746979986783, 2.115075314002752, 1.203904574002081, 1.89231960199686, 1.7100835259989253, 1.543843484003446, 1.1807803330011666, 1.5204758899999433, 1.3458891150003183, 1.2529576779998024, 1.359813001996372, 1.7120116930018412, 1.4216985430030036, 1.3822036589990603, 1.2067782660014927, 1.4894425560050877, 1.7831931350010564, 1.3058568600026774, 1.3054844119978952, 2.059465950995218, 1.4581472459976794, 1.270023143995786, 2.62894125800085, 1.3782471529993927, 1.3241366399961407, 1.2472128329973202, 1.5665804310046951, 1.6826258090004558, 1.189605464001943, 1.4288445909987786, 1.5222388010006398, 1.7531571130020893, 1.815441214996099, 1.4357464810018428, 1.3139854370019748, 1.5090739529987331, 1.1563964579981985, 1.6351650880023954, 1.614636698999675]
# CNN Path [260, 192, 135, 121, 147, 180, 170, 136, 142, 139, 138, 167, 124, 213, 228, 232, 202, 233, 176, 267, 188, 194, 320, 176, 158, 142, 142, 143, 149, 146, 169, 129, 194, 171, 252, 138, 172, 130, 162, 130, 190, 149, 137, 251, 259, 150, 158, 142, 120, 152, 138, 200, 214, 134, 148, 220, 137, 183, 153, 197, 152, 176, 182, 139, 132, 193, 150, 137, 147, 136, 150, 133, 197, 272, 265, 123, 164, 175, 126, 165, 219, 155, 131, 262, 137, 141, 142, 219, 174, 148, 164, 139, 170, 150, 161, 158, 133, 131, 177, 161]
# CNN Traj [342, 228, 154, 130, 154, 216, 192, 146, 156, 150, 154, 184, 130, 270, 288, 282, 302, 308, 206, 330, 228, 244, 516, 206, 176, 156, 158, 160, 166, 156, 186, 142, 224, 190, 344, 152, 202, 136, 200, 142, 234, 182, 146, 354, 364, 176, 174, 160, 122, 170, 150, 232, 256, 144, 168, 258, 152, 216, 172, 230, 166, 218, 202, 164, 148, 228, 160, 158, 174, 154, 160, 140, 244, 362, 354, 132, 186, 218, 132, 180, 296, 174, 138, 334, 172, 160, 146, 308, 182, 168, 174, 154, 206, 176, 182, 176, 146, 146, 202, 178]
# CNN Time [4.945114360998559, 3.3420685659948504, 2.351106369002082, 2.0140302859945223, 2.2618696770005045, 3.3318038859943044, 2.9137950820004335, 2.2930109279986937, 2.416666309996799, 2.1032558379956754, 2.400186106999172, 3.0540676079981495, 1.9901796339981956, 4.049380112999643, 4.105204333005531, 4.355358982000325, 3.895965024996258, 4.640226133997203, 2.9836677450002753, 5.148707915002888, 3.4758050960008404, 3.5822699530035607, 5.769960819998232, 3.045340117998421, 2.765636860000086, 2.522148705997097, 2.52376502900006, 2.425705886002106, 2.7079717130036443, 2.4153163349983515, 2.866817722999258, 2.2354098970026826, 3.408465298001829, 2.8238023110025097, 4.924985402998573, 2.415222811003332, 2.899627496000903, 2.254958543002431, 2.996228123003675, 2.2130638179951347, 3.364184245998331, 2.618534016000922, 2.434872620004171, 5.046626000999822, 5.355156862999138, 2.6763259959989227, 2.737001096997119, 2.3698084880015813, 1.7947202659997856, 2.6278330109998933, 2.2366199679963756, 3.7698755279998295, 4.018575535999844, 2.0024352010004804, 2.3574645780026913, 4.100198995998653, 2.2973408230027417, 3.713311583000177, 2.720891043994925, 3.9146145330014406, 2.378712023004482, 3.0677492530012387, 2.9915382440012763, 2.2120918209984666, 2.3063199410025845, 3.295006487998762, 2.5729798860047595, 2.084448137997242, 2.584196632000385, 2.4350837120000506, 2.4216039719976834, 2.2817018979985733, 3.85980581700278, 4.994538202001422, 4.647876504001033, 2.0678355990021373, 2.8701243640025496, 3.372562086005928, 2.1334525019992725, 2.4640426599944476, 4.0792334619982284, 3.0081565590007813, 2.0114090879942523, 4.944621315000404, 2.532898714998737, 2.4456944779958576, 2.198961814996437, 3.749500978003198, 2.7139484970030026, 2.5658725239991327, 2.8190564640026423, 2.2420506100024795, 3.3405181810012436, 2.5494875040021725, 2.5462262720029685, 2.7616324390037335, 2.3030392820000998, 2.2553941719961585, 2.986115419997077, 2.7509792650016607]
# A Path [244, 189, 132, 121, 139, 163, 169, 136, 142, 139, 138, 177, 124, 204, 193, 230, 199, 217, 179, 223, 186, 178, 193, 166, 159, 142, 137, 137, 149, 146, 169, 129, 194, 171, 226, 137, 156, 132, 175, 130, 193, 143, 137, 221, 157, 138, 157, 188, 120, 152, 136, 199, 203, 134, 172, 198, 134, 176, 141, 196, 152, 173, 190, 134, 147, 190, 165, 126, 145, 136, 148, 144, 188, 204, 250, 123, 141, 173, 126, 165, 184, 155, 127, 231, 148, 140, 142, 154, 176, 148, 165, 167, 164, 146, 160, 159, 153, 131, 169, 162]
# A Traj [428, 304, 201, 177, 197, 281, 253, 196, 216, 195, 206, 272, 176, 332, 307, 373, 367, 363, 271, 389, 304, 294, 333, 266, 245, 222, 199, 208, 225, 214, 246, 197, 295, 248, 401, 199, 240, 185, 299, 186, 307, 218, 205, 388, 261, 195, 227, 335, 147, 231, 196, 316, 324, 191, 251, 314, 202, 289, 202, 306, 217, 289, 299, 195, 231, 296, 248, 193, 234, 215, 208, 215, 308, 336, 420, 185, 216, 287, 172, 228, 308, 238, 176, 378, 259, 212, 195, 232, 245, 227, 242, 261, 276, 225, 237, 249, 235, 203, 265, 249]
# A Time [0.08892450200073654, 0.04267270600394113, 0.03693932099849917, 0.03570451599807711, 0.028758701002516318, 0.055100133999076206, 0.0348225839989027, 0.030272582996985875, 0.03910133199678967, 0.03030911899986677, 0.0352981020041625, 0.040246546996058896, 0.031535709000309, 0.05188021000503795, 0.054147657996509224, 0.07277082899963716, 0.06854295400262345, 0.07194811500085052, 0.04628036900248844, 0.07148547999531729, 0.06051465299970005, 0.05260770599852549, 0.05164230200171005, 0.0488976909982739, 0.04197708900028374, 0.04144760000053793, 0.0310659519964247, 0.035501523001585156, 0.03954965100274421, 0.0472352819997468, 0.03594075000000885, 0.043282657003146596, 0.041589045998989604, 0.07888542299770052, 0.07613752400357043, 0.03402561000257265, 0.03721763499925146, 0.026358759998402093, 0.049896197000634857, 0.02677972499805037, 0.057379759004106745, 0.042101825994905084, 0.03989081600593636, 0.07822594600293087, 0.042355479003163055, 0.032610895003017504, 0.032807711002533324, 0.07284381700446829, 0.02887890899728518, 0.05217055500543211, 0.03990672800136963, 0.04818862899992382, 0.059414356997876894, 0.03587033299845643, 0.05952012199850287, 0.05879090500093298, 0.04983734000416007, 0.054740993000450544, 0.03298489000007976, 0.04058156700193649, 0.029498951997084077, 0.056697539999731816, 0.06145005200232845, 0.034602992003783584, 0.04549926899926504, 0.03797245500027202, 0.040333171004022006, 0.03361676799977431, 0.04597364999790443, 0.04196097499516327, 0.034056587996019516, 0.04042021700297482, 0.05231480499787722, 0.058599515999958385, 0.0799129700026242, 0.03608046600129455, 0.04029892299877247, 0.05233828999917023, 0.03652066300128354, 0.02799215000413824, 0.07322925300104544, 0.04320804900635267, 0.029960290994495153, 0.06793834899872309, 0.05853194300289033, 0.03499812399968505, 0.03141881200281205, 0.04696276300091995, 0.03242510099516949, 0.03809566899872152, 0.03656515300099272, 0.04940078499930678, 0.046909805001632776, 0.04238029699627077, 0.0328849369980162, 0.04852876499353442, 0.040025950998824555, 0.05995137900026748, 0.04450982499838574, 0.057083924002654385]
# starting iteration 4
# NN Traj [348, 176, 150, 256, 316, 166, 188, 260, 158, 368, 196, 216, 320, 336, 164, 304, 256, 282, 420, 204, 202, 158, 180, 248, 254, 186, 206, 232, 158, 244, 150, 162, 160, 198, 252, 232, 160, 130, 256, 200, 374, 212, 122, 250, 148, 152, 178, 198, 272, 208, 210, 164, 206, 182, 200, 264, 160, 220, 166, 148, 228, 182, 266, 208, 156, 166, 242, 180, 338, 136, 152, 178, 204, 226, 168, 126, 204, 186, 144, 174, 188, 254, 208, 202, 314, 150, 172, 212, 178, 208, 402, 152, 238, 196, 246, 178, 312, 154, 170, 230]
# NN Time [2.2027992329967674, 1.3496584290041937, 1.2511746649979614, 1.9258080270010396, 2.297494682003162, 1.4011027509986889, 1.4612887149996823, 1.8652982280036667, 1.1928860250045545, 2.2067909889956354, 1.6024519569982658, 1.6348459639993962, 2.3107727609967696, 2.2126650779973716, 1.206642662997183, 2.2855880289935158, 2.0508350759992027, 1.9387622610011022, 2.7747432500036666, 1.536026983994816, 1.4865079990049708, 1.2625442179996753, 1.4433376750021125, 1.7465356449974934, 1.6743906060000882, 1.5666758549996302, 1.7048464360050275, 1.9297468770018895, 1.2783846980019007, 2.063063664994843, 1.7902150690060807, 1.48318125300284, 1.280140492002829, 1.6232332970030257, 1.745448296998802, 2.138889529000153, 1.3564424059950397, 0.941764121002052, 2.0047557989964844, 1.6171188630032702, 2.3869185160001507, 1.9074787249992369, 1.0690885690055438, 2.0264085900053033, 1.300367825999274, 1.1253905150006176, 1.4126482819992816, 1.3693520669985446, 2.0307225479991757, 1.5534390920001897, 1.616009170000325, 1.3197858780040406, 1.6217127839991008, 1.3694937489999575, 1.6088452300027711, 1.9757093849984813, 1.3265166159981163, 1.74275315500563, 1.3771848640026292, 1.1222375700017437, 1.8447507720047724, 1.4301859229963156, 1.9911471720042755, 1.4906387540031574, 1.4893984380032634, 1.306276584000443, 1.6511586339984206, 1.0509032679983648, 2.430802507995395, 1.0685883480036864, 1.1775424309962546, 1.2918931480016909, 1.6031489229935687, 1.8234451669995906, 1.3474417820034432, 0.9648501389965531, 1.6704511479983921, 1.4652484999969602, 1.2575578830001177, 1.3079627839979366, 1.5615625799982809, 1.8951313349971315, 1.4507103740033926, 1.5670142020026105, 2.312560852995375, 1.1428042599945911, 1.4247600620001322, 1.7924320040037856, 1.4753199100014172, 1.384682238996902, 2.7915558620006777, 1.2578519099988625, 1.8367122230047244, 1.5772400779969757, 1.9524013420013944, 1.4599449800007278, 2.243018419998407, 1.1355832609988283, 1.4245310400001472, 1.6082047180025256]
# CNN Path [230, 152, 144, 229, 209, 150, 170, 219, 131, 199, 171, 291, 216, 280, 144, 162, 202, 286, 260, 183, 154, 140, 157, 154, 200, 178, 162, 150, 204, 207, 132, 158, 145, 157, 212, 201, 149, 129, 160, 179, 240, 171, 116, 170, 139, 135, 145, 159, 202, 159, 159, 146, 149, 237, 184, 212, 135, 258, 152, 158, 156, 151, 193, 186, 135, 266, 250, 154, 231, 131, 135, 131, 166, 129, 145, 131, 174, 179, 131, 143, 162, 164, 159, 215, 289, 152, 153, 179, 142, 168, 246, 149, 165, 141, 229, 160, 174, 136, 162, 155]
# CNN Traj [290, 168, 168, 300, 254, 164, 222, 270, 134, 280, 194, 330, 264, 346, 160, 178, 256, 382, 344, 236, 178, 158, 178, 172, 232, 206, 184, 166, 242, 254, 148, 174, 156, 184, 274, 250, 172, 144, 186, 208, 296, 200, 130, 204, 154, 152, 178, 184, 240, 192, 202, 162, 156, 326, 206, 264, 150, 286, 186, 182, 176, 178, 230, 248, 148, 324, 292, 172, 264, 146, 148, 138, 198, 136, 162, 138, 226, 238, 144, 152, 186, 192, 190, 280, 338, 170, 174, 210, 164, 214, 288, 170, 192, 162, 286, 178, 210, 150, 176, 218]
# CNN Time [4.121950959000969, 2.4739014669976314, 2.254333705001045, 4.264367643998412, 3.605567761005659, 2.5821802359932917, 3.083163796000008, 3.6505684449948603, 2.0159359399985988, 3.6912219480000203, 2.977504452006542, 5.107955277999281, 3.8440609450044576, 4.810199710998859, 2.375150357001985, 2.8396554209975875, 3.7134365969977807, 5.374689226002374, 5.293794142002298, 3.1462977999981376, 2.9073227379994933, 2.433296903000155, 2.954053813999053, 2.777429623005446, 3.6533715529949404, 3.216809267003555, 3.2877588720002677, 2.3821069140030886, 3.796705814995221, 3.7881551359969308, 2.448793587995169, 2.610754826993798, 2.741181309997046, 2.7758836049979436, 4.7229290879986365, 3.854869475006126, 2.9167685770007665, 2.20317499399971, 3.05708192900056, 3.2668319509975845, 4.540833936000126, 3.1070797560023493, 2.1224256900022738, 2.8935779680032283, 2.6327961729984963, 2.3364262960021733, 2.6886856829951284, 2.4763251449985546, 3.2651733480015537, 2.635341139997763, 3.0092442120003398, 2.689875016003498, 2.4454356639980688, 4.60895407100179, 3.4382438870015903, 3.5837015680008335, 2.466151869004534, 4.408192898001289, 3.066398479997588, 2.924130482999317, 2.7638447680001264, 2.5502520539957914, 3.6947263129986823, 3.2421087399998214, 2.320765696997114, 4.831214552999882, 4.617021992002265, 2.207500575001177, 4.298402240005089, 2.2649938089962234, 2.3035048349993303, 2.115088165002817, 2.927317444002256, 2.212804894996225, 2.4491248600024846, 2.1222120949969394, 2.896480433999386, 2.9202714470011415, 2.3340579139985493, 2.385282154995366, 2.8244400639960077, 2.61119146800047, 2.5370966019982006, 3.6080376059981063, 4.9770983300040825, 2.495409365998057, 2.729621592006879, 3.1903509399999166, 2.358411459994386, 3.125815417006379, 4.70694298500166, 2.736367004996282, 2.8657825960035552, 2.3026864290004596, 4.176635495001392, 2.6332684380031424, 3.4680435440022848, 2.326727026993467, 2.666200731997378, 2.7874691410033847]
# A Path [227, 150, 144, 196, 202, 193, 166, 226, 131, 194, 171, 192, 214, 245, 142, 154, 201, 178, 242, 187, 158, 140, 158, 153, 199, 184, 157, 150, 140, 171, 132, 154, 143, 157, 187, 208, 143, 129, 199, 171, 242, 177, 115, 170, 137, 135, 133, 159, 196, 158, 158, 146, 149, 147, 177, 197, 135, 167, 145, 134, 148, 148, 197, 173, 135, 135, 175, 154, 230, 131, 135, 131, 158, 122, 163, 128, 163, 178, 129, 133, 153, 164, 159, 165, 191, 134, 151, 170, 170, 166, 254, 152, 152, 141, 197, 161, 171, 136, 162, 162]
# A Traj [374, 222, 207, 334, 327, 303, 259, 346, 175, 314, 265, 292, 351, 387, 215, 238, 322, 313, 424, 286, 256, 208, 248, 231, 311, 307, 244, 216, 217, 277, 206, 222, 201, 247, 321, 363, 234, 181, 347, 262, 407, 262, 172, 264, 204, 199, 220, 237, 306, 246, 244, 220, 213, 220, 264, 306, 204, 245, 237, 195, 224, 254, 335, 282, 206, 188, 275, 227, 357, 190, 204, 185, 255, 179, 238, 178, 248, 272, 187, 184, 235, 254, 246, 256, 303, 191, 230, 256, 281, 277, 436, 237, 238, 216, 323, 249, 280, 198, 240, 275]
# A Time [0.07061816699570045, 0.03561878899927251, 0.03560868399654282, 0.08036649100540671, 0.0560449240001617, 0.04730196500167949, 0.050683164001384284, 0.047603803999663796, 0.03566566300287377, 0.08309137099422514, 0.04636236299847951, 0.040926813999249134, 0.05069077700318303, 0.055791013000998646, 0.037827076004759874, 0.05790038000122877, 0.05771599100262392, 0.09435293600108707, 0.0865061010044883, 0.050664098002016544, 0.04386282800260233, 0.062197267994633876, 0.03787423099856824, 0.0394925489963498, 0.0462508609998622, 0.11970190900319722, 0.05399023799691349, 0.04693971700180555, 0.041112494000117294, 0.051874250995751936, 0.049439210997661576, 0.03212486099801026, 0.0384016710013384, 0.04030989099555882, 0.056449920004524756, 0.06599484099569963, 0.04911540600005537, 0.024399866997555364, 0.07492404300137423, 0.03671874499559635, 0.10938456500298344, 0.038540469999134075, 0.03353274899563985, 0.04326756700174883, 0.042594409998855554, 0.032850046998646576, 0.04279405799752567, 0.036016707999806385, 0.03893460600374965, 0.04731934500159696, 0.03964214200095739, 0.04277049400116084, 0.03337579300568905, 0.050479965000704397, 0.0401813200005563, 0.06482292000146117, 0.032854410994332284, 0.05456081799638923, 0.040959198006021325, 0.0308768710019649, 0.037501573002373334, 0.044902601999638136, 0.054274786001769826, 0.057376698998268694, 0.05095900499873096, 0.025253388004784938, 0.04284013599681202, 0.038211708997550886, 0.060764082998503, 0.03160395699524088, 0.03608346200053347, 0.031007767000119202, 0.056647854995389935, 0.031723635001981165, 0.0380931289982982, 0.03402677999838488, 0.04247030400438234, 0.04365488600160461, 0.030421755000134, 0.02743830699910177, 0.05615469099575421, 0.04903912099689478, 0.04046323800139362, 0.04238558799988823, 0.04669607499818085, 0.02970003499649465, 0.04160853000212228, 0.0392998110037297, 0.05072656899574213, 0.050890824997622985, 0.0865952149979421, 0.04368172699469142, 0.049587899004109204, 0.03931668499717489, 0.05433915400499245, 0.038693433001753874, 0.06851091200223891, 0.029626386996824294, 0.043197765000513755, 0.07501617199886823]
# starting iteration 5
# NN Path
# NN Traj [212, 190, 188, 166, 284, 184, 148, 164, 162, 232, 178, 318, 194, 234, 230, 144, 124, 224, 254, 156, 258, 334, 166, 216, 170, 176, 168, 236, 242, 226, 322, 174, 166, 172, 602, 252, 160, 234, 176, 138, 178, 198, 252, 202, 162, 168, 156, 254, 184, 186, 312, 172, 212, 254, 234, 176, 254, 150, 300, 154, 266, 252, 206, 372, 158, 202, 234, 204, 214, 160, 202, 576, 160, 144, 406, 246, 166, 174, 204, 150, 188, 208, 312, 154, 188, 130, 244, 216, 252, 292, 308, 184, 118, 160, 140, 344, 290, 234, 224, 250]
# NN Time [1.7318937059972086, 1.4335537569932058, 1.481015829005628, 1.2245296309993137, 2.0541704220013344, 1.4664190749972477, 1.1336085560033098, 1.3133260289978352, 1.3056308459999855, 1.7602681030039093, 1.3435901560005732, 2.261394819004636, 1.3946657330016023, 1.7946401799999876, 1.6341051390045322, 1.1778619249962503, 0.9469989229983184, 1.5845646759989904, 2.088121784996474, 1.2608444840006996, 1.969210896997538, 2.211122686996532, 1.3583837909973226, 1.770012218003103, 1.3484398729997338, 1.3351294419990154, 1.3852043299993966, 1.5318532149976818, 1.8388704309982131, 1.5982249250009772, 2.4146564190014033, 1.2711292489984771, 1.3859235600029933, 1.3612371350027388, 3.6241708419984207, 1.982974253995053, 1.443185812997399, 1.5735662049992243, 1.3825470020019566, 1.1467948270001216, 1.4344114399937098, 1.6381163590049255, 1.9830864439936704, 1.8438345969989314, 1.2448430719960015, 1.3227277500045602, 1.227147245001106, 1.9331816080011777, 1.550973730001715, 1.4519829899945762, 2.229411315995094, 1.4160848210012773, 1.6417998309989343, 1.9293250359987724, 1.9279892179984017, 1.360531059995992, 1.99507791200449, 1.2646567350020632, 1.9889172780021909, 1.222799565999594, 1.9458842120002373, 1.937217958002293, 1.516264065001451, 2.4854955570044694, 1.3630947579949861, 1.5332191120032803, 1.8487608170034946, 1.5912401040041004, 1.6033044950017938, 1.3078651609976077, 1.6031437489946256, 3.515783171998919, 1.2799995390014374, 1.1814900530007435, 2.7752617849982926, 2.0488686169992434, 1.277882377995411, 1.3345362229956663, 1.4688696429948322, 1.2639524319965858, 1.4412028690057923, 1.7220611199954874, 2.326739140000427, 1.229206288997375, 1.5049531369950273, 0.9448685579991434, 1.7980523830046877, 1.6246618699951796, 1.9249063589959405, 2.128387037999346, 2.0297466589981923, 1.5331655710033374, 1.0335912729933625, 1.2292772520013386, 1.082191690999025, 2.3991355129983276, 2.100800722000713, 1.7480480480007827, 1.6459673539939104, 1.8535421679989668]
# CNN Path [136, 161, 158, 146, 193, 143, 232, 165, 170, 187, 154, 203, 157, 191, 154, 140, 122, 220, 233, 147, 338, 204, 163, 143, 180, 195, 145, 150, 174, 169, 214, 206, 150, 147, 295, 325, 147, 149, 153, 129, 156, 159, 172, 203, 148, 165, 169, 198, 160, 150, 148, 150, 127, 144, 180, 147, 201, 149, 188, 132, 152, 211, 143, 282, 140, 169, 176, 160, 130, 163, 161, 222, 176, 128, 248, 175, 146, 158, 168, 136, 180, 235, 224, 144, 213, 118, 177, 214, 166, 220, 197, 174, 136, 147, 124, 192, 191, 175, 172, 190]
# CNN Traj [144, 176, 172, 162, 240, 168, 292, 196, 200, 212, 170, 240, 174, 218, 174, 146, 124, 270, 312, 168, 496, 278, 184, 154, 216, 222, 158, 164, 196, 208, 260, 234, 164, 170, 412, 446, 166, 164, 168, 144, 186, 176, 202, 242, 164, 200, 194, 242, 186, 174, 164, 168, 136, 164, 228, 160, 260, 166, 234, 150, 174, 230, 164, 406, 154, 194, 214, 202, 144, 186, 168, 300, 198, 134, 364, 192, 160, 188, 190, 150, 204, 278, 262, 166, 250, 124, 198, 248, 208, 274, 292, 214, 148, 156, 132, 258, 230, 212, 204, 218]
# CNN Time [2.2101541810043273, 2.637010778998956, 2.5148024830050417, 2.4190002270042896, 3.396242760005407, 2.3799349370019627, 4.162605943005474, 2.822780873000738, 2.906385271002364, 3.18623689100059, 2.4163168919985765, 3.5311093279960915, 2.657739037000283, 3.367911016001017, 2.515011299001344, 2.3102042370010167, 1.7405141150011332, 3.8562724539951887, 4.378857431001961, 2.558615517002181, 6.022772251999413, 3.478601411996351, 2.697987617000763, 2.3514243499957956, 3.074113174996455, 3.3259353060057038, 2.6350858119985787, 2.426415112000541, 2.691685393001535, 2.8859519450052176, 3.805203438998433, 3.6894731450011022, 2.4064747049997095, 2.3802964120040997, 5.2067420759995, 5.970456706003461, 2.6696785789972637, 2.320921177000855, 2.379995045994292, 2.4055494209969766, 3.02250561099936, 2.721476475999225, 3.123631636000937, 3.8403694999942672, 2.6347766790058813, 2.8287294629990356, 2.923654164005711, 3.6146745799997007, 2.8577235659977305, 2.7448356649983907, 2.570392150999396, 2.5478749419999076, 2.144086560001597, 2.4167775850000908, 3.1480459779995726, 2.417114740004763, 3.7557246459982707, 2.5084239709976828, 2.9786898179954733, 2.1974716910044663, 2.602295091994165, 3.724559860995214, 2.5673090419950313, 5.213749285998347, 2.275853310005914, 2.815390806994401, 3.298990324001352, 2.987071087998629, 2.513460675996612, 2.5933449000003748, 2.5926799070002744, 4.1394474179978715, 3.119122399999469, 1.907888608999201, 5.067462979000993, 3.0553873230019235, 2.365879789002065, 2.7746517429986852, 2.816144748001534, 2.149704080999072, 3.1956593799986877, 4.206670930994733, 3.8237263439950766, 2.403771751000022, 3.8073872759996448, 1.751524725004856, 3.4297701160030556, 3.7890257309991284, 3.0357705610003904, 3.7582790619999287, 3.4080161890015006, 3.141221568999754, 2.2923136629979126, 2.4817621479960508, 1.9796879350033123, 3.436639658000786, 3.3001433519966668, 3.233613862998027, 3.1434947789966827, 3.1697211589998915]
# A Path [136, 157, 158, 145, 188, 142, 123, 162, 142, 184, 153, 205, 155, 166, 155, 140, 122, 179, 222, 142, 265, 198, 163, 143, 152, 193, 145, 146, 173, 170, 249, 197, 152, 147, 274, 188, 147, 149, 141, 136, 152, 159, 158, 148, 148, 164, 169, 171, 156, 144, 148, 150, 181, 144, 318, 146, 196, 148, 184, 132, 149, 191, 144, 250, 140, 160, 173, 150, 140, 163, 153, 284, 176, 128, 222, 171, 145, 157, 168, 135, 155, 168, 207, 138, 213, 118, 175, 179, 165, 132, 186, 160, 140, 147, 124, 192, 185, 171, 173, 172]
# A Traj [193, 233, 214, 220, 301, 220, 187, 248, 226, 282, 224, 328, 233, 253, 244, 193, 155, 278, 389, 226, 479, 345, 248, 210, 227, 295, 224, 204, 260, 272, 386, 293, 229, 225, 455, 299, 230, 208, 230, 189, 244, 237, 236, 231, 220, 253, 253, 264, 235, 223, 218, 228, 326, 219, 551, 213, 350, 222, 299, 200, 224, 302, 226, 435, 205, 240, 282, 248, 216, 247, 224, 587, 275, 175, 417, 254, 201, 256, 250, 204, 241, 271, 323, 208, 348, 157, 277, 304, 278, 186, 318, 244, 220, 202, 175, 317, 290, 287, 290, 252]
# A Time [0.04449745299643837, 0.0499301279996871, 0.030559844999515917, 0.0354758449975634, 0.04056366799341049, 0.04552153999975417, 0.05028945500089321, 0.03354224099894054, 0.05221825400076341, 0.04771601199900033, 0.044496802998764906, 0.07260918799875071, 0.040860060005798005, 0.04043629400257487, 0.04402799600211438, 0.03447472400148399, 0.020043145996169187, 0.04546789699816145, 0.09500543400645256, 0.050246030994458124, 0.10467027599952416, 0.06986970800062409, 0.040680590995179955, 0.03835360500670504, 0.0329648469996755, 0.050207001993840095, 0.034167395999247674, 0.029536854999605566, 0.041167174000293016, 0.0729934740011231, 0.09081981800409267, 0.041368320999026764, 0.03939165799965849, 0.039837293996242806, 0.09427890699589625, 0.044115978002082556, 0.04538470499392133, 0.02494836500409292, 0.037714707003033254, 0.030131371997413225, 0.04325044799770694, 0.05431523500010371, 0.05986156799917808, 0.07682564000424463, 0.03787324999575503, 0.04072606399859069, 0.0391911490005441, 0.039919814000313636, 0.05572232099802932, 0.0423446509958012, 0.038163180004630703, 0.04028728700359352, 0.05871944899990922, 0.0477580040023895, 0.202177605999168, 0.03479230599623406, 0.06428913399577141, 0.05074831199453911, 0.044198663999850396, 0.036942740000085905, 0.03617479099921184, 0.04943173200445017, 0.0407507580021047, 0.09610760800569551, 0.03497853400040185, 0.03526463299931493, 0.06048700000246754, 0.03905233700061217, 0.0468341409941786, 0.048260819996357895, 0.03965395600243937, 0.11315015100262826, 0.04635277100169333, 0.03200141400157008, 0.06743575599830365, 0.03643822899903171, 0.027300309993734118, 0.04081176999898162, 0.04102819800027646, 0.05218553799932124, 0.05725372300366871, 0.052269101004640106, 0.06302268600120442, 0.02876451699557947, 0.05879529999947408, 0.031920091998472344, 0.04976333100057673, 0.05436643100256333, 0.05668076599977212, 0.03161362399987411, 0.053109335000044666, 0.0405068609979935, 0.047271130002627615, 0.028409696999005973, 0.027190622000489384, 0.055082999999285676, 0.04414089400233934, 0.057515104999765754, 0.058819163001317065, 0.04812885799765354]
# starting iteration 6
# NN Path
# NN Traj [146, 200, 216, 252, 148, 196, 176, 340, 348, 180, 148, 212, 142, 196, 210, 190, 184, 298, 144, 138, 190, 200, 134, 144, 160, 170, 270, 156, 256, 232, 170, 244, 148, 186, 236, 316, 220, 212, 276, 158, 126, 190, 152, 374, 210, 142, 184, 214, 282, 140, 176, 154, 164, 254, 196, 202, 154, 160, 242, 192, 208, 214, 176, 202, 112, 182, 206, 186, 170, 150, 276, 300, 194, 310, 226, 142, 150, 182, 294, 162, 232, 266, 226, 210, 218, 244, 172, 290, 172, 238, 396, 118, 214, 186, 234, 184, 182, 246, 176, 152]
# NN Time [1.1149351420026505, 1.4811068419949152, 1.5165383339990512, 1.9384041290031746, 1.1467902410004172, 1.4381154779985081, 1.4291763910005102, 2.500877730999491, 2.224375024998153, 1.3764177769990056, 1.1850680150018889, 1.6451054219942307, 1.1388070329994662, 1.4917024899987155, 1.572938882993185, 1.4913743370052543, 1.5184605800022837, 2.157181976006541, 1.1021235680018435, 1.0911988579973695, 1.5533793559952755, 1.4720166700062691, 1.129858364998654, 1.0594572889967822, 1.3727566589950584, 1.3737552339953254, 1.9196749050024664, 1.2919927039984032, 1.9292794310022146, 1.7457893769969814, 1.4582073920028051, 1.8872339699955774, 1.1712532390010892, 1.4543080009971163, 1.7680023229986546, 2.096813962001761, 1.2957369950017892, 1.7124924290037598, 2.017736394001986, 1.191560876002768, 1.1812019669960137, 1.5108307269983925, 1.1311051459997543, 2.2904118439982994, 1.5790212790016085, 1.1655253260032623, 1.3912572900007945, 1.6705253249965608, 2.1377452599990647, 1.1760041560046375, 1.4473163100046804, 1.146689712004445, 1.231930679998186, 1.7184390539987362, 1.4184770789943286, 1.6632612530011102, 1.2980123849993106, 1.2224357969971607, 1.8653242150030565, 1.4298148329980904, 1.4636210929966182, 1.7967133519996423, 1.2393788370027323, 1.549794546001067, 0.8944337189968792, 1.3612778059978154, 1.6282998230017256, 1.3992377079994185, 1.3589007279952057, 1.2963516200034064, 1.9453079730010359, 2.2174458500012406, 1.5875909039968974, 2.1339393820016994, 1.7762733609997667, 1.179785827996966, 1.2066435149972676, 1.4527653620025376, 1.9922190199940815, 1.3705445149971638, 1.7693909290028387, 1.8863785060020746, 1.839578225997684, 1.6042604169997503, 1.7565246619997197, 1.5666319630036014, 1.3674310050046188, 2.1612405329942703, 1.3481402819961659, 1.6560824519983726, 2.72246604400425, 0.8784658040021895, 1.536717884002428, 1.5125418300012825, 1.6845413140035816, 1.341191199993773, 1.2980386229974101, 1.90287083399744, 1.2444833980043768, 1.2667808179976419]
# CNN Path [135, 145, 151, 198, 143, 153, 156, 208, 207, 149, 133, 157, 177, 158, 219, 156, 158, 205, 183, 127, 181, 152, 135, 134, 134, 161, 177, 151, 191, 191, 183, 183, 141, 176, 158, 197, 159, 157, 194, 145, 179, 145, 145, 134, 192, 134, 155, 170, 231, 131, 143, 146, 138, 188, 203, 160, 138, 184, 246, 148, 155, 186, 149, 174, 111, 159, 160, 172, 149, 137, 196, 201, 146, 203, 160, 115, 160, 162, 164, 132, 177, 194, 217, 172, 192, 169, 168, 216, 168, 197, 408, 115, 208, 150, 163, 149, 141, 230, 146, 150]
# CNN Traj [150, 166, 168, 238, 158, 162, 176, 262, 268, 172, 142, 174, 214, 186, 260, 172, 170, 242, 218, 130, 204, 164, 138, 142, 152, 210, 204, 170, 252, 248, 222, 216, 154, 224, 184, 276, 180, 178, 228, 164, 206, 166, 164, 158, 222, 144, 168, 190, 276, 140, 168, 164, 154, 244, 244, 190, 150, 206, 332, 158, 178, 240, 166, 196, 114, 178, 176, 198, 168, 154, 254, 234, 164, 246, 182, 122, 174, 200, 206, 152, 208, 234, 276, 188, 276, 194, 196, 272, 206, 248, 618, 118, 266, 166, 198, 158, 162, 290, 152, 194]
# CNN Time [2.2531712709969725, 2.68887854499917, 2.323070287995506, 3.462233958998695, 2.3181391499965684, 2.4211861840012716, 2.6106587309986935, 4.145914416003507, 3.5182748300067033, 2.6847495820038603, 2.0701015449958504, 2.5418479569998453, 2.920845141998143, 2.872441034996882, 3.7209890160011128, 2.697945407999214, 2.7637551719963085, 3.553757328998472, 2.8042781940021086, 1.9595503689997713, 3.303078215001733, 2.5046079239982646, 2.066357688003336, 2.1182428739994066, 2.4064575009979308, 2.7978699000013876, 3.042420931997185, 2.7911358429992106, 3.724931992997881, 3.6191866750014015, 3.223847048000607, 3.2435151529934956, 2.3760645860020304, 3.1952832469978603, 2.9118310189951444, 3.25472532799904, 2.633994279000035, 2.6013931540001067, 3.647335462999763, 2.346596508003131, 3.1993014000036055, 2.7195605500019155, 2.3794298419961706, 2.388540756001021, 3.4273303610025323, 2.298682740001823, 2.8371589189991937, 3.088824086000386, 3.9742962010059273, 2.203048099996522, 2.5154155749987694, 2.4084279790040455, 2.263793631995213, 2.9475570040012826, 3.230900151997048, 2.7916285729952506, 2.3434383010026067, 3.069099611006095, 4.656929138996929, 2.4421627260016976, 2.825427881005453, 3.466584891000821, 2.501660940994043, 2.8684969460009597, 1.6743089290030184, 2.7157667970022885, 2.7180007839997415, 3.1678285020025214, 2.481655273004435, 2.3880968330049654, 3.4039784029955626, 3.6466976200026693, 2.546758240998315, 3.1436925559974043, 2.782386929000495, 1.8701692190006725, 2.512443555002392, 3.0044330739983707, 3.046715520998987, 2.258136149001075, 2.8144807559947367, 3.5057239570014644, 3.8315123499996844, 2.895725881004182, 3.5629063689993927, 2.6970567269963794, 2.825887111001066, 4.05135087799863, 3.1410043420037255, 3.4636315749958158, 7.264295638000476, 1.5959760010009632, 3.572325301996898, 2.6108212060062215, 2.958364662998065, 2.393289677995199, 2.5044039729982615, 4.2478509309949, 2.1864486839986057, 2.795851668997784]
# A Path [135, 150, 132, 200, 136, 156, 156, 221, 185, 148, 133, 156, 170, 160, 192, 155, 156, 201, 179, 127, 180, 150, 144, 133, 134, 148, 176, 144, 177, 184, 183, 178, 192, 176, 158, 170, 153, 157, 199, 139, 168, 163, 143, 169, 173, 134, 154, 157, 209, 131, 148, 145, 162, 177, 189, 162, 138, 181, 203, 148, 156, 182, 147, 174, 111, 159, 160, 172, 146, 137, 191, 203, 146, 200, 144, 115, 144, 152, 163, 135, 181, 231, 200, 170, 205, 166, 146, 205, 167, 194, 248, 115, 203, 152, 165, 148, 137, 209, 146, 132]
# A Traj [204, 234, 207, 325, 197, 227, 236, 366, 289, 238, 192, 227, 248, 264, 300, 230, 229, 325, 271, 175, 279, 218, 209, 185, 209, 232, 289, 213, 294, 297, 287, 281, 285, 305, 241, 263, 237, 238, 323, 199, 266, 247, 211, 271, 279, 196, 218, 232, 330, 193, 243, 217, 253, 275, 303, 264, 211, 265, 365, 213, 254, 303, 211, 259, 149, 237, 254, 272, 221, 220, 318, 328, 222, 326, 213, 168, 202, 248, 279, 212, 287, 387, 313, 246, 346, 245, 221, 334, 268, 311, 410, 141, 327, 238, 285, 204, 215, 338, 199, 201]
# A Time [0.03786751100415131, 0.040371073002461344, 0.041894098998454865, 0.04628341400530189, 0.03198538099968573, 0.04007523700420279, 0.03669220300071174, 0.07410292299755383, 0.04419955399498576, 0.03870122200169135, 0.03971343799639726, 0.03059276299609337, 0.0358567570001469, 0.05306914900575066, 0.05162957299762638, 0.04027401700295741, 0.03086629199970048, 0.04734099699999206, 0.04127791400242131, 0.04525191100401571, 0.05912800000078278, 0.0505961090020719, 0.0338247620020411, 0.06594135200430173, 0.03618428900517756, 0.05202605199883692, 0.041186697999364696, 0.03036705699923914, 0.05603425499430159, 0.0479708419952658, 0.058814239004277624, 0.04932733099849429, 0.057926223998947535, 0.06848752299993066, 0.039820705002057366, 0.05127776000153972, 0.0448667189994012, 0.034822952999093104, 0.048342895999667235, 0.03260853599931579, 0.04578517800109694, 0.041522814994095825, 0.03101451299880864, 0.05137504999584053, 0.043555987002037, 0.04053794799983734, 0.05392430300707929, 0.04455017599684652, 0.05910811400099192, 0.03786500400019577, 0.05768575399997644, 0.03777574400010053, 0.040561504996730946, 0.057067731002462097, 0.050537934999738354, 0.05308939699898474, 0.044160234996525105, 0.04804736499499995, 0.07938856900000246, 0.03392071399866836, 0.048853686996153556, 0.06342391399812186, 0.029776835996017326, 0.034815309998521116, 0.027736692005419172, 0.03250991400273051, 0.04637848899437813, 0.04007335999631323, 0.04138940799748525, 0.04257640500145499, 0.06167567399825202, 0.05324266399838962, 0.035929852005210705, 0.09784056500211591, 0.04952715399849694, 0.03209116899961373, 0.0367989270016551, 0.036450633000640664, 0.05603955099650193, 0.0447045659966534, 0.056116098996426445, 0.07133448300010059, 0.06492134099971736, 0.043981292998068966, 0.06168718499975512, 0.03994407699792646, 0.04217830800189404, 0.06577156799903605, 0.03999322800518712, 0.08106746400153497, 0.08633195599395549, 0.023085134002030827, 0.0685209430012037, 0.03749219999735942, 0.048772641996038146, 0.03270064500247827, 0.03453439699660521, 0.06070374599948991, 0.029329254997719545, 0.0620462349979789]
# starting iteration 7
# NN Path
# NN Traj [212, 464, 170, 278, 146, 202, 148, 146, 234, 230, 204, 196, 170, 178, 212, 234, 110, 248, 142, 184, 278, 236, 168, 188, 380, 324, 244, 168, 332, 156, 160, 254, 144, 160, 210, 178, 220, 152, 160, 192, 148, 172, 314, 206, 190, 194, 162, 146, 168, 154, 170, 148, 194, 262, 132, 266, 160, 208, 148, 142, 184, 234, 258, 204, 208, 252, 136, 152, 190, 180, 218, 190, 320, 206, 218, 192, 134, 148, 162, 232, 136, 136, 176, 182, 180, 304, 190, 198, 192, 158, 178, 136, 170, 172, 186, 138, 242, 236, 160, 232]
# NN Time [1.5754261949987267, 3.336326790995372, 1.2986436310020508, 1.9172266490058973, 1.1771187009944697, 1.5470458929994493, 1.262703751999652, 1.1490406919983798, 1.7547465309980907, 1.6314402149946545, 1.5534527290001279, 1.3502708060041186, 1.374598877999233, 1.3217835309988004, 1.6978266639998765, 1.747476328004268, 0.8965261089979322, 1.7800073300022632, 1.135191499000939, 1.4998046699984116, 1.8583269370064954, 1.7729348189968732, 1.3651930499981972, 1.4208261499952641, 2.6178016330013634, 2.380779626000731, 1.6476998149955762, 1.3194611029975931, 2.3312305829967954, 1.1563692679992528, 1.2486832350041368, 1.8846347149956273, 1.1256364060027408, 1.2282854430013685, 1.481729589002498, 1.3592747489965404, 1.5788376969940146, 1.2732560570002533, 1.1638433190018986, 1.4393587049999041, 1.1607202900049742, 1.379754128996865, 1.997019784997974, 1.5277147810047609, 1.4411157300055493, 1.3835800729939365, 1.1348314839997329, 1.1626543829988805, 1.2608721599972341, 1.2244211949946475, 1.3367154770021443, 1.2879862420013524, 1.5422616239957279, 2.000302214997646, 1.0211651450008503, 1.9803685909937485, 1.3385964589979267, 1.5606801620015176, 1.2508726400046726, 1.0007301090008696, 1.2450511169954552, 1.9419121989994892, 1.7693614539966802, 1.5117936700044083, 1.5733792989994981, 2.064589817004162, 1.06404575099441, 1.2382377999965684, 1.448909195001761, 1.3585062109996215, 1.638282485997479, 1.6085974689995055, 2.3296008839970455, 1.6760424120002426, 1.6011429450009018, 1.6466537109954515, 1.0999795950046973, 1.095948448004492, 1.3639748810019228, 1.8372712100026547, 1.0504212580053718, 1.1474866160060628, 1.4008254630025476, 1.4010239599956549, 1.4244074410016765, 2.1907156230008695, 1.6808610760053853, 1.4584251729975222, 1.4625113130023237, 1.124634095998772, 1.3310427169999457, 1.0430530250014272, 1.284545968999737, 1.3196609019942116, 1.5488269589986885, 1.0346394520020112, 1.6147548129956704, 1.623109811996983, 1.2488083909993293, 1.728096805003588]
# CNN Path [177, 342, 164, 191, 125, 149, 133, 174, 159, 189, 158, 158, 155, 134, 275, 183, 108, 179, 137, 162, 182, 172, 181, 154, 273, 203, 177, 197, 149, 146, 135, 179, 133, 166, 142, 142, 151, 149, 137, 140, 130, 148, 203, 150, 229, 146, 148, 155, 141, 139, 261, 150, 164, 200, 130, 255, 142, 180, 136, 130, 132, 195, 172, 171, 149, 412, 131, 167, 164, 156, 194, 160, 205, 171, 189, 155, 124, 132, 143, 163, 125, 135, 149, 163, 176, 253, 175, 167, 145, 149, 123, 124, 139, 167, 170, 119, 143, 163, 145, 197]
# CNN Traj [210, 424, 190, 244, 148, 170, 146, 206, 200, 244, 178, 204, 162, 146, 324, 222, 110, 208, 154, 182, 230, 226, 214, 182, 356, 252, 204, 230, 166, 156, 156, 228, 134, 194, 150, 162, 176, 172, 164, 150, 150, 166, 288, 180, 274, 172, 174, 168, 160, 156, 344, 168, 182, 236, 132, 306, 160, 202, 144, 136, 136, 222, 202, 196, 166, 570, 134, 200, 196, 180, 244, 180, 254, 186, 250, 186, 130, 136, 158, 188, 132, 140, 168, 170, 210, 296, 184, 194, 156, 164, 136, 128, 154, 174, 186, 126, 166, 192, 154, 244]
# CNN Time [2.8551359349949053, 6.064577789002215, 2.872663202993863, 3.1324545230017975, 2.034793136997905, 2.577578871998412, 2.2976538529983372, 2.888164873998903, 2.74046450300375, 3.4004971309987013, 2.581954572997347, 2.5778944840058102, 2.8037395360006485, 2.259713516003103, 4.641530842993234, 3.154780775003019, 1.5345199570001569, 3.2510183039994445, 2.4233820069930516, 2.8041783119988395, 2.9199947390006855, 2.93169881200447, 3.3840087040007347, 2.768172695999965, 5.009917176001181, 3.76318300699495, 3.0456492899975274, 3.2612001599991345, 2.6651505249974434, 2.383811285995762, 2.2301275040008477, 3.2794513530025142, 1.861158225001418, 2.558983574999729, 2.3571275970025454, 2.3616339080035686, 2.7032535030011786, 2.81773910800257, 2.1697992449990124, 2.3476024880001205, 2.231797918997472, 2.5013874169962946, 4.013388671002758, 2.61229881500185, 4.087144181001349, 2.6875054289994296, 2.7116194530026405, 2.4924810390002676, 2.0413856799932546, 2.3161762219970115, 4.759205219997966, 2.51064109300205, 3.0007747909985483, 3.576725113001885, 1.8941297729979851, 4.394356471995707, 2.4631747719977284, 2.8950516340046306, 2.1752127229992766, 2.149483399996825, 2.071520875004353, 3.2606715129950317, 2.7977014019998023, 2.7967861580036697, 2.6598376980036846, 7.082856476998131, 2.1368903850016068, 2.831740637004259, 2.7880054819979705, 2.4785531890011043, 3.2983912820054684, 2.9841862190005486, 3.689757894993818, 2.7712325880056596, 3.236111551996146, 2.7215346270022565, 2.1097945860019536, 1.9951008619973436, 2.3214141719945474, 2.875163112003065, 2.179631792998407, 2.1776381800009403, 2.46772718399734, 2.671910762997868, 2.938783274999878, 4.05967489300383, 2.997166898996511, 2.9256380980004906, 2.3999262810029904, 2.209495645998686, 2.114637781000056, 1.926434042994515, 2.2410515939991456, 2.575210650000372, 2.9213277020026, 1.865340094002022, 2.393091714999173, 2.815965974004939, 2.268281291006133, 3.331038060001447]
# A Path [177, 322, 156, 189, 123, 156, 138, 134, 171, 183, 158, 157, 155, 134, 185, 175, 108, 172, 143, 165, 181, 167, 143, 154, 268, 203, 175, 187, 149, 146, 140, 181, 133, 166, 142, 140, 159, 149, 137, 140, 134, 146, 206, 150, 192, 145, 148, 135, 141, 139, 247, 149, 165, 200, 130, 231, 140, 163, 137, 130, 132, 180, 170, 164, 147, 243, 131, 139, 156, 179, 194, 160, 204, 168, 168, 146, 124, 132, 140, 158, 125, 135, 148, 163, 154, 219, 175, 160, 160, 146, 123, 124, 139, 167, 170, 119, 160, 156, 142, 174]
# A Traj [270, 515, 248, 311, 180, 248, 213, 196, 255, 326, 235, 256, 220, 200, 292, 284, 137, 270, 220, 255, 284, 275, 232, 246, 465, 347, 269, 276, 232, 198, 220, 336, 165, 246, 207, 208, 240, 249, 214, 196, 217, 213, 377, 239, 313, 229, 233, 206, 199, 207, 414, 224, 259, 314, 169, 365, 216, 243, 201, 176, 168, 269, 257, 241, 228, 395, 172, 216, 229, 288, 327, 251, 343, 242, 261, 229, 182, 176, 210, 253, 177, 192, 224, 232, 242, 346, 254, 248, 234, 202, 190, 166, 206, 230, 256, 158, 270, 232, 204, 269]
# A Time [0.050531716995465104, 0.23455288200057112, 0.03977635900082532, 0.07589692100009415, 0.036470442995778285, 0.039032915999996476, 0.03729295900120633, 0.029906997995567508, 0.04388091999862809, 0.06471825499465922, 0.04032361899589887, 0.05539827099710237, 0.035116418002871796, 0.03135401000326965, 0.05619480799941812, 0.04674162199808052, 0.022096959997725207, 0.045198447005532216, 0.06251419299951522, 0.043601479002973065, 0.050777611999365035, 0.055914060001668986, 0.04764662699744804, 0.04367006099346327, 0.12137611000071047, 0.07626634200278204, 0.03947034700104268, 0.04422362699551741, 0.0471987370037823, 0.034515993000240996, 0.046360355998331215, 0.0651923559998977, 0.017581664003955666, 0.03539825600455515, 0.043514811994100455, 0.028869344998383895, 0.04970002699701581, 0.04443005400389666, 0.03909466000186512, 0.027286999000352807, 0.04221708600380225, 0.03599352800665656, 0.10266916000546189, 0.03962358099670382, 0.06201037999562686, 0.037677610001992434, 0.055570434997207485, 0.043382839001424145, 0.028179722001368646, 0.0352031389993499, 0.07826448600098956, 0.04167796699766768, 0.06175006100238534, 0.04671906099974876, 0.027843258998473175, 0.07001252399641089, 0.04542425299587194, 0.04444617699482478, 0.03096443400136195, 0.027229748004174326, 0.022624878001806792, 0.03680260399414692, 0.03983225399861112, 0.041831567003100645, 0.039433779995306395, 0.08490565900137881, 0.026728138996986672, 0.051449044003675226, 0.034147389997087885, 0.058189980998577084, 0.0569877070010989, 0.051290770003106445, 0.06625866000103997, 0.04196164799941471, 0.06162099899665918, 0.04235526199772721, 0.034618507997947745, 0.023962669001775794, 0.03633897200052161, 0.03832576499553397, 0.028912580004543997, 0.030854280994390137, 0.03547737099870574, 0.03618219800409861, 0.03579009399982169, 0.058342871998320334, 0.04423345900431741, 0.038824843999464065, 0.032981878001010045, 0.023222429997986183, 0.03900691899616504, 0.026825471002666745, 0.03159118499752367, 0.036084988998482004, 0.04439543299668003, 0.025249803999031428, 0.07893754999531666, 0.04941809799493058, 0.032239012005447876, 0.049319025994918775]
# starting iteration 8
# NN Traj [172, 254, 266, 322, 172, 224, 282, 140, 202, 242, 354, 166, 294, 216, 156, 182, 142, 200, 178, 174, 214, 214, 154, 178, 220, 154, 382, 202, 146, 192, 228, 230, 176, 188, 196, 220, 158, 216, 288, 162, 180, 282, 260, 120, 230, 248, 168, 152, 224, 202, 140, 182, 132, 252, 262, 252, 168, 174, 238, 128, 216, 166, 154, 232, 184, 240, 240, 190, 168, 234, 224, 224, 144, 120, 130, 202, 224, 134, 194, 202, 172, 230, 142, 176, 162, 330, 190, 194, 234, 192, 154, 156, 158, 198, 206, 168, 176, 158, 166, 222]
# NN Time [1.2684204199977103, 1.8228114590019686, 1.5331745039948146, 2.3316027199980454, 1.3455352519958979, 1.6307125449966406, 2.2331789450036013, 1.1197107429979951, 1.4146612140029902, 2.002509171004931, 2.4400168040010612, 1.252838675995008, 2.105098956999427, 1.6640191420010524, 1.3469121119996998, 1.4517710789950797, 1.1547975200010114, 1.5181056450019241, 1.3986370980055653, 1.247902969000279, 1.7834685240013641, 1.7724245599965798, 1.184269358003803, 1.3679420749976998, 1.7474894680053694, 1.216215081003611, 2.5212978390045464, 1.521372285002144, 1.130177585000638, 1.5787158399980399, 1.7434368159956648, 1.971135735999269, 1.477967731996614, 1.4121514660000685, 1.5490313050031546, 1.7353673290053848, 1.26297253100347, 1.6154227829974843, 2.1172996000022977, 1.2577777989936294, 1.4135609849981847, 2.0910920030000852, 1.9000767950055888, 0.8760685019951779, 1.850052247995336, 1.8058516539967968, 1.2899275890013087, 1.2370926229996257, 1.6833438760004356, 1.5808868130043265, 1.1945615810036543, 1.4352423860036652, 1.155885929998476, 2.035421086999122, 2.1274853789946064, 1.8163081079983385, 1.2929659059955156, 1.4133630900032585, 1.8945118489937158, 1.1016036119981436, 1.5874790320012835, 1.3362761249954929, 1.1966663030034397, 1.576974584000709, 1.481250042001193, 2.049047978995077, 1.715135852995445, 1.4049845879999339, 1.3046532940061297, 1.8441160200018203, 1.7206598409975413, 1.6537260629993398, 1.201082886000222, 0.9574259549990529, 1.1489067939983215, 1.4501126239993027, 1.7292896229992039, 0.9086505439991015, 1.5726451640002779, 1.5577079430004233, 1.3258341590044438, 1.7318455549975624, 1.179254218004644, 1.441342182995868, 1.1950698860018747, 2.2136275400043814, 1.5891690089993062, 1.5451006090006558, 1.8456328259999282, 1.4044872679951368, 1.3034479979978641, 1.116914503996668, 1.0063585300013074, 1.5087586180015933, 1.7098756300038076, 1.4622240249955212, 1.3154739849996986, 1.2220294680009829, 1.3442675949991099, 1.871124516997952]
# CNN Path [166, 197, 190, 135, 146, 181, 221, 130, 154, 238, 268, 143, 175, 187, 141, 157, 126, 213, 197, 175, 155, 172, 147, 144, 183, 133, 233, 181, 137, 188, 157, 167, 152, 149, 162, 213, 160, 148, 198, 162, 170, 199, 185, 121, 179, 194, 131, 142, 190, 194, 137, 149, 152, 176, 272, 176, 276, 144, 191, 152, 176, 159, 123, 169, 165, 167, 188, 143, 156, 236, 157, 174, 167, 120, 123, 152, 182, 137, 188, 173, 143, 155, 159, 156, 144, 220, 142, 150, 228, 152, 198, 131, 128, 170, 204, 157, 143, 139, 152, 182]
# CNN Traj [178, 244, 224, 150, 164, 204, 272, 138, 182, 292, 412, 156, 210, 212, 154, 176, 142, 272, 244, 208, 184, 198, 166, 152, 202, 144, 306, 212, 156, 220, 176, 186, 178, 154, 188, 302, 194, 182, 250, 178, 204, 242, 210, 122, 198, 222, 136, 162, 236, 256, 142, 180, 164, 198, 328, 216, 360, 164, 232, 180, 214, 188, 126, 184, 176, 194, 208, 170, 174, 308, 180, 190, 192, 128, 130, 178, 210, 148, 222, 200, 154, 182, 178, 182, 158, 290, 160, 174, 288, 166, 226, 138, 130, 206, 252, 168, 160, 148, 172, 212]
# CNN Time [2.6704374269975233, 3.505839064004249, 2.911259936998249, 2.180455597997934, 2.409049000001687, 3.0501074549974874, 3.911407332001545, 2.140579514998535, 2.538729357002012, 3.9980211109941592, 5.172530140996969, 2.43108099699748, 3.491170195004088, 3.3437694359963643, 3.0644980069992016, 2.421026341995457, 2.085454041000048, 4.194453367999813, 3.369813042001624, 2.8376835280068917, 2.4433241939987056, 3.1427288909981144, 2.3838865059951786, 2.4056971300014993, 3.2235172240034444, 2.1680799390014727, 4.375665706000291, 3.24519988099928, 2.1532423489989014, 3.376121566994698, 3.0144240820009145, 2.9419734260009136, 2.7443936109993956, 2.97678118100157, 3.026735430001281, 3.964888560003601, 2.863421592002851, 2.748570956995536, 3.562310161003552, 2.796323886999744, 2.7517923710038303, 3.767676917996141, 2.9895964870011085, 1.675749863999954, 3.0768044409996946, 3.2364617380007985, 2.0801553329947637, 2.618228719999024, 3.415536095002608, 3.650174030000926, 2.1564223910027067, 2.653540080995299, 2.3848935209971387, 2.779957074999402, 4.877344479005842, 3.144529720004357, 4.59780621399841, 2.395530827001494, 3.166151600999001, 2.6472506069985684, 3.054132056000526, 2.7645992420002585, 1.8952519070007838, 2.738783444001456, 2.6621919449971756, 3.2637458780009183, 2.9877347020010347, 2.4631866339987027, 2.5562974080021377, 4.328842538998288, 2.641174113996385, 2.769536836000043, 2.7568086859973846, 1.8502386420004768, 2.008952367999882, 2.4433885460020974, 3.1745433039977797, 2.0117158380016917, 3.197861909000494, 2.775255474000005, 2.166439321998041, 2.680792138999095, 2.4367987169971457, 2.778845670996816, 2.000644123996608, 3.9917754990019603, 2.336132201999135, 2.532640290002746, 4.181219568003144, 2.542870495002717, 3.327081368996005, 2.0558502189960564, 1.8063579019944882, 3.120889999998326, 3.6197153700049967, 2.6810765159971197, 2.3571304270008113, 2.5641484730003867, 2.6175722379994113, 3.060543380997842]
# A Path [136, 139, 196, 137, 145, 194, 217, 130, 148, 185, 263, 143, 157, 187, 141, 155, 119, 206, 159, 137, 151, 166, 153, 144, 182, 133, 208, 172, 132, 185, 138, 164, 152, 149, 159, 149, 134, 148, 177, 167, 153, 221, 189, 121, 179, 192, 130, 142, 186, 191, 133, 147, 152, 172, 226, 189, 138, 167, 185, 151, 176, 157, 123, 168, 165, 150, 183, 143, 144, 196, 161, 172, 167, 117, 123, 152, 178, 136, 184, 179, 143, 154, 152, 152, 144, 205, 135, 162, 209, 151, 205, 131, 128, 157, 274, 154, 143, 137, 145, 179]
# A Traj [193, 221, 298, 213, 226, 322, 358, 188, 220, 285, 507, 214, 261, 289, 202, 224, 170, 345, 246, 211, 225, 259, 229, 208, 290, 197, 371, 270, 185, 286, 207, 249, 250, 215, 255, 233, 197, 231, 272, 233, 224, 398, 281, 151, 276, 291, 182, 221, 290, 332, 183, 243, 218, 244, 369, 332, 205, 250, 312, 237, 283, 249, 164, 242, 241, 240, 268, 226, 208, 328, 243, 240, 250, 161, 172, 231, 286, 185, 300, 281, 199, 241, 218, 236, 200, 364, 202, 257, 337, 226, 312, 182, 161, 257, 444, 224, 220, 199, 222, 282]
# A Time [0.027050021999457385, 0.037320210998586845, 0.04006789000413846, 0.05457702500279993, 0.03917543699935777, 0.05667156299750786, 0.06858573600038653, 0.03330866699980106, 0.03735478599992348, 0.04536983300204156, 0.09973413599800551, 0.037971752004523296, 0.04541141500521917, 0.04615031999856001, 0.038665587999275886, 0.029722539999056607, 0.055161980002594646, 0.10250408999854699, 0.04878109900164418, 0.03338833300222177, 0.033451626004534774, 0.05685451199678937, 0.040931408999313135, 0.035970363998785615, 0.054668907992891036, 0.034738030000880826, 0.07270171900017885, 0.05825601900141919, 0.027550251994398423, 0.09002513399900636, 0.038188949998584576, 0.04506101599690737, 0.04706571000133408, 0.03873533200385282, 0.0419156429998111, 0.051717695001570974, 0.039843225000367966, 0.03872554599365685, 0.04734459600149421, 0.029870957994717173, 0.05080988300323952, 0.07234241100377403, 0.03902049600583268, 0.03639800800010562, 0.04064132200437598, 0.0458469170043827, 0.03760941300424747, 0.04801510800461983, 0.042511943996942136, 0.04919732800044585, 0.030854970995278563, 0.04782653299480444, 0.030662123004731257, 0.044858114000817295, 0.05846513400319964, 0.06314379000104964, 0.028965353994863108, 0.06476067900075577, 0.06963493300281698, 0.03742511099699186, 0.0520925390010234, 0.043396465000114404, 0.027106553003250156, 0.03631799499999033, 0.03585076400486287, 0.03836904200579738, 0.03919798199785873, 0.038534988001629245, 0.031646209994505625, 0.052061110000067856, 0.04425075100152753, 0.03308086000470212, 0.042430501998751424, 0.03093948400055524, 0.028182520996779203, 0.04519180500210496, 0.03967958599969279, 0.02630404099909356, 0.050833007000619546, 0.04738311000255635, 0.02795725299802143, 0.05465244600054575, 0.03612060900195502, 0.04636515399761265, 0.026405684002384078, 0.06526302400015993, 0.04315802700148197, 0.0390529959986452, 0.05565513300098246, 0.03525301700574346, 0.06316653699468588, 0.03154392199940048, 0.0208376730006421, 0.04717710699333111, 0.07901727299758932, 0.036467961996095255, 0.03840744400076801, 0.03064909800013993, 0.04653409200545866, 0.04447939099918585]
# starting iteration 9
from random import choice
from timeit import default_timer as timer
import tensorflow as tf
import numpy as np
MAZE_DIM = 50
# ANALYSIS
def run_nn_agent(complete_grid, model):
original_source = (0,0) # we need to remember the original_source node bc in repeated A* source changes
source = original_source
n, m = (len(complete_grid), len(complete_grid[0]))
goal = (n-1, m-1) # dimensions of grid
num_actions = trajectory = 0
complete_grid = np.array(complete_grid)
grid = np.array([[0.5]*(m) for _ in range(n)])
grid[source[0]][source[1]] = 0.0
path = [source]
grid_states = []
movement_grid = [[-1]*m for _ in range(n)]
while source != goal: # execution step
valid_actions = []
# Generate input data for agent - grid, source + neighbors tuple
source_neighbors_data = [source[0], source[1]]
for dx, dy in [(0, 1), (-1, 0), (0, -1), (1, 0)]: # generate children
new_x, new_y = (source[0] + dx, source[1] + dy)
if 0 <= new_x < len(grid) and 0 <= new_y < len(grid):
source_neighbors_data.extend([new_x, new_y, grid[new_x][new_y]])
if grid[new_x][new_y] != 1:
valid_actions.append((dx, dy))
else:
source_neighbors_data.extend([0,0,1])
input_data = (np.array([grid]), np.array([source_neighbors_data]))
# Give agent grid and (try to) take step by agent
out = model(input_data).numpy()
prediction = np.argmax( out , axis = 1 )[0]
# Move to predicted state, if possible
num_actions += 1
action_movement_dict = {
0 : (-1,0), # North
1 : (0, 1), # East
2 : (1, 0), # South
3 : (0, -1) # West
}
# If attempted action is not valid for some reason, pick a random one that is
# Also, if we're attempting to go the same way we already have (in a loop), run a_star to get out of the loop
action = action_movement_dict[prediction]
if action not in valid_actions:
action = choice(valid_actions)
elif movement_grid[source[0]][source[1]] == prediction:
valid_actions.remove(action)
a_star_path, *_ = a_star(deepcopy(grid), source, goal)
idx = 1
next = a_star_path[0]
# Follow a_star until we hit a block or we are no longer in the loop
while source in path:
if idx >= len(a_star_path):
# print("Finished here")
break
curr = a_star_path[idx]
x, y = curr
grid[x][y] = complete_grid[x][y]
if grid[x][y] == 1: # If we hit a new block, break
break
# Move from source to curr
trajectory += 1
source = (x, y)
path.append(source)
idx += 1
continue
movement_grid[source[0]][source[1]] = prediction
# Move to new cell
x, y = (source[0] + action[0], source[1] + action[1])
grid[x][y] = complete_grid[x][y]
if grid[x][y] == 1:
continue
trajectory += 1
source = (x, y)
path.append(source)
# print(f"current path: {path[-10:]}")
return path, trajectory, num_actions
NUM_ITERATIONS = 30
NUM_TRIALS = 100
MAZE_DIM = 50
DENSITY = 0.3
SOLVABILITY = True
IS_TERRAIN_ENABLED = False
source = (0,0)
target = goal = (MAZE_DIM-1, MAZE_DIM-1)
p1_dense_model = tf.keras.models.load_model('./test_data_dense_p1/')
p1_cnn_model = tf.keras.models.load_model('./test_data_cnn_p1/')
def run_iteration(i=0):
# nn_path_lens = []
# nn_trajs = []
# nn_times = []
# cnn_path_lens = []
# cnn_trajs = []
# cnn_times = []
# a_path_lens = []
# a_trajs = []
# a_times = []
nn_path_sum = nn_traj_sum = nn_time_sum = 0
cnn_path_sum = cnn_traj_sum = cnn_time_sum = 0
a_path_sum = a_traj_sum = a_time_sum = 0
print(f"starting iteration {i}")
for j in range(NUM_TRIALS):
# print(f"trial {(i, j)}")
grid, _, _ = gen_gridworld(MAZE_DIM, DENSITY, SOLVABILITY, IS_TERRAIN_ENABLED, source, target)
start = timer()
path, traj, actions = run_nn_agent(grid, p1_dense_model)
stop = timer()
nn_time = stop - start
# nn_path_lens.append(len(set(path)))
# nn_trajs.append(traj)
# nn_times.append(nn_time)
nn_path_sum += len(set(path))
nn_traj_sum += traj
nn_time_sum += nn_time
start = timer()
path, traj, actions = run_nn_agent(grid, p1_cnn_model)
stop = timer()
cnn_time = stop - start
# cnn_path_lens.append(len(set(path)))
# cnn_trajs.append(traj)
# cnn_times.append(cnn_time)
cnn_path_sum += len(set(path))
cnn_traj_sum += traj
cnn_time_sum += cnn_time
start = timer()
_, repeated_a_star_path_set, _, _, _, _, traj, _, _ = repeated_a_star(grid)
stop = timer()
a_star_time = stop - start
# a_path_lens.append(len(repeated_a_star_path_set))
# a_trajs.append(traj)
# a_times.append(a_star_time)
a_path_sum += len(repeated_a_star_path_set)
a_traj_sum += traj
a_time_sum += a_star_time
return ((nn_path_sum / NUM_TRIALS, nn_traj_sum / NUM_TRIALS, nn_time_sum / NUM_TRIALS), (cnn_path_sum / NUM_TRIALS, cnn_traj_sum / NUM_TRIALS, cnn_time_sum / NUM_TRIALS), (a_path_sum / NUM_TRIALS, a_traj_sum / NUM_TRIALS, a_time_sum / NUM_TRIALS))
# print("NN Path", nn_path_lens)
# print("NN Traj", nn_trajs)
# print("NN Time", nn_times)
# print("CNN Path", cnn_path_lens)
# print("CNN Traj", cnn_trajs)
# print("CNN Time", cnn_times)
# print("A Path", a_path_lens)
# print("A Traj", a_trajs)
# print("A Time", a_times)
nn_path_avgs = []
nn_traj_avgs = []
nn_time_avgs = []
cnn_path_avgs = []
cnn_traj_avgs = []
cnn_time_avgs = []
a_path_avgs = []
a_traj_avgs = []
a_time_avgs = []
for i in range(NUM_ITERATIONS):
rets = run_iteration(i)
nn_path_avgs.append(rets[0][0])
nn_traj_avgs.append(rets[0][1])
nn_time_avgs.append(rets[0][2])
cnn_path_avgs.append(rets[1][0])
cnn_traj_avgs.append(rets[1][1])
cnn_time_avgs.append(rets[1][2])
a_path_avgs.append(rets[2][0])
a_traj_avgs.append(rets[2][1])
a_time_avgs.append(rets[2][2])
print(nn_path_avgs)
print(nn_traj_avgs)
print(nn_time_avgs)
print(cnn_path_avgs)
print(cnn_traj_avgs)
print(cnn_time_avgs)
print(a_path_avgs)
print(a_traj_avgs)
print(a_time_avgs)
nn_path_avgs = [177.4, 157.6, 175.6, 156.2, 189.4, 188.0, 162.4, 164.8, 151.8, 184.8]
cnn_path_avgs = [226.0, 176.2, 157.0, 142.6, 158.8, 175.4, 172.0, 161.0, 158.2, 215.4]
a_path_avgs = [130.2, 132.2, 127.4, 123.0, 123.4, 133.4, 125.8, 127.0, 120.6, 139.8]
nn_traj_avgs = [222.0, 180.8, 217.6, 184.8, 251.6, 245.6, 205.6, 192.4, 176.8, 226.0]
cnn_traj_avgs = [298.4, 206.8, 180.4, 157.2, 184.8, 207.2, 205.2, 184.8, 188.8, 267.2]
a_traj_avgs = [258.2, 251.2, 230.6, 206.4, 250.0, 263.0, 269.6, 234.0, 231.2, 287.4]
nn_time_avgs = [1.934705733403098, 1.645674962400517, 1.8060789418013883, 1.6095205746009014, 2.0631797833993915, 2.0514500179997412, 1.8833281233994057, 1.617990298998484, 1.6302852630004054, 2.040595890399709]
cnn_time_avgs = [4.266449646399996, 3.369963846601604, 2.9987810671998885, 2.577833686601662, 3.0480733748001514, 3.3660197225995945, 3.5337548686002265, 3.025409155401576, 3.049412638200738, 4.3841205936012555]
a_time_avgs = [0.04768102299858583, 0.04122461159859085, 0.03712173440144397, 0.03736739440064411, 0.04637555099907331, 0.0517265602000407, 0.0513785564005957, 0.034284745600598396, 0.0453821864008205, 0.056263956600741946]
import matplotlib.pyplot as plt
its = [0,1,2,3,4,5,6,7,8,9]
plt.plot(its, nn_path_avgs)
plt.plot(its, cnn_path_avgs)
plt.plot(its, a_path_avgs)
plt.legend(["NN", "CNN", "A Star"])
plt.xlabel("Iterations")
plt.ylabel("Average Path")
plt.title("Average Path For Each Agent")
plt.grid()
plt.show()
its = [0,1,2,3,4,5,6,7,8,9]
plt.plot(its, nn_traj_avgs)
plt.plot(its, cnn_traj_avgs)
plt.plot(its, a_traj_avgs)
plt.legend(["NN", "CNN", "A Star"])
plt.xlabel("Iterations")
plt.ylabel("Average Trajectory")
plt.title("Average Trajectory For Each Agent")
plt.grid()
plt.show()
its = [0,1,2,3,4,5,6,7,8,9]
plt.plot(its, nn_time_avgs)
plt.plot(its, cnn_time_avgs)
plt.plot(its, a_time_avgs)
plt.legend(["NN", "CNN", "A Star"])
plt.xlabel("Iterations")
plt.ylabel("Average Times")
plt.title("Average Times For Each Agent")
plt.grid()
plt.show()
class CellState(Enum):
BLOCKED = 1
OPEN = 2
UNKNOWN = 3
def __int__(self):
return self.value
class Intel:
def __init__(self, N: int = 0):
self.N = N
self.visited = False
self.status = CellState.UNKNOWN
self.C = -1 # Value of -1 means we don't know this value (cell hasn't been visited yet)
self.B = 0
self.E = 0
self.H = N
# self.blocked_probability = 0.5 # Used for agent 5
def __str__(self):
return f"N: {self.N}, visited: {self.visited}, status: {self.status}, C: {self.C}, B: {self.B}, " \
f"E: {self.E}, H: {self.H}"
def __repr__(self):
return self.__str__()
@staticmethod
def convert(self):
return np.array([self.N, int(self.visited), int(self.status), self.C, self.B, self.E, self.H]) # didn't add blocked_probability bc we only care about agent 3
INTEL_SIZE = len(Intel().__dict__.keys())
def _run_single_inference_agent3(grid_len: int, kb: List[List[Intel]]) -> bool:
changed = False
for row_idx, row in enumerate(kb):
for col_idx, intel in enumerate(row):
# If the cell has not been visited yet, we can't infer anything about its neighbors
# since C has not been set
# If Hx=0: nothing remains to be inferred about cell x.
if not intel.visited or intel.H == 0:
continue
# If Cx=Bx: all remaining hidden neighbors of x are empty.
if intel.C == intel.B:
for (n_row, n_col) in _get_neighbor_indices(grid_len, row_idx, col_idx):
if kb[n_row][n_col].status == CellState.UNKNOWN:
update_cell_status(CellState.OPEN, kb, n_row, n_col, grid_len)
changed = True
# If Nx−Cx=Ex: all remaining hidden neighbors of x are blocked.
elif intel.N - intel.C == intel.E:
for (n_row, n_col) in _get_neighbor_indices(grid_len, row_idx, col_idx):
if kb[n_row][n_col].status == CellState.UNKNOWN:
update_cell_status(CellState.BLOCKED, kb, n_row, n_col, grid_len)
changed = True
return changed
def agent3(grid: List[List[int]], dataformat:int = None, inference: Callable = _run_single_inference_agent3):
"""
Args:
Returns:
list(tuple(int, int)): the path from the source node to the goal node which consists of coordinate points (i.e. x,y) in the form of tuples.
int: the length of the discovered path from source to goal (i.e. the length of the first return value)
list(tuple(int,int)): all visited squares from the source point. Contains int tuples in the form of (x,y).
int: trajectory - num of moves the agent makes
int: num of times the agent assumes wrong and enters a blocked square
"""
knowledge_base = create_initial_kb(grid)
grid_len = len(grid)
n, m = grid_len-1, grid_len-1
source, goal = (0, 0), (n,m)
full_path = [source]
full_path_and_blocks = []
solvable = True
num_bumps = 0
# seen = defaultdict(int)
grid_states = []
grid_sources = []
actions = []
while solvable and goal != full_path[-1]:
# Create current grid based on current knowledge base
current_grid = np.zeros((len(grid), len(grid[0])))
for row_idx, row in enumerate(grid):
for col_idx, _ in enumerate(row):
blocked = 1 if knowledge_base[row_idx][col_idx].status == CellState.BLOCKED else 0
current_grid[row_idx][col_idx] = blocked
path, _, _, _, _ = a_star(current_grid, source, goal)
if path == []: solvable = False
for path_idx, (row_idx, col_idx) in enumerate(path):
ran_inference = False
if not knowledge_base[row_idx][col_idx].visited:
if grid[row_idx][col_idx] == 0:
knowledge_base[row_idx][col_idx].C = \
_get_num_blocked_neighbors(grid, row_idx, col_idx)
knowledge_base[row_idx][col_idx].visited = True
if knowledge_base[row_idx][col_idx].status == CellState.UNKNOWN:
cell_state = \
CellState.BLOCKED if grid[row_idx][col_idx] == 1 else CellState.OPEN
update_cell_status(cell_state, knowledge_base, row_idx, col_idx, grid_len)
run_inference(grid_len, knowledge_base, inference)
ran_inference = True
full_path_and_blocks.append((row_idx, col_idx))
# If we hit a block, restart and try again. Otherwise, continue
if grid[row_idx][col_idx] == 1 and full_path != []:
num_bumps += 1
source = full_path[-1]
break
# Only add to the path if we aren't at the source
# (otherwise we would have duplicates on restarts)
if (row_idx, col_idx) != source:
if dataformat:
if dataformat == 1:
source_and_neighbors = []
for i, (dx, dy) in enumerate([(0,0), (0,1), (1,0), (0,-1), (-1,0), (1,1), (-1,-1), (1,-1), (-1,1)]):
new_x, new_y = [row_idx + dx, col_idx + dy]
if 0 <= new_x < n and 0 <= new_y < m:
source_and_neighbors.extend([new_x, new_y, *(Intel.convert(knowledge_base[new_x][new_y]))])
else:
source_and_neighbors.extend([new_x, new_y, *([0] * INTEL_SIZE)])
grid_states.append(source_and_neighbors)
elif dataformat == 2:
grid_state = np.empty(shape=(len(grid), len(grid[0]), INTEL_SIZE))
for i in range(len(grid)):
for j in range(len(grid[0])):
for k,v in enumerate(Intel.convert(knowledge_base[i][j])):
grid_state[i][j][k] = v
grid_states.append(grid_state)
grid_sources.append((source))
actions.append(int(_get_cardinal_direction(*source, row_idx, col_idx)))
full_path.append((row_idx, col_idx))
# If there is a block in our current path, but the current cell was not blocked
if ran_inference and any(knowledge_base[x][y].status == CellState.BLOCKED for (x,y) in path[path_idx+1:]):
source = full_path[-1]
break
if solvable:
# Create current grid based on current knowledge base
current_grid = np.zeros((len(grid), len(grid[0])))
for row_idx, row in enumerate(grid):
for col_idx, _ in enumerate(row):
blocked = 1 if knowledge_base[row_idx][col_idx].status == CellState.BLOCKED or not knowledge_base[row_idx][col_idx].visited else 0
current_grid[row_idx][col_idx] = blocked
shortest_path, _, _, _, _ = a_star(current_grid, (0,0), goal)
return shortest_path, len(shortest_path), grid_states, grid_sources, actions, full_path, len(full_path), num_bumps, full_path_and_blocks
return [], 0, grid_states, grid_sources, actions, actions, full_path, len(full_path), num_bumps, full_path_and_blocks
def create_initial_kb(grid: List[List[int]]) -> List[List[Intel]]:
kb = np.zeros((len(grid), len(grid[0])))
kb = kb.astype('object')
for row_idx, row in enumerate(grid):
for col_idx, _ in enumerate(row):
N = len(_get_neighbor_indices(len(grid), row_idx, col_idx))
kb[row_idx][col_idx] = Intel(N)
return kb
def run_inference(
grid_len: int,
kb: List[List[Intel]],
single_inference_func : Callable = _run_single_inference_agent3
):
changed = True
while changed:
changed = single_inference_func(grid_len, kb)
def update_cell_status(
new_state: CellState,
kb: List[List[Intel]],
row_idx: int,
col_idx: int,
grid_len: int
) -> None:
# Check if the status can be changed
if kb[row_idx][col_idx].status != CellState.UNKNOWN: return
if new_state == CellState.UNKNOWN: raise Exception("Cell state to change to cannot be unknown")
kb[row_idx][col_idx].status = new_state
# If we change the status, change neighbor info
neighbors = _get_neighbor_indices(grid_len, row_idx, col_idx)
for (row, col) in neighbors:
kb[row][col].H -= 1
if new_state == CellState.BLOCKED:
kb[row][col].B += 1
elif new_state == CellState.OPEN:
kb[row][col].E += 1
def _get_num_blocked_neighbors(
grid: List[List[int]],
row_idx: int,
col_idx: int
) -> int:
num_blocked = 0
neighbors = _get_neighbor_indices(len(grid), row_idx, col_idx)
for row, col in neighbors:
if grid[row][col] == 1:
num_blocked += 1
return num_blocked
def _get_neighbor_indices(grid_len: int, row_idx: int, col_idx: int) -> List[Tuple[int, int]]:
neighbors = []
for row in range(row_idx - 1, row_idx + 2):
if row < 0 or row >= grid_len:
continue
for col in range(col_idx - 1, col_idx + 2):
if row_idx == row and col_idx == col:
continue
elif col < 0 or col >= grid_len:
continue
neighbors.append((row, col))
return neighbors
NUM_OF_MAZES_PER_TEST = 30
NUM_TRAINING_MAZES = 50
NUM_TESTING_MAZES = 50
MAZE_DIM = 50
DENSITY = 0.3
SOLVABILITY = True
IS_TERRAIN_ENABLED = False
source = (0,0)
target = (MAZE_DIM-1, MAZE_DIM-1)
# grid, _, _ = gen_gridworld(dim, density, solvability, is_terrain_enabled, source, target)
# discovered_path_3, length_discovered_path_3, grid_states, grid_sources, actions, full_path_3, trajectory_3, num_bumps_3 = agent3(grid, 2)
# print(f"Agent 3: { length_discovered_path_3}, trajectory: {trajectory_3}, Num of walls bumped into: {num_bumps_3}")
# # pretty_print_explored_region(grid, full_path_3, discovered_path_3)
# print(len(grid_states)) #, grid_states)
# print(len(actions)) #, actions)
def gen_grid_data_agent3(num_iter, agent, arg=None):
all_grids = []
all_sources = []
all_actions = []
for i in range(num_iter):
if i % 100 == 0:
print("Maze", i)
grid, _, _ = gen_gridworld(MAZE_DIM, DENSITY, SOLVABILITY, IS_TERRAIN_ENABLED, source, target)
if arg:
explored_grid, path, grid_states, grid_sources, actions, len_path, visited_cells, *_ = agent(grid, arg)
else:
explored_grid, path, grid_states, grid_sources, actions, len_path, visited_cells, *_ = agent(grid)
all_grids.extend(grid_states)
all_sources.extend(sources)
all_actions.extend(actions)
return (np.array(all_grids), np.array(all_sources), np.array(all_actions))
# Model Building
def create_p2_dense_model():
source_input = tf.keras.layers.Input( shape = (81) )
dense_1 = tf.keras.layers.Dense( units = 2000, activation = tf.nn.relu )( source_input )
dense_2 = tf.keras.layers.Dense( units = 2000, activation = tf.nn.relu )( dense_1 )
dense_3 = tf.keras.layers.Dense( units = 2000, activation = tf.nn.relu )( dense_2 )
probabilities = tf.keras.layers.Dense( units = 4, activation = tf.nn.softmax )( dense_3 )
model = tf.keras.Model(inputs=[source_input], outputs=probabilities)
model.compile( optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'] )
print(model.summary())
return model
# Training
def training_p2_dense():
BATCH_SIZE = 25
NUM_TRAINING_MAZES = 250
total_grids = []
total_sources = []
total_actions = []
for i in range(0, NUM_TRAINING_MAZES, BATCH_SIZE):
print(f"Training starting on maze {i}")
total_grids, total_sources, total_actions = gen_grid_data_agent3(BATCH_SIZE, agent3, 1)
print(np.shape(total_grids), np.shape(total_sources), np.shape(total_actions))
history = model.fit(total_grids, total_actions, epochs = 3, validation_split=.1)
# Testing
model_dense_p2 = tf.keras.models.load_model('./test_data_dense_p2/')
# Hyperparameter Tuning
def create_tuning_model():
def build_model(hp):
source_input = tf.keras.layers.Input( shape = (2) )
# Grid input
grid_input = tf.keras.layers.Input(shape=(MAZE_DIM, MAZE_DIM, INTEL_SIZE))
i = hp.Choice('kernel_and_stride_size', values=[2,3])
conv2d_1 = tf.keras.layers.Conv2D(filters=hp.Int("filters1", min_value=256, max_value=2048, step=128), kernel_size=(i,i), strides=(i,i), padding="same",activation="relu")(grid_input)
maxpool2d_1 = tf.keras.layers.MaxPooling2D(pool_size=(i, i), padding='same')( conv2d_1 )
conv2d_2 = tf.keras.layers.Conv2D(filters=hp.Int("filters2", min_value=32, max_value=256, step=64),kernel_size=i,strides=(i,i), padding="same",activation="relu")( maxpool2d_1 )
maxpool2d_2 = tf.keras.layers.MaxPooling2D(pool_size=(i, i), padding='same')( conv2d_2 )
conv2d_3 = tf.keras.layers.Conv2D(filters=hp.Int("filters3", min_value=4, max_value=32, step=4),kernel_size=i, strides=(i,i), padding="same",activation="relu")( maxpool2d_2 )
maxpool2d_3 = tf.keras.layers.MaxPooling2D(pool_size=(i, i), padding='same')( conv2d_3 )
flatten = tf.keras.layers.Flatten()( maxpool2d_3 )
grids_and_sources = tf.keras.layers.Concatenate()([flatten, source_input])
dense_1 = tf.keras.layers.Dense(units=hp.Int("units", min_value=1000, max_value=3000, step=500), activation='relu')(grids_and_sources)
dense_2 = tf.keras.layers.Dense(units = 500, activation = hp.Choice('activation function', values=['relu', 'sigmoid']))( dense_1)
logits = tf.keras.layers.Dense(units = 4, activation = "softmax" )( dense_2 )
model = tf.keras.Model(inputs=[grid_input, source_input], outputs=logits)
model.compile( optimizer = keras.optimizers.Adam(hp.Choice('learning_rate', values=[1e-2,1e-4])), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
return model
tuner = RandomSearch(
build_model,
objective='val_accuracy',
max_trials=10,
executions_per_trial=1,
# directory
)
tuner.search_space_summary()
# tuner.search((train_total_grids, train_total_sources), train_total_actions, epochs=5, validation_data=((test_total_grids, test_total_sources), test_total_actions))
tuner.results_summary()
models = tuner.get_best_models(num_models=2)
best_model = models[0]
print(best_model.summary())
return best_model
def create_p2_cnn_model():
source_input = tf.keras.layers.Input( shape = (2) )
# Grid input
grid_input = tf.keras.layers.Input(shape=(MAZE_DIM, MAZE_DIM, INTEL_SIZE))
conv2d_1 = tf.keras.layers.Conv2D(filters=256,kernel_size=2,padding="same",activation="relu")(grid_input)
conv2d_11 = tf.keras.layers.Conv2D(filters=128,kernel_size=2,padding="same",activation="relu")(conv2d_1)
conv2d_12 = tf.keras.layers.Conv2D(filters=64,kernel_size=2,padding="same",activation="relu")(conv2d_11)
conv2d_13 = tf.keras.layers.Conv2D(filters=32,kernel_size=2,padding="same",activation="relu")( conv2d_12 )
conv2d_3 = tf.keras.layers.Conv2D(filters=4,kernel_size=2,padding="same",activation="relu")(conv2d_13 )
flatten = tf.keras.layers.Flatten()( conv2d_3 )
grids_and_sources = tf.keras.layers.Concatenate()([flatten, source_input])
dense_1 = tf.keras.layers.Dense( units = 2500, activation = tf.nn.relu )( grids_and_sources )
dense_2 = tf.keras.layers.Dense( units = 500, activation = tf.nn.relu )( dense_1)
logits = tf.keras.layers.Dense( units = 4, activation = "softmax" )( dense_2 )
cnn_model = tf.keras.Model(inputs=[grid_input, source_input], outputs=logits)
cnn_model.compile( optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'] )
print(cnn_model.summary())
return cnn_model
# Training
def training_p2_cnn():
BATCH_SIZE = 10
NUM_TRAINING_MAZES = 250
total_grids = []
total_sources = []
total_actions = []
for i in range(0, NUM_TRAINING_MAZES, BATCH_SIZE):
print(f"Training starting on maze {i}")
total_grids, total_sources, total_actions = gen_grid_data_agent3(BATCH_SIZE, agent3, 2)
print(np.shape(total_grids), np.shape(total_sources), np.shape(total_actions))
history = model.fit([total_grids,total_sources], total_actions, epochs = 3, validation_split=.1)
model_cnn_p2 = tf.keras.models.load_model('./test_data_cnn_p2_regular/')
from random import choice
from timeit import default_timer as timer
import tensorflow as tf
import numpy as np
action_movement_dict = {
0 : (-1,0), # North
1 : (0, 1), # East
2 : (1, 0), # South
3 : (0, -1) # West
}
FRAME_SIZE = 1
def run_nn_agent_agent3(complete_grid, model, full_path, data_func=None):
knowledge_base = create_initial_kb(complete_grid)
grid_len = len(complete_grid)
n, m = grid_len-1, grid_len-1
original_source = (0,0)
source, goal = (0, 0), (n,m)
# full_path = [source]
path = []
solvable = True
num_bumps = trajectory = num_actions = miss = correct = 0
# movement_grid = [[-1]*m for _ in range(n)]
rounds = 0 # REMOVE ME
while source != goal:
rounds += 1
# Generate Current grid from Knowledge base
grid = np.zeros((len(complete_grid), len(complete_grid[0])))
for row_idx, row in enumerate(complete_grid):
for col_idx, _ in enumerate(row):
blocked = 1 if knowledge_base[row_idx][col_idx].status == CellState.BLOCKED else 0
grid[row_idx][col_idx] = blocked
# Update KB
row_idx, col_idx = source
if not knowledge_base[row_idx][col_idx].visited:
if complete_grid[row_idx][col_idx] == 0:
knowledge_base[row_idx][col_idx].C = _get_num_blocked_neighbors(grid, row_idx, col_idx)
knowledge_base[row_idx][col_idx].visited = True
if knowledge_base[row_idx][col_idx].status == CellState.UNKNOWN:
cell_state = CellState.BLOCKED if complete_grid[row_idx][col_idx] == 1 else CellState.OPEN
update_cell_status(cell_state, knowledge_base, row_idx, col_idx, grid_len)
if cell_state == CellState.BLOCKED:
source = prev
run_inference(grid_len, knowledge_base, _run_single_inference_agent3)
# Dense
if data_func == 1:
grid_state = []
for i, (dx, dy) in enumerate([(0,0), (0,1), (1,0), (0,-1), (-1,0), (1,1), (-1,-1), (1,-1), (-1,1)]):
new_x, new_y = [row_idx + dx, col_idx + dy]
if 0 <= new_x < n and 0 <= new_y < m:
grid_state.extend([new_x, new_y, *(Intel.convert(knowledge_base[new_x][new_y]))])
else:
grid_state.extend([new_x, new_y, *([0] * INTEL_SIZE)])
input_data = np.array([grid_state]) # didn't use source for input data model
# CNN
elif data_func == 2 or data_func == 3:
grid_state = np.empty(shape=(len(grid), len(grid[0]), INTEL_SIZE))
for i in range(len(grid)):
for j in range(len(grid[0])):
for k,v in enumerate(Intel.convert(knowledge_base[i][j])):
grid_state[i][j][k] = v
if data_func == 3:
grid_state = tf.reshape(grid_state, (FRAME_SIZE, MAZE_DIM, MAZE_DIM, INTEL_SIZE))
input_data = (np.array([grid_state]), np.array([source]))
### Prediction ###
out = model(input_data).numpy()
prediction = np.argmax(out, axis = 1)[0]
# num_actions += 1
# If attempted action is not valid for some reason, pick a random one that is
# Also, if we're attempting to go the same way we already have (in a loop), run a_star to get out of the loop
action = action_movement_dict[prediction]
x, y = (source[0] + action[0], source[1] + action[1])
if (x,y) != full_path[rounds]:
miss += 1
(x,y) = full_path[rounds]
else:
correct += 1
grid[x][y] = complete_grid[x][y]
trajectory += 1
prev = source
source = (x, y)
if grid[x][y] == 1:
continue
path.append(source)
# trajectory = correct + miss
return path, trajectory, correct, miss
# due to how slow this runs only 10 (trials) x 10 (mazes) was done
nn_means_path = [175.4, 175.5, 169.2, 160.8, 153.5, 158.0, 163.8, 160.4, 171.0, 166.7]
nn_means_traj = [322.2, 312.8, 311.0, 281.0, 272.0, 283.1, 295.0, 287.6, 329.9, 303.4]
nn_means_times =[3.8288494416963657, 3.524037164998299, 3.4668822037012434, 3.129986681700393, 3.169014308699116, 3.3250332264986353, 4.867095653699653, 4.109915263901348, 3.8719698475979385, 4.09942249039741]
nn_means_corrects =[82.5, 77.8, 75.1, 69.0, 67.5, 71.6, 72.0, 69.9, 82.0, 74.1]
nn_means_misses = [239.7, 235.0, 235.9, 212.0, 204.5, 211.5, 223.0, 217.7, 247.9, 229.3]
cnn_means_path = [175.4, 175.5, 169.2, 160.8, 153.5, 158.0, 163.8, 160.4, 171.0, 166.7]
cnn_means_traj = [322.2, 312.8, 311.0, 281.0, 272.0, 283.1, 295.0, 287.6, 329.9, 303.4]
cnn_means_times =[31.878190686796735, 30.192351302296448, 29.74190882630064, 27.12558812719799, 26.970018435301608, 27.708891010099613, 45.04423980590072, 34.643869927999916, 33.82419218010182, 35.71657393649948]
cnn_means_corrects = [89.6, 87.7, 87.9, 80.0, 77.3, 80.2, 83.7, 81.0, 94.2, 84.4]
cnn_means_misses = [232.6, 225.1, 223.1, 201.0, 194.7, 202.9, 211.3, 206.6, 235.7, 219.0]
a_means_path = [124.6, 138.4, 131.6, 128.0, 127.2, 121.6, 127.8, 126.6, 129.8, 128.6]
a_means_traj = [209.6, 204.2, 197.0, 182.6, 173.6, 182.8, 188.0, 183.8, 203.8, 192.8]
a_means_times = [0.41573393699800365, 0.3616427813001792, 0.3975740644018515, 0.32227378389798106, 0.3559665934983059, 0.338859436500934, 0.6192624611008796, 0.45207660400046734, 0.4427161724001053, 0.4525190754982759]
a_means_corrects = a_means_traj
a_means_misses = [0] * len(a_means_traj)
import matplotlib.pyplot as plt
agents = ["NN", "CNN", "A Star"]
paths = [nn_means_path, cnn_means_path, a_means_path]
for path in paths:
idx = list(range(len(path)))
plt.plot(idx, path)
plt.legend(["NN", "CNN", "A Star"])
plt.xlabel("Agents")
plt.ylabel("Average Path")
plt.title("Average Path For Each Agent")
plt.grid()
plt.show()
agents = ["NN", "CNN", "A Star"]
trajs = [nn_means_traj, cnn_means_traj, a_means_traj]
for traj in trajs:
idx = list(range(len(traj)))
plt.plot(idx, traj)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Trajectory")
plt.title("Average Trajectory For Each Agent")
plt.grid()
plt.show()
agents = ["NN", "CNN", "A Star"]
times = [nn_means_times, cnn_means_times, a_means_times]
for time in times:
idx = list(range(len(time)))
plt.plot(idx, time)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Time")
plt.title("Average Time For Each Agent")
plt.grid()
plt.show()
agents = ["NN", "CNN", "A Star"]
corrects = [nn_means_corrects, cnn_means_corrects, a_means_corrects]
for correct in corrects:
idx = list(range(len(correct)))
plt.plot(idx, correct)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Corrects")
plt.title("Average Corrects For Each Agent")
plt.grid()
plt.show()
agents = ["NN", "CNN", "A Star"]
misses = [nn_means_misses, cnn_means_misses, a_means_misses]
for miss in misses:
idx = list(range(len(miss)))
plt.plot(idx, miss)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Misses")
plt.title("Average Misses For Each Agent")
plt.grid()
plt.show()
def run_nn_agent_agent3_v2(complete_grid, model):
original_source = (0,0) # we need to remember the original_source node bc in repeated A* source changes
source = original_source
n, m = (len(complete_grid), len(complete_grid[0]))
goal = (n-1, m-1) # dimensions of grid
num_actions = trajectory = 0
complete_grid = np.array(complete_grid)
grid = np.array([[0.5]*(m) for _ in range(n)])
grid[source[0]][source[1]] = 0.0
path = [source]
grid_states = []
movement_grid = [[-1]*m for _ in range(n)]
while source != goal: # execution step
valid_actions = []
# Generate input data for agent - grid, source + neighbors tuple
source_neighbors_data = [source[0], source[1]]
for dx, dy in [(0, 1), (-1, 0), (0, -1), (1, 0)]: # generate children
new_x, new_y = (source[0] + dx, source[1] + dy)
if 0 <= new_x < len(grid) and 0 <= new_y < len(grid):
source_neighbors_data.extend([new_x, new_y, grid[new_x][new_y]])
if grid[new_x][new_y] != 1:
valid_actions.append((dx, dy))
else:
source_neighbors_data.extend([0,0,1])
input_data = (np.array([grid]), np.array([source_neighbors_data]))
# Give agent grid and (try to) take step by agent
out = model(input_data).numpy()
prediction = np.argmax( out , axis = 1 )[0]
# Move to predicted state, if possible
num_actions += 1
action_movement_dict = {
0 : (-1,0), # North
1 : (0, 1), # East
2 : (1, 0), # South
3 : (0, -1) # West
}
# If attempted action is not valid for some reason, pick a random one that is
# Also, if we're attempting to go the same way we already have (in a loop), run a_star to get out of the loop
action = action_movement_dict[prediction]
if action not in valid_actions:
action = choice(valid_actions)
elif movement_grid[source[0]][source[1]] == prediction:
valid_actions.remove(action)
a_star_path, *_ = a_star(deepcopy(grid), source, goal)
idx = 1
next = a_star_path[0]
# Follow a_star until we hit a block or we are no longer in the loop
while source in path:
if idx >= len(a_star_path):
# print("Finished here")
break
curr = a_star_path[idx]
x, y = curr
grid[x][y] = complete_grid[x][y]
if grid[x][y] == 1: # If we hit a new block, break
break
# Move from source to curr
trajectory += 1
source = (x, y)
path.append(source)
idx += 1
continue
movement_grid[source[0]][source[1]] = prediction
# Move to new cell
x, y = (source[0] + action[0], source[1] + action[1])
grid[x][y] = complete_grid[x][y]
if grid[x][y] == 1:
continue
trajectory += 1
source = (x, y)
path.append(source)
# print(f"current path: {path[-10:]}")
return path, trajectory, num_actions
def _run_single_inference_agent3(grid_len: int, kb: List[List[Intel]]) -> bool:
changed = False
for row_idx, row in enumerate(kb):
for col_idx, intel in enumerate(row):
# If the cell has not been visited yet, we can't infer anything about its neighbors
# since C has not been set
# If Hx=0: nothing remains to be inferred about cell x.
if not intel.visited or intel.H == 0:
continue
# If Cx=Bx: all remaining hidden neighbors of x are empty.
if intel.C == intel.B:
for (n_row, n_col) in _get_neighbor_indices(grid_len, row_idx, col_idx):
if kb[n_row][n_col].status == CellState.UNKNOWN:
update_cell_status(CellState.OPEN, kb, n_row, n_col, grid_len)
changed = True
# If Nx−Cx=Ex: all remaining hidden neighbors of x are blocked.
elif intel.N - intel.C == intel.E:
for (n_row, n_col) in _get_neighbor_indices(grid_len, row_idx, col_idx):
if kb[n_row][n_col].status == CellState.UNKNOWN:
update_cell_status(CellState.BLOCKED, kb, n_row, n_col, grid_len)
changed = True
return changed
def agent3_nn(grid: List[List[int]], inference: Callable = _run_single_inference_agent3):
"""
Args:
Returns:
list(tuple(int, int)): the path from the source node to the goal node which consists of coordinate points (i.e. x,y) in the form of tuples.
int: the length of the discovered path from source to goal (i.e. the length of the first return value)
list(tuple(int,int)): all visited squares from the source point. Contains int tuples in the form of (x,y).
int: trajectory - num of moves the agent makes
int: num of times the agent assumes wrong and enters a blocked square
"""
knowledge_base = create_initial_kb(grid)
grid_len = len(grid)
source, goal = (0, 0), (grid_len - 1, grid_len - 1)
full_path = [source]
solvable = True
num_bumps = 0
seen = defaultdict(int)
while goal != full_path[-1]:
# Create current grid based on current knowledge base
current_grid = []
for row_idx, row in enumerate(grid):
current_grid.append([])
for col_idx, _ in enumerate(row):
blocked = 1 if knowledge_base[row_idx][col_idx].status == CellState.BLOCKED else 0
current_grid[row_idx].append(blocked)
path, _, _, _, _ = a_star(current_grid, manhattan_heuristic, source)
for path_idx, (row_idx, col_idx) in enumerate(path):
ran_inference = False
if not knowledge_base[row_idx][col_idx].visited:
if grid[row_idx][col_idx] == 0:
knowledge_base[row_idx][col_idx].C = \
_get_num_blocked_neighbors(grid, row_idx, col_idx)
knowledge_base[row_idx][col_idx].visited = True
if knowledge_base[row_idx][col_idx].status == CellState.UNKNOWN:
cell_state = \
CellState.BLOCKED if grid[row_idx][col_idx] == 1 else CellState.OPEN
update_cell_status(cell_state, knowledge_base, row_idx, col_idx, grid_len)
run_inference(grid_len, knowledge_base, inference)
ran_inference = True
# If we hit a block, restart and try again. Otherwise, continue
if grid[row_idx][col_idx] == 1 and full_path != []:
num_bumps += 1
source = full_path[-1]
break
# Only add to the path if we aren't at the source
# (otherwise we would have duplicates on restarts)
if (row_idx, col_idx) != source:
full_path.append((row_idx, col_idx))
# If there is a block in our current path, but the current cell was not blocked
if ran_inference and any(knowledge_base[x][y].status == CellState.BLOCKED for (x,y) in path[path_idx+1:]):
source = full_path[-1]
break
if solvable:
# Create current grid based on current knowledge base
current_grid = []
for row_idx, row in enumerate(grid):
current_grid.append([])
for col_idx, _ in enumerate(row):
blocked = 1 if knowledge_base[row_idx][col_idx].status == CellState.BLOCKED or not knowledge_base[row_idx][col_idx].visited else 0
current_grid[row_idx].append(blocked)
shortest_path, _, _, _, _ = a_star(current_grid, manhattan_heuristic, (0,0))
return shortest_path, len(shortest_path), full_path, len(full_path), num_bumps
return [], 0, full_path, len(full_path), num_bumps
def create_initial_kb(grid: List[List[int]]) -> List[List[Intel]]:
kb = []
for row_idx, row in enumerate(grid):
kb.append([])
for col_idx, _ in enumerate(row):
N = len(_get_neighbor_indices(len(grid), row_idx, col_idx))
kb[row_idx].append(Intel(N))
return kb
def run_inference(
grid_len: int,
kb: List[List[Intel]],
single_inference_func : Callable = _run_single_inference_agent3
):
changed = True
while changed:
changed = single_inference_func(grid_len, kb)
def update_cell_status(
new_state: CellState,
kb: List[List[Intel]],
row_idx: int,
col_idx: int,
grid_len: int
) -> None:
# Check if the status can be changed
if kb[row_idx][col_idx].status != CellState.UNKNOWN: return
if new_state == CellState.UNKNOWN: raise Exception("Cell state to change to cannot be unknown")
kb[row_idx][col_idx].status = new_state
# If we change the status, change neighbor info
neighbors = _get_neighbor_indices(grid_len, row_idx, col_idx)
for (row, col) in neighbors:
kb[row][col].H -= 1
if new_state == CellState.BLOCKED:
kb[row][col].B += 1
elif new_state == CellState.OPEN:
kb[row][col].E += 1
def _get_num_blocked_neighbors(
grid: List[List[int]],
row_idx: int,
col_idx: int
) -> int:
num_blocked = 0
neighbors = _get_neighbor_indices(len(grid), row_idx, col_idx)
for row, col in neighbors:
if grid[row][col] == 1:
num_blocked += 1
return num_blocked
def _get_neighbor_indices(grid_len: int, row_idx: int, col_idx: int) -> List[Tuple[int, int]]:
neighbors = []
for row in range(row_idx - 1, row_idx + 2):
if row < 0 or row >= grid_len:
continue
for col in range(col_idx - 1, col_idx + 2):
if row_idx == row and col_idx == col:
continue
elif col < 0 or col >= grid_len:
continue
neighbors.append((row, col))
return neighbors
# CNN w/ Max Pooling
# Model Building
def create_p2_cnn_model():
# Source input
source_input = tf.keras.layers.Input( shape = (2) )
# Grid input
grid_input = tf.keras.layers.Input(shape=(MAZE_DIM, MAZE_DIM, INTEL_SIZE))
conv2d_1 = tf.keras.layers.Conv2D(filters=256,kernel_size=2,padding="same",activation="relu")(grid_input)
maxpool2d_1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_1 )
conv2d_11 = tf.keras.layers.Conv2D(filters=128,kernel_size=2,padding="same",activation="relu")(maxpool2d_1)
maxpool2d_11 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_11 )
conv2d_12 = tf.keras.layers.Conv2D(filters=64,kernel_size=2,padding="same",activation="relu")(maxpool2d_11)
maxpool2d_12 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_12 )
conv2d_13 = tf.keras.layers.Conv2D(filters=32,kernel_size=2,padding="same",activation="relu")( maxpool2d_12 )
maxpool2d_13 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_13 )
conv2d_3 = tf.keras.layers.Conv2D(filters=4,kernel_size=2,padding="same",activation="relu")( maxpool2d_13 )
maxpool2d_3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')( conv2d_3 )
flatten = tf.keras.layers.Flatten()( maxpool2d_3 )
grids_and_sources = tf.keras.layers.Concatenate()([flatten, source_input])
dense_1 = tf.keras.layers.Dense( units = 2500, activation = tf.nn.relu )( grids_and_sources )
dense_2 = tf.keras.layers.Dense( units = 500, activation = tf.nn.relu )( dense_1)
logits = tf.keras.layers.Dense( units = 4, activation = "softmax" )( dense_2 )
model = tf.keras.Model(inputs=[grid_input, source_input], outputs=logits)
model.compile( optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'] )
print(model.summary())
return model
# RNN
FRAME_SIZE = 1
def creating_rnn_p2():
# Source Input
source_input = tf.keras.layers.Input(shape=(2))
# Grid Input
grid_input = tf.keras.layers.Input(shape=(FRAME_SIZE, MAZE_DIM, MAZE_DIM, INTEL_SIZE))
conv2d_lstm_1 = tf.keras.layers.ConvLSTM2D(filters=128, kernel_size=2, padding="same", return_sequences=True)(grid_input)
batchNormalization_1 = tf.keras.layers.BatchNormalization()(conv2d_lstm_1)
conv2d_lstm_2 = tf.keras.layers.ConvLSTM2D(filters=64, kernel_size=2, padding="same", return_sequences=True)(batchNormalization_1)
batchNormalization_2 = tf.keras.layers.BatchNormalization()(conv2d_lstm_2 )
conv2d_lstm_3 = tf.keras.layers.ConvLSTM2D(filters=32, kernel_size=2, padding="same", return_sequences=True)(batchNormalization_2)
batchNormalization_3 = tf.keras.layers.BatchNormalization()(conv2d_lstm_3)
conv2d_lstm_4 = tf.keras.layers.ConvLSTM2D(filters=16, kernel_size=2, padding="same", return_sequences=True)(batchNormalization_3)
batchNormalization_4 = tf.keras.layers.BatchNormalization()(conv2d_lstm_4)
flatten = tf.keras.layers.Flatten()(batchNormalization_4)
grids_and_sources = tf.keras.layers.Concatenate()([flatten, source_input])
dense_1 = tf.keras.layers.Dense( units = 100, activation = tf.nn.relu )( grids_and_sources )
logits = tf.keras.layers.Dense( units = 4, activation = "softmax" )( dense_1 )
cnn_model = tf.keras.Model(inputs=[grid_input, source_input], outputs=logits)
cnn_model.compile( optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'] )
print(cnn_model.summary())
return cnn_model
## Testing
def run_iteration_bonus(i=0):
rnn_path_lens = []
rnn_trajs = []
rnn_times = []
rnn_corrects = []
rnn_misses = []
cnn_max_path_lens = []
cnn_max_trajs = []
cnn_max_times = []
cnn_max_corrects = []
cnn_max_misses = []
agent_path_lens = []
agent_trajs = []
agent_times = []
# a_corrects = []
# a_misses = []
print(f"starting iteration {i}")
for j in range(NUM_TRIALS):
print(f"trial {(i, j)}")
grid, _, _ = gen_gridworld(MAZE_DIM, DENSITY, SOLVABILITY, IS_TERRAIN_ENABLED, source, target)
start = timer()
shortest_path, len_path, _, _, _, full_path, traj, _, full_path_and_blocks = agent3(grid)
stop = timer()
agent_star_time = stop - start
agent_path_lens.append(len_path)
agent_trajs.append(traj)
agent_times.append(agent_star_time)
start = timer()
path, traj, corrects, misses = run_nn_agent_agent3(grid, p2_rnn_model, full_path_and_blocks, 3)
stop = timer()
rnn_time = stop - start
rnn_path_lens.append(len(set(path)))
rnn_trajs.append(traj)
rnn_times.append(rnn_time)
rnn_corrects.append(corrects)
rnn_misses.append(misses)
start = timer()
path, traj, corrects, misses = run_nn_agent_agent3(grid, p2_cnn_model_ec, full_path_and_blocks ,2)
stop = timer()
cnn_max_time = stop - start
cnn_max_path_lens.append(len(set(path)))
cnn_max_trajs.append(traj)
cnn_max_times.append(cnn_max_time)
cnn_max_corrects.append(corrects)
cnn_max_misses.append(misses)
print("RNN Path", rnn_path_lens)
print("RNN Traj", rnn_trajs)
print("RNN Time", rnn_times)
print("RNN Corrects", rnn_corrects)
print("RNN Misses", rnn_misses)
print("CNN Maxpool Path", cnn_max_path_lens)
print("CNN Maxpool Traj", cnn_max_trajs)
print("CNN Maxpool Time", cnn_max_times)
print("CNN Maxpool Corrects", cnn_max_corrects)
print("CNN Maxpool Misses", cnn_max_misses)
print("Agent Path", agent_path_lens)
print("Agent Traj", agent_trajs)
print("Agent Time", agent_times)
# print("Agent Corrects", agent_corrects)
# print("Agent Misses", agent_misses)
print("")
rnn_means_path.append(sum(rnn_path_lens)/NUM_TRIALS)
rnn_means_traj.append(sum(rnn_trajs)/NUM_TRIALS)
rnn_means_times.append(sum(rnn_times)/NUM_TRIALS)
rnn_means_corrects.append(sum(rnn_corrects)/NUM_TRIALS)
rnn_means_misses.append(sum(rnn_misses)/NUM_TRIALS)
cnn_max_means_path.append(sum(cnn_max_path_lens)/NUM_TRIALS)
cnn_max_means_traj.append(sum(cnn_max_trajs)/NUM_TRIALS)
cnn_max_means_times.append(sum(cnn_max_times)/NUM_TRIALS)
cnn_max_means_corrects.append(sum(cnn_max_corrects)/NUM_TRIALS)
cnn_max_means_misses.append(sum(cnn_max_misses)/NUM_TRIALS)
agent_means_path.append(sum(agent_path_lens)/NUM_TRIALS)
agent_means_traj.append(sum(agent_trajs)/NUM_TRIALS)
agent_means_times.append(sum(agent_times)/NUM_TRIALS)
# a_means_corrects.append(sum(a_corrects)/NUM_TRIALS)
# a_means_misses.append(sum(a_misses)/NUM_TRIALS)
rnn_means_path = []
rnn_means_traj = []
rnn_means_times = []
rnn_means_corrects = []
rnn_means_misses = []
cnn_max_means_path = []
cnn_max_means_traj = []
cnn_max_means_times = []
cnn_max_means_corrects = []
cnn_max_means_misses = []
agent_means_path = []
agent_means_traj = []
agent_means_times = []
agent_means_corrects = []
agent_means_misses = []
p2_rnn_model = tf.keras.models.load_model('./test_data_rnn_p2/')
p2_cnn_model_ec = tf.keras.models.load_model('./test_data_cnn_p2_ec')
NUM_ITERATIONS = 5 #30
NUM_TRIALS = 5 #50
MAZE_DIM = 50
DENSITY = 0.3
SOLVABILITY = True
IS_TERRAIN_ENABLED = False
source = (0,0)
target = (MAZE_DIM-1, MAZE_DIM-1)
for i in range(NUM_ITERATIONS):
run_iteration_bonus(i)
agent_means_corrects = agent_means_traj
agent_means_misses = [0] * len(agent_means_traj)
#
print(rnn_means_path,rnn_means_traj,rnn_means_times,rnn_means_corrects,rnn_means_misses,cnn_max_means_path,cnn_max_means_traj,cnn_max_means_times,cnn_max_means_corrects,cnn_max_means_misses,agent_means_path,agent_means_traj,agent_means_times,agent_means_corrects,agent_means_misses)
# due to how long run the simulations take we were only able to do 5 (trials) x 5 (mazes)
agents = ["RNN", "CNN Maxpool", "A Star"]
paths = [rnn_means_path, cnn_max_means_path, agent_means_path]
for path in paths:
idx = list(range(len(path)))
plt.plot(idx, path)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Path")
plt.title("Average Path For Each Agent")
plt.grid()
plt.show()
agents = ["RNN", "CNN Maxpool", "A Star"]
trajs = [rnn_means_traj, cnn_max_means_traj, agent_means_traj]
for traj in trajs:
idx = list(range(len(traj)))
plt.plot(idx, traj)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Trajectory")
plt.title("Average Trajectory For Each Agent")
plt.grid()
plt.show()
agents = ["RNN", "CNN Maxpool", "A Star"]
times = [rnn_means_times, cnn_max_means_times, agent_means_times]
for time in times:
idx = list(range(len(time)))
plt.plot(idx, time)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Time")
plt.title("Average Time For Each Agent")
plt.grid()
plt.show()
agents = ["RNN", "CNN Maxpool", "A Star"]
corrects = [rnn_means_corrects, cnn_max_means_corrects, agent_means_corrects]
for correct in corrects:
idx = list(range(len(correct)))
plt.plot(idx, correct)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Corrects")
plt.title("Average Corrects For Each Agent")
plt.grid()
plt.show()
agents = ["RNN", "CNN Maxpool", "A Star"]
misses = [rnn_means_misses, cnn_max_means_misses, agent_means_misses]
for miss in misses:
idx = list(range(len(miss)))
plt.plot(idx, miss)
plt.legend(agents)
plt.xlabel("Agents")
plt.ylabel("Average Misses")
plt.title("Average Misses For Each Agent")
plt.grid()
plt.show()
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
def build_model(states, actions):
# CREATE NN MODEL
grid_input = tf.keras.layers.Input( shape = (MAZE_DIM, MAZE_DIM) )
flatten_grids = tf.keras.layers.Flatten()( grid_input )
# Source input
source_input = tf.keras.layers.Input( shape = (14) )
# Full input
grids_and_sources = tf.keras.layers.Concatenate()([flatten_grids, source_input])
dense_1 = tf.keras.layers.Dense( units = 2000, activation = tf.nn.relu )( grids_and_sources )
dense_2 = tf.keras.layers.Dense( units = 500, activation = tf.nn.relu )( dense_1 )
probabilities = tf.keras.layers.Dense( units = 4, activation = tf.nn.linear )( dense_2 )
return model
model = build_model(states, actions)
model.summary
from rl.agents import DQNAgent
from rl.memory impoort SequentialMemory
def build_agent(model, actions):
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=actions, nb_steps_warmup, target_model_update=1e-2)
retturn dqn
dqn = build_agent(model,actions)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
dqn.fit(env, nb_steps=50000, visualize=False, verbose=1
# Create the environment
from gym import Env
from gym.spaces import Discrete, Box
import numpy as np
import random
action_movement_dict = {
0 : (-1,0), # North
1 : (0, 1), # East
2 : (1, 0), # South
3 : (0, -1) # West
}
class GridEnv(Env):
def __init__(self, n=50):
# TODO: Create grid?
self.grid_len = n
self.action_space = Discrete(4)
self.observation_space = Discrete(n*n)
self.state = (0,0)
def step(self, action):
dx, dy = action_movement_dict[action]
# TODO: Check validity of state
self.state[0] += dx
self.state[0] += dy
if self.state == (self.grid_len - 1, self.grid_len - 1):
done = True
else:
done = False
return self.state, reward, done, info
def render(self):
pass
def reset(self):
pass