File size: 4,640 Bytes
918d1df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import numpy as np
import torch
import gym
from models.attention_model_wrapper import Agent
device = 'cpu'
ckpt_path = './runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt'
agent = Agent(device=device, name='tsp').to(device)
agent.load_state_dict(torch.load(ckpt_path))
from wrappers.syncVectorEnvPomo import SyncVectorEnv
from wrappers.recordWrapper import RecordEpisodeStatistics
env_id = 'tsp-v0'
env_entry_point = 'envs.tsp_vector_env:TSPVectorEnv'
seed = 0
gym.envs.register(
id=env_id,
entry_point=env_entry_point,
)
def make_env(env_id, seed, cfg={}):
def thunk():
env = gym.make(env_id, **cfg)
env = RecordEpisodeStatistics(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
return thunk
def inference(data):
envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=1,
max_nodes = len(data),
eval_data = 'from_input',
eval_data_from_input = data))])
trajectories = []
agent.eval()
obs = envs.reset()
done = np.array([False])
while not done.all():
# ALGO LOGIC: action logic
with torch.no_grad():
action, logits = agent(obs)
obs, reward, done, info = envs.step(action.cpu().numpy())
trajectories.append(action.cpu().numpy())
nodes_coordinates = obs['observations'][0]
final_return = info[0]['episode']['r']
resulting_traj = np.array(trajectories)[:,0,0]
return resulting_traj, final_return
default_data = np.array([[0.5488135 , 0.71518937],
[0.60276338, 0.54488318],
[0.4236548 , 0.64589411],
[0.43758721, 0.891773 ],
[0.96366276, 0.38344152],
[0.79172504, 0.52889492],
[0.56804456, 0.92559664],
[0.07103606, 0.0871293 ],
[0.0202184 , 0.83261985],
[0.77815675, 0.87001215],])
#@title Helper function for plotting
# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
def make_segments(x, y):
'''
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
an array of the form numlines x (points per line) x 2 (x and y) array
'''
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
return segments
def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
'''
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
'''
# Default colors equally spaced on [0,1]:
if z is None:
z = np.linspace(0.3, 1.0, len(x))
# Special case if a single number:
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
z = np.array([z])
z = np.asarray(z)
segments = make_segments(x, y)
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
ax = plt.gca()
ax.add_collection(lc)
return lc
def plot(coords):
fig = plt.figure()
x,y = coords.T
lc = colorline(x,y,cmap='Reds')
plt.axis('square')
return fig
import gradio as gr
def run_inference(data):
data = data.astype(float).to_numpy()
resulting_traj, final_return = inference(data)
result_text = f'Planned Tour:\t{resulting_traj}\nTotal tour length:\t{final_return[0]:.2f}'
return [plot(data[resulting_traj]),result_text]
demo = gr.Interface(run_inference, gr.Dataframe(
label = 'Input',
headers=['x','y'],
row_count=10,
col_count=(2, "fixed"),
max_rows = 10,
value = default_data.tolist(),
overflow_row_behaviour = 'show_ends'
),
[gr.Plot(label= 'Results Visualization'),
gr.Code(label= 'Results',
interactive=False)])
demo.launch(share = True)
|