自己組織化マップ (SOM) が色を学習するアニメーション

動作イメージ

SOM2.mp4

Untitled

ソースコード (Processing)

int nu = 10;  // number of units
float cc = 0.1; // learning ratio
int[][][]c = new int[nu][nu][3]; // reference vectors
// display parameter
float us = 0.9;

// 初期設定
void setup() {
  size(500, 600);
  frameRate(1);
  ellipseMode(CENTER);
  for (int i=0;i<nu;i++) {
    for (int j=0;j<nu;j++) {
      for (int k=0;k<3;k++) {
        c[i][j][k] = int(random(0, 255)); //乱数で初期化
      }
    }
  }
}

// 繰返し処理
void draw() {
  background(0);
  // ユニットと参照ベクトルを描画
  translate(25, 25);
  scale(50);
  for (int i=0;i<nu;i++) {
    for (int j=0;j<nu;j++) {
      fill(c[i][j][0], c[i][j][1], c[i][j][2]);
      strokeWeight(0.01);
      noStroke();
      ellipse(i, j, us, us);
    }
  }
  updateW();
  // 描画速度(フレームレート)の変更
  if (frameCount > 20) {
    frameRate(15);
  }
}

// 入力と重みの更新
void updateW() {
  // 入力サンプルを乱数で決定
  int[]input = new int[3];
    for (int k=0;k<3;k++) {
      input[k] = int(random(0, 255));
    }
  fill(input[0], input[1], input[2]);
  stroke(255);
  strokeWeight(0.1);
  ellipse(nu/2-0.5, nu+0.5, 1, 1);
  pushMatrix();
  scale(0.02);
  fill(255);
  text("RGB = ("+input[0]+","+input[1]+","+input[2]+")", (nu/2+0.5)*50, (nu+0.5)*50);
  text("step = "+frameCount, (nu/2+0.5)*50, (nu+1)*50);
  popMatrix();
  // 勝者ユニットの選択
  int[] ii = new int[2];
  float d = 999999.9;
  for (int i=0;i<nu;i++) {
    for (int j=0;j<nu;j++) {
      float tmp = dist(input[0], input[1], input[2], c[i][j][0], c[i][j][1], c[i][j][2]);
      if (tmp < d) {
        ii[0] = i;
        ii[1] = j;
        d = tmp;
      }
    }
  }
  stroke(255);
  strokeWeight(0.1);
  noFill();
  ellipse(ii[0], ii[1], us, us);
  line(nu/2-0.5, nu+0.5, ii[0], ii[1]);
  // 近傍ユニットの重みを更新
  for (int i=ii[0]-1;i<=ii[0]+1;i++) {
    for (int j=ii[1]-1;j<=ii[1]+1;j++) {
      int[] tmp = new int[2];
      tmp[0] = i;
      tmp[1] = j;
      for (int l=0;l<2;l++) {
        if (tmp[l] < 0) {
          tmp[l] = nu-1;
        }
        else if (tmp[l] > nu-1) {
          tmp[l] = 0;
        }
      }
      stroke(255);
      strokeWeight(0.05);
      ellipse(tmp[0], tmp[1], us, us);
      for (int k=0;k<3;k++) {
        float dd = (input[k] - c[tmp[0]][tmp[1]][k]) * cc;
        c[tmp[0]][tmp[1]][k] += dd;
      }
    }
  }
}

// マウスイベント
void mousePressed() {
  switch(mouseButton) {
  case LEFT: //左クリックで一時停止
    loop();
    break;
  case RIGHT: //右クリックで再開
    noLoop();
    break;
  }
}

参考