強化学習 で迷路課題を解くアニメーション

動作イメージ

maze.mp4

Untitled

ソースコード (Processing)

float ALPHA = 0.5; float GAMMA = 0.9; float EPSILON = 0.5;
int Nsuccess = 0; // 成功した回数
int Ntrial = 1; // 試行数
int Nstep = 1; // ステップ数
int Naction = 4; //行動の数
String[] arrow = {"↑","→","↓","←"}; //up,right,down,left
int x,y,s,a,r;
int[][] map = { // 状態空間の定義:報酬
  {-1,-1,-1,-1,-1,-1,-1,-1,-1},
  {-1, 0, 0, 0, 0, 0, 0, 0,-1},
  {-1, 0,-1,-1,-1, 0,-1, 0,-1},
  {-1, 0, 0, 0, 0, 0,-1, 0,-1},
  {-1, 0,-1, 0,-1, 0,-1, 0,-1},
  {-1, 0,-1, 0, 0, 0, 0, 0,-1},
  {-1, 0,-1,-1, 0,-1, 0,-1,-1},
  {-1, 0, 0, 0, 0,-1, 0,10,-1},
  {-1,-1,-1,-1,-1,-1,-1,-1,-1},
};
int Nmap = map.length; // 迷路のサイズ
int start_x = 1;int start_y = 1; // スタート位置
int goal_x = Nmap-1;int goal_y = Nmap-1; //ゴール位置
float[][] q = new float[Nmap*Nmap][Naction]; // Q値

void setup(){
  noStroke(); frameRate(1); size(1050, 300);
  ellipseMode(CORNER);
  x = start_x; y = start_y;
  s = xy2s(x,y); // 座標から状態番号へと変換
  for(int i=0;i<Nmap*Nmap;i++){
    for(int j=0;j<Naction;j++){
      q[i][j] = random(1); //Q値の初期化(乱数)
    }
  }
}

void draw(){
  background(255); scale(20);
  translate(1,1); textSize(1);
  if(Nstep > 1){
    a = select_action(s); // 行動選択
    switch(a){
      case 0: y--; break; // UP
      case 1: x++; break; // RIGHT
      case 2: y++; break; // DOWN
      case 3: x--; break; // LEFT
    }
  }
  draw_arrow(); //矢印を描画
  draw_map(); //迷路を描画
  draw_agent(); //エージェントを描画

  pushMatrix(); translate(1,0);
  // 4つの行動に対するQ値を描画
  for(int i=0;i<Naction;i++){
    translate(Nmap+1,0);
    draw_q(i);
    draw_map();
  }
  popMatrix();

  r = map[x][y]; // 現在位置の報酬を得る
  // Q値を更新する
  q[s][a] = (1-ALPHA)*q[s][a] + ALPHA*(r+GAMMA*max(q[xy2s(x,y)]));

  if(r != 0){ // 0以外の報酬を得た場合
    if(r > 0){ Nsuccess++; }; // 成功!
    x = start_x; y = start_y; // スタート位置に戻る
    Nstep = 0; // ステップ数をリセット
    Ntrial++; // 試行数に1追加
  }
  // next state
  s = xy2s(x,y);
  Nstep++; // ステップ数に1追加
  fill(0);
  text("Ntrial="+Ntrial+", Nstep="+Nstep+"-> Nsuccess="+Nsuccess,0,11);
  if(Ntrial>10) frameRate(60);
}
// 座標から状態番号を得る
int xy2s(int x,int y){
  int s = x + y*Nmap;
  return s;
}
// 状態番号の最大価値を得る
int max_a(int s){
  int a=0;
  for(int i=0;i<Naction;i++){
    if(q[s][i] == max(q[s])){
      a = i;
      break;
    }
  }
  return a;
}
// エージェントを描画する
void draw_agent(){
  fill(0,255,0,100); //緑
  ellipse(x,y,1,1);
}
// 迷路を描画する
void draw_map(){
  for(int i=0; i<Nmap; i++){
    for(int j=0; j<Nmap; j++){
      if(map[i][j]<0){ // WALL
        fill(0,0,255); rect(i,j,1,1);
      }else if(map[i][j] > 0){ // GOAL
        fill(255,0,0); rect(i,j,1,1);
      }
    }
  }
}
// 矢印を描画する
void draw_arrow(){
  fill(0);
  for(int i=0; i<Nmap; i++){
    for(int j=0; j<Nmap; j++){
      int s = xy2s(i,j);
      text(arrow[max_a(s)],i,j+0.5);
    }
  }
}
// Q値を描画する
void draw_q(int a){
  for(int i=0; i<Nmap; i++){
    for(int j=0; j<Nmap; j++){
      fill(map(q[xy2s(i,j)][a],min_q(),max_q(),0,255));
      rect(i,j,1,1);
      fill(0); text(arrow[a],i,j+1);
    }
  }
}
// 行動を選択する:ε-greedy法
int select_action(int s){
  if(random(0,1) > EPSILON){
    a = max_a(s);
  }else{
    a = (int)random(Naction);
  }
  return a;
}
// 全ての状態の中で最大のQ値を得る
float max_q(){
  float m1 = 0;
  float m2 = 0;
  for(int i=0;i<Nmap*Nmap;i++){
    m2 = max(q[i]);
    if(m2 > m1) m1 = m2;
  }
  return m1;
}
// 全ての状態の中で最小のQ値を得る
float min_q(){
  float m1 = 999;
  float m2 = 0;
  for(int i=0;i<Nmap*Nmap;i++){
    m2 = min(q[i]);
    if(m2 < m1) m1 = m2;
  }
  return m1;
}

参考