!apt update
!apt install ffmpeg -y
from Maze import Maze #使用上传的Maze.py文件
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
# 设计迷宫
arr=np.array([[0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],
[0,1,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,1,0,0,1,0,0,1,1,1,1,1,0,1,1,0,1,1,1,0],
[0,1,0,0,1,0,0,0,0,0,1,0,0,1,0,0,1,0,0,0],
[0,0,0,0,1,0,0,1,1,1,0,0,1,1,0,0,1,0,0,0],
[0,0,0,0,0,0,0,1,0,0,0,1,0,1,0,1,1,0,1,1],
[1,1,1,0,1,1,0,1,0,0,1,0,0,1,0,0,1,0,0,0],
[0,0,1,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,1,0],
[0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0],
[0,0,0,0,1,0,0,1,0,0,0,0,0,1,1,1,0,0,0,0],
[1,0,1,1,1,0,1,0,0,1,0,0,0,1,0,0,0,1,0,0],
[1,0,1,1,1,0,1,0,0,1,0,0,1,1,0,0,0,1,0,0],
[1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0],
[0,0,0,0,1,0,1,0,0,1,1,0,1,0,0,0,1,1,1,0],
[0,0,1,1,1,0,1,0,0,1,0,1,0,0,1,1,0,0,0,0],
[0,1,1,0,0,0,0,1,0,1,0,0,1,1,0,1,0,1,1,1],
[0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0],
[0,0,1,1,1,0,1,1,0,0,1,0,1,0,0,1,1,0,0,0],
[1,0,0,0,1,0,1,0,0,0,1,0,1,0,0,1,1,1,0,0],
[1,1,0,0,0,0,1,0,0,0,1,0,1,0,0,0,0,1,0,0]
],dtype=float)
#老鼠在左上角,默认要走到右下角
rat=(0,0)
maze=Maze(arr,rat)
maze.show_maze()
# 编写强化学习计算Q矩阵的方法
import numpy as np
import random as random
class SarsaAgent():
def __init__(self,maze):
self.maze=maze
nrow=maze.nrow
ncol=maze.ncol
state_number=nrow*ncol
action_number=len(maze.actions)
self.q={}
#q矩阵的格式是{'(0,0)':[0.5,0.5,0.5,0.5],'(0,1)':[0.5,0.5,0.5,0.5],...}
#共有20x20=400行state状态即坐标格子,共有[上,下,左,右]四列action动作默认都是0.5
#这时候还都是空的,get_q(...)的时候会自动填充
#根据行state和列action获取q矩阵中的值;
#如果还没有的话就自动填充0.5
def get_q(self,state,action=None):
row,col,_=state
if action is None:
if (row,col) in self.q.keys():
return self.q[(row,col)]
else:
self.q[(row,col)]=np.full((4,),0.5)
return self.q[(row,col)]
else:
if (row,col) in self.q.keys():
return self.q[(row,col)][action]
else:
self.q[(row,col)]=np.full((4,),0.5)
return self.q[(row,col)][action]
#根据行state和列action设置(更新)q矩阵中的值;
def set_q(self,state,action,value):
row,col,_=state
if (row,col) not in self.q.keys():
self.q[(row,col)]=np.full((4,),0.5)
self.q[(row,col)][action]=value
#核心函数!!
#
def learn(self,episodes=100,alpha=.5,gamma=.99,epsilon=0.1):
for i in range(episodes):
self.maze.reset() #每个周期都先重置
done=False #是否走到目标点
while not done:
state=self.maze.state
p=random.random()
if p<epsilon: #随机走(探索未知)
action=random.choice(self.maze.actions)
else: #根据q矩阵选最优方向
action=np.argmax(self.get_q(state))
#maze.step自动返回下一步的情况
next_state,reward,done=self.maze.step(action)
#选下一步方向
p=random.random()
if p<epsilon: #随机走(探索未知)
next_action=random.choice(self.maze.actions)
else: #根据q矩阵选最优方向
next_action=np.argmax(self.get_q(next_state))
#根据state和action计算q值
new_val=self.get_q(state,action)+alpha*(reward+gamma
*self.get_q(next_state).max()-self.get_q(state,action))
#把这个值更新大片q矩阵中对应位置
self.set_q(state,action,new_val)
#下面两个函数是用来学习之后走路的
def get_policy(self,cell):
row,col=cell
state=(row,col,None)
action=np.argmax(self.get_q(state))
return action
def get_policy_matrix():
policy=np.copy(self.maze.maze)
policy[maze.maze==1]==5
#通过学习来计算q矩阵
agent=SarsaAgent(maze)
agent.learn(episodes=1000)
print('done')
#画图,下面要利用这个图生产动画
nrow=maze.nrow
ncol=maze.ncol
fig=plt.figure()
ax=fig.gca() #获取当前绘图坐标
ax.set_xticks(np.arange(0.5,ncol,1))
ax.set_yticks(np.arange(0.5,nrow,1))
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.grid('on')
img=ax.imshow(maze.maze,cmap="gray",)
a=5
#根据q矩阵走迷宫并保存成动画
def gen_func():
maze=Maze(arr,rat)
done=False
while not done:
row,col,_=maze.state
cell=(row,col)
action=agent.get_policy(cell)
maze.step(action)
done=maze.get_status()
yield maze.get_canvas()
def update_plot(canvas):
img.set_data(canvas)
anim=animation.FuncAnimation(fig,update_plot,gen_func)
HTML(anim.to_html5_video())