自己組織化マップ (SOM) の仕組みは非常に単純なのですが、教科書的な説明 (例えば Wikipediaの説明) では、その動作原理をイメージしにくいと思います。

このプログラムはSOMが一様乱数を学習する例をアニメーションで示しています。

動作イメージ

SOM1.mp4

Untitled

ソースコード (Processing)

float ALPHA = 0.2; float SIGMA = 0.8; //parameter
int n1 = 10; int n2 = n1; // コホネン層のニューロン数
int t = 0; int T = 100000; // 最大時間
Dot m[][] = new Dot[n1][n2]; // 参照ベクトル
Dot x[] = new Dot[T]; // 入力ベクトル

void setup(){
  size(500,500); textSize(16); frameRate(1);
  for(int i=0;i<n1;i++){
    for(int j=0;j<n2;j++){
      // 参照ベクトルの初期化
      m[i][j] = new Dot(random(width/10)+width/2, random(height/10)+height/2);
    }
  }
}

void draw(){
  background(255);
  // 入力サンプル
  x[t] = new Dot(random(width), random(height));
  drawSample(); // サンプルを描画
  drawNeuron(); // ニューロンの参照ベクトルを描画
  //最も入力サンプル(s[t])に近い参照ベクトルを持つニューロンを選ぶ
  float dist_min = 10000; int i_min = 0; int j_min = 0;
  for(int i=0;i<n1;i++){
    for(int j=0;j<n2;j++){
      float dist_tmp = dist(x[t].x, x[t].y, m[i][j].x, m[i][j].y);
      if(dist_tmp < dist_min){
        dist_min = dist_tmp; i_min = i; j_min = j;
      }
    }
  }
  // 学習処理
  for(int i=0;i<n1;i++){
    for(int j=0;j<n2;j++){
      float dist_tmp = sq(i-i_min) + sq(j-j_min);
      float denom = 2.0*sq(SIGMA);
      float hci = ALPHA*exp(-dist_tmp/denom);
      m[i][j].x = m[i][j].x + hci * (x[t].x - m[i][j].x);
      m[i][j].y = m[i][j].y + hci * (x[t].y - m[i][j].y);
    }
  }
  if(t > 20) frameRate(30); //画面の更新速度をあげる
  fill(0,0,255); text("step="+t,10,20); t++;
}

// 入力サンプルを表示する関数
void drawSample(){
  strokeWeight(5); stroke(150, 150);
  for(int i=0;i<t;i++){x[i].display();}
}
// ニューロンの参照ベクトルを表示する関数
void drawNeuron(){
  strokeWeight(7); stroke(255, 0, 0, 150);
  for(int i=0;i<n1;i++){
    for(int j=0;j<n2;j++) m[i][j].display();
  }
  // 隣接するニューロンと線を繋ぐ
  strokeWeight(2); stroke(255, 0, 0, 150);
  for(int i=0;i<n1;i++){
    for(int j=1;j<n2;j++){
      line(m[i][j-1].x, m[i][j-1].y, m[i][j].x, m[i][j].y);
    }
  }
  for(int j=0;j<n2;j++){
    for(int i=1;i<n1;i++){
      line(m[i-1][j].x, m[i-1][j].y, m[i][j].x, m[i][j].y);
    }
  }
}
// 位置を記録するためのオブジェクト
class Dot{
  float x; float y;
  Dot(float x_, float y_){
    x = x_; y = y_;
  }
  void display(){point(x,y);}
}

他の分布を使う場合は、20行目を書き換えます。例えば正規乱数の場合は、下記の通りです。

  x[t] = new Dot(map(randomGaussian(),-3,3,0,width), map(randomGaussian(),-3,3,0,height));

参考