強化学習 で迷路課題を解くアニメーション
float ALPHA = 0.5; float GAMMA = 0.9; float EPSILON = 0.5;
int Nsuccess = 0; // 成功した回数
int Ntrial = 1; // 試行数
int Nstep = 1; // ステップ数
int Naction = 4; //行動の数
String[] arrow = {"↑","→","↓","←"}; //up,right,down,left
int x,y,s,a,r;
int[][] map = { // 状態空間の定義:報酬
{-1,-1,-1,-1,-1,-1,-1,-1,-1},
{-1, 0, 0, 0, 0, 0, 0, 0,-1},
{-1, 0,-1,-1,-1, 0,-1, 0,-1},
{-1, 0, 0, 0, 0, 0,-1, 0,-1},
{-1, 0,-1, 0,-1, 0,-1, 0,-1},
{-1, 0,-1, 0, 0, 0, 0, 0,-1},
{-1, 0,-1,-1, 0,-1, 0,-1,-1},
{-1, 0, 0, 0, 0,-1, 0,10,-1},
{-1,-1,-1,-1,-1,-1,-1,-1,-1},
};
int Nmap = map.length; // 迷路のサイズ
int start_x = 1;int start_y = 1; // スタート位置
int goal_x = Nmap-1;int goal_y = Nmap-1; //ゴール位置
float[][] q = new float[Nmap*Nmap][Naction]; // Q値
void setup(){
noStroke(); frameRate(1); size(1050, 300);
ellipseMode(CORNER);
x = start_x; y = start_y;
s = xy2s(x,y); // 座標から状態番号へと変換
for(int i=0;i<Nmap*Nmap;i++){
for(int j=0;j<Naction;j++){
q[i][j] = random(1); //Q値の初期化(乱数)
}
}
}
void draw(){
background(255); scale(20);
translate(1,1); textSize(1);
if(Nstep > 1){
a = select_action(s); // 行動選択
switch(a){
case 0: y--; break; // UP
case 1: x++; break; // RIGHT
case 2: y++; break; // DOWN
case 3: x--; break; // LEFT
}
}
draw_arrow(); //矢印を描画
draw_map(); //迷路を描画
draw_agent(); //エージェントを描画
pushMatrix(); translate(1,0);
// 4つの行動に対するQ値を描画
for(int i=0;i<Naction;i++){
translate(Nmap+1,0);
draw_q(i);
draw_map();
}
popMatrix();
r = map[x][y]; // 現在位置の報酬を得る
// Q値を更新する
q[s][a] = (1-ALPHA)*q[s][a] + ALPHA*(r+GAMMA*max(q[xy2s(x,y)]));
if(r != 0){ // 0以外の報酬を得た場合
if(r > 0){ Nsuccess++; }; // 成功!
x = start_x; y = start_y; // スタート位置に戻る
Nstep = 0; // ステップ数をリセット
Ntrial++; // 試行数に1追加
}
// next state
s = xy2s(x,y);
Nstep++; // ステップ数に1追加
fill(0);
text("Ntrial="+Ntrial+", Nstep="+Nstep+"-> Nsuccess="+Nsuccess,0,11);
if(Ntrial>10) frameRate(60);
}
// 座標から状態番号を得る
int xy2s(int x,int y){
int s = x + y*Nmap;
return s;
}
// 状態番号の最大価値を得る
int max_a(int s){
int a=0;
for(int i=0;i<Naction;i++){
if(q[s][i] == max(q[s])){
a = i;
break;
}
}
return a;
}
// エージェントを描画する
void draw_agent(){
fill(0,255,0,100); //緑
ellipse(x,y,1,1);
}
// 迷路を描画する
void draw_map(){
for(int i=0; i<Nmap; i++){
for(int j=0; j<Nmap; j++){
if(map[i][j]<0){ // WALL
fill(0,0,255); rect(i,j,1,1);
}else if(map[i][j] > 0){ // GOAL
fill(255,0,0); rect(i,j,1,1);
}
}
}
}
// 矢印を描画する
void draw_arrow(){
fill(0);
for(int i=0; i<Nmap; i++){
for(int j=0; j<Nmap; j++){
int s = xy2s(i,j);
text(arrow[max_a(s)],i,j+0.5);
}
}
}
// Q値を描画する
void draw_q(int a){
for(int i=0; i<Nmap; i++){
for(int j=0; j<Nmap; j++){
fill(map(q[xy2s(i,j)][a],min_q(),max_q(),0,255));
rect(i,j,1,1);
fill(0); text(arrow[a],i,j+1);
}
}
}
// 行動を選択する:ε-greedy法
int select_action(int s){
if(random(0,1) > EPSILON){
a = max_a(s);
}else{
a = (int)random(Naction);
}
return a;
}
// 全ての状態の中で最大のQ値を得る
float max_q(){
float m1 = 0;
float m2 = 0;
for(int i=0;i<Nmap*Nmap;i++){
m2 = max(q[i]);
if(m2 > m1) m1 = m2;
}
return m1;
}
// 全ての状態の中で最小のQ値を得る
float min_q(){
float m1 = 999;
float m2 = 0;
for(int i=0;i<Nmap*Nmap;i++){
m2 = min(q[i]);
if(m2 < m1) m1 = m2;
}
return m1;
}