自己組織化マップ (SOM) がアルファベットを分類するアニメーション

動作イメージ

SOM3.mp4

Untitled

ソースコード (Processing)

int nu = 10; // number of units
float cc = 0.1; // learning rate
int ns; // number of sample
int nr_v = 5; // vertical size
int nr_h = 7; // horizontal size
int nr = nr_v*nr_h; // number of reference vector
int nn = 1; // number of neighboring unit
float[][][]c = new float[nu][nu][nr]; // reference vectors

// display parameters
float us = 0.9;
int[][] sample = new int[ns][nr];

//初期設定
void setup() {
  sample = getLetterDigit(); //文字
  ns = sample.length;
  frameRate(1);
  size(500, 600);
  ellipseMode(CENTER);
  for (int i=0;i<nu;i++) {
    for (int j=0;j<nu;j++) {
      for (int k=0;k<nr;k++) {
        c[i][j][k] = random(0, 1); //乱数で初期化
      }
    }
  }
}

// 繰返し処理
void draw() {
  drawUnits();
  drawWinners();
  updateW();
  //描画速度(フレームレート)の変更
  if (frameCount > 15) {
    frameRate(30);
  }
}

// ユニットと参照ベクトルを描画
void drawUnits() {
  background(0);
  translate(25, 25);
  scale(50);
  for (int i=0;i<nu;i++) {
    for (int j=0;j<nu;j++) {
      noFill();
      strokeWeight(0.01);
      ellipse(i, j, us, us);
      // reference vectorを表示
      scale(0.1);
      stroke(255);
      strokeWeight(0.1);
      int ss = 0;
      for (int ii=0;ii<nr_h;ii++) {
        for (int jj=0;jj<nr_v;jj++) {
          fill(c[i][j][ss]*255);
          ellipse(i*10 + jj-5/2, j*10 + ii-7/2, 1, 1);
          ss++;
        }
      }
      scale(10);
    }
  }
}

// 各アルファベットに対する勝者ユニットを描画
void drawWinners() {
  for (int l=0;l<ns;l++) {
    int[] ii = new int[2];
    float d = 9999.9;
    for (int i=0;i<nu;i++) {
      for (int j=0;j<nu;j++) {
        float tmp = 0;
        for (int k=0;k<nr;k++) {
          tmp += abs(sample[l][k] - c[i][j][k]);
        }
        if (tmp < d) {
          ii[0] = i;
          ii[1] = j;
          d = tmp;
        }
      }
    }
    // 勝者ユニットを緑円で囲む
    noFill();
    stroke(0, 255, 0, 100);
    strokeWeight(0.03);
    ellipse(ii[0], ii[1], us, us);
    // 対応するアルファベットを表示
    scale(0.1);
    fill(0, 255, 0);
    textSize(3);
    text(getLetter(l), ii[0]*10-5, ii[1]*10-3);
    scale(10);
  }
}

// 入力と重みの更新
void updateW() {
  //入力サンプルを乱数で決定
  int[]input = new int[nr];
  int i_s = int(random(0, ns-1));
  for (int k=0;k<nr;k++) {
    input[k] = sample[i_s][k];
  }
  //入力サンプルを描画
  stroke(255);
  strokeWeight(0.1);
  pushMatrix();
  translate(0, nu-0.1);
  scale(0.2);
  fill(255);
  int ss = 0;
  for (int i=0;i<nr_h;i++) {
    for (int j=0;j<nr_v;j++) {
      switch(input[ss]) {
      case 0:
        noFill();
        break;
      case 1:
        fill(255);
        break;
      }
      ellipse(j, i, 1, 1);
      ss++;
    }
  }
  scale(0.1);
  textSize(12);
  fill(255);
  text("step = "+frameCount, 60, 2);
  textSize(20);
  fill(0, 255, 0);
  text(getLetter(i_s), 60, 35);
  popMatrix();

  // 勝者ユニットの選択
  int[] ii = new int[2];
  float d = 9999.9;
  for (int i=0;i<nu;i++) {
    for (int j=0;j<nu;j++) {
      float tmp = 0;
      for (int k=0;k<nr;k++) {
        tmp += abs(input[k] - c[i][j][k]);
      }
      if (tmp < d) {
        ii[0] = i;
        ii[1] = j;
        d = tmp;
      }
    }
  }
  stroke(255);
  strokeWeight(0.1);
  noFill();
  ellipse(ii[0], ii[1], us, us);
  strokeWeight(0.03);
  line(0.4, nu-0.3, ii[0], ii[1]);
  // 近傍ユニットの重みを更新
  for (int i=ii[0]-nn;i<=ii[0]+nn;i++) {
    for (int j=ii[1]-nn;j<=ii[1]+nn;j++) {
      int[] tmp = new int[2];
      tmp[0] = i;
      tmp[1] = j;
      for (int l=0;l<2;l++) {
        if (tmp[l] < 0) {
          tmp[l] = nu-1;
        }
        else if (tmp[l] > nu-1) {
          tmp[l] = 0;
        }
      }
      stroke(255);
      strokeWeight(0.05);
      ellipse(tmp[0], tmp[1], us, us);
      for (int k=0;k<nr;k++) {
        float dd = (input[k] - c[tmp[0]][tmp[1]][k]) * cc;
        c[tmp[0]][tmp[1]][k] += dd;
      }
    }
  }
}

// 数字からアルファベットを得る
char getLetter(int i) {
  char[] letter = {
    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
  };
  return letter[i];
}

// アルファベットのデジタル表現
int[][] getLetterDigit() {
  int[][] letter = {
    {
      0, 0, 1, 0, 0, //A
      0, 1, 0, 1, 0,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1
    }
    ,
    {
      1, 1, 1, 1, 0, //B
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 0,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 0
    }
    ,
    {
      0, 1, 1, 1, 0, //C
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 1,
      0, 1, 1, 1, 0
    }
    ,
    {
      1, 1, 1, 1, 0, //D
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 0
    }
    ,
    {
      1, 1, 1, 1, 1, //E
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 1, 1, 1, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 1, 1, 1, 1
    }
    ,
    {
      1, 1, 1, 1, 1, //F
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 1, 1, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 1, 1, 1, 1
    }
    ,
    {
      0, 1, 1, 1, 0, //G
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 0,
      1, 0, 0, 1, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      0, 1, 1, 1, 1
    }
    ,
    {
      1, 0, 0, 0, 1, //H
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1
    }
    ,
    {
      0, 1, 1, 1, 0, //I
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 1, 1, 1, 0
    }
    ,
    {
      0, 0, 1, 1, 1, //J
      0, 0, 0, 1, 0,
      0, 0, 0, 1, 0,
      0, 0, 0, 1, 0,
      1, 0, 0, 1, 0,
      1, 0, 0, 1, 0,
      0, 1, 1, 0, 0
    }
    ,
    {
      1, 0, 0, 0, 1, //K
      1, 0, 0, 1, 0,
      1, 0, 1, 0, 0,
      1, 1, 0, 0, 0,
      1, 0, 1, 0, 0,
      1, 0, 0, 1, 0,
      1, 0, 0, 0, 1
    }
    ,
    {
      1, 0, 0, 0, 0, //L
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 1, 1, 1, 1
    }
    ,
    {
      1, 0, 0, 0, 1, //M
      1, 1, 0, 1, 1,
      1, 0, 1, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1
    }
    ,
    {
      1, 0, 0, 0, 1, //N
      1, 1, 0, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 0, 1, 1,
      1, 0, 0, 0, 1
    }
    ,
    {
      0, 1, 1, 1, 0, //O
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      0, 1, 1, 1, 0
    }
    ,
    {
      1, 1, 1, 1, 0, //P
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0,
      1, 0, 0, 0, 0
    }
    ,
    {
      0, 1, 1, 1, 0, //Q
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 0, 1, 1,
      0, 1, 1, 1, 1
    }
    ,
    {
      1, 1, 1, 1, 0, //R
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 1, 1, 1, 0,
      1, 0, 1, 0, 0,
      1, 0, 0, 1, 0,
      1, 0, 0, 0, 1
    }
    ,
    {
      0, 1, 1, 1, 0, //S
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 0,
      0, 1, 1, 1, 0,
      0, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      0, 1, 1, 1, 0
    }
    ,
    {
      1, 1, 1, 1, 1, //T
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0
    }
    ,
    {
      1, 0, 0, 0, 1, //U
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      0, 1, 1, 1, 0
    }
    ,
    {
      1, 0, 0, 0, 1, //V
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      0, 1, 0, 1, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0
    }
    ,
    {
      1, 0, 0, 0, 1, //W
      1, 0, 0, 0, 1,
      1, 0, 0, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 1, 0, 1,
      1, 0, 1, 0, 1,
      0, 1, 0, 1, 0
    }
    ,
    {
      1, 0, 0, 0, 1, //X
      0, 1, 0, 1, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 1, 0, 1, 0,
      1, 0, 0, 0, 1
    }
    ,
    {
      1, 0, 0, 0, 1, //Y
      0, 1, 0, 1, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0
    }
    ,
    {
      1, 1, 1, 1, 1, //Z
      0, 0, 0, 1, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 0, 1, 0, 0,
      0, 1, 0, 0, 0,
      1, 1, 1, 1, 1
    }
  };
  return letter;
}

// マウスイベント
void mousePressed() {
  switch(mouseButton) {
  case LEFT: //左クリックで一時停止
    loop();
    break;
  case RIGHT: //右クリックで再開
    noLoop();
    break;
  }
}

参考