ee905\sofm.java


   1  

   2  import java.awt.*;

   3  import java.applet.*;

   4  import java.util.*;

   5  

   6  /*

   7   A colourful demonstration of SOFM

   8  

   9   Written by Simon Lucas (email sml@essex.ac.uk)

  10   Permission is granted to copy and modify this code as desired.

  11  

  12   The aim here is to capture the spirit of SOFM - I've

  13   not really paid much attention to the details e.g.

  14   choice of neighbourhood function - you can play with

  15   this by modifying the weight equation in the synapse constructor.

  16  

  17   The design of this is quite neat and simple - but I lose

  18   marks for passing a Graphics object around the SOFM implementation.

  19   This was done for reasons of efficiency - better to just update

  20   the cell colours as they change rather repaint the entire map each

  21   time.

  22  

  23   A better way to do this would be to define cell, synapse and map without

  24   any reference to a Graphics object, then subclass them with graphical versions.

  25  

  26  */

  27  

  28  class synapse {

  29    // this class models a weighted connection to a cell

  30    double weight;

  31    static double sharpness = 2.0; // higher values give sharper cutoff

  32    cell c;

  33  

  34    public static double sigmoid(double x) {

  35      return 1.0 / (1.0 + Math.exp(-x));

  36    }

  37  

  38    public synapse(cell c, double dist) {

  39      weight = 0.2 * sigmoid((sofm.limit - sharpness*dist)/sharpness);

  40      // weight = 0.2 / (1.0 + dist); // use simple function

  41      this.c = c;

  42    }

  43  

  44    public void update(double[] v, Graphics g) {

  45      c.update(v, map.rate*weight, g);

  46    }

  47  }

  48  

  49  class cell {

  50    // has three basic properties: a grid position (gridv)

  51    // an input space position (ipv)

  52    // and a set of neighbours (stored in a java.util.Vector)

  53  

  54    double[] gridv;

  55    double[] ipv;

  56    int d;

  57    static int neighboursEstimate = 30; // a rough estimate of number of neighbours - will grow automatically if needed to

  58    Vector neighbours;  // more of a list than a vector!

  59  

  60    public cell(int d, int x, int y) {

  61      // construct a cell with a d-dimensional randomly initialised vector at

  62      // point x,y on the grid

  63      ipv = new double[d];

  64      gridv = new double[2];

  65      gridv[0] = (double) x;

  66      gridv[1] = (double) y;

  67      this.d = d;

  68      randCell();

  69      neighbours = new Vector(neighboursEstimate);

  70    }

  71  

  72    public void randCell() {

  73      for (int i=0; i<d; i++)

  74        ipv[i] = Math.random();

  75    }

  76  

  77    public void addNeighbour(cell c) {

  78      neighbours.addElement(new synapse(c, dist(gridv, c.gridv)));

  79    }

  80  

  81    public void removeNeighbours() {

  82      neighbours.removeAllElements();

  83    }

  84  

  85    public void updateNeighbours(double[] v, Graphics g) {

  86      // iterate over set of neighbours updating them

  87      // in proportion to the connection weight

  88      // now clearly identify the winner i.e. this cell

  89      g.setColor(map.v2c(v));

  90      g.fillOval((int)gridv[0]*map.size, (int)gridv[1]*map.size, map.size, map.size);

  91      g.setColor(Color.black);

  92      g.drawOval(1+(int)gridv[0]*map.size, 1+(int)gridv[1]*map.size, map.size-2, map.size-2);

  93      try {Thread.sleep(100); }

  94      catch (Exception e) {}

  95      for (Enumeration e = neighbours.elements(); e.hasMoreElements(); )

  96        ((synapse) e.nextElement()).update(v, g);

  97    }

  98  

  99    public void update(double[] v, double w, Graphics g) {

 100      // updates the vector of this cell using a weighted average

 101      // of the current cell vector (ipv) and v

 102      for (int i=0; i<d; i++) {

 103        ipv[i] = (1.0 - w) * ipv[i] + w * v[i];

 104        g.setColor(map.v2c(ipv));

 105        g.fillRect((int)gridv[0]*map.size, (int)gridv[1]*map.size, map.size, map.size);

 106      }

 107    }

 108  

 109    public static double sqr (double x) {

 110      return x * x;

 111    }

 112  

 113    public static double dist(double[] x, double[] y) {

 114      double dis = 0.0;

 115      for (int i=0; i<x.length; i++)

 116        dis += sqr(x[i] - y[i]);

 117      return Math.sqrt(dis);

 118    }

 119  

 120    public double dist(double[] v) {

 121      return dist(ipv, v);

 122    }

 123  

 124  }

 125  

 126  class map {

 127    // this class is to model the 2-d map with the underlying vectors

 128    // can be set up to be a grid, each with a random neuron

 129    // then each time, update it according to the chosen input vector

 130  

 131    // this follows the cycle of pick winner, update neighbours

 132    // (each cell includes itself in its set of neighbours)

 133  

 134    int d; // d is the number of dimensions in input

 135    int n; // n is the number of points on the grid

 136    // cell[][] points;

 137    static int size = 40;  // number of pixels to use for each cell when drawing it

 138    static double rate;

 139    static double decayFac = 0.99; // no weight decay at present

 140  

 141    Vector cells;

 142  

 143    public map(int d, int n) {

 144      this.n = n;

 145      this.d = d;

 146      cells = new Vector(n*n);

 147      rate = 1.0;

 148      makeMap();

 149      setNeighbours();

 150    }

 151  

 152    public void makeMap() {

 153      for (int i=0; i<n; i++)

 154        for (int j=0; j<n; j++)

 155          cells.addElement(new cell(d, i, j));

 156    }

 157  

 158    public void reset() {

 159      rate = 1.0;

 160      for (Enumeration e = cells.elements(); e.hasMoreElements();)

 161        ((cell) e.nextElement()).randCell();

 162    }

 163  

 164  

 165    public void setNeighbours() {

 166      for (Enumeration e = cells.elements(); e.hasMoreElements();)

 167        cellNeighbours((cell) e.nextElement());

 168    }

 169  

 170    public void cellNeighbours(cell c) {

 171      c.removeNeighbours();

 172      for (Enumeration e = cells.elements(); e.hasMoreElements();) {

 173        cell cand = (cell) e.nextElement(); // cand is current candidate

 174        if (cell.dist(c.gridv, cand.gridv) < sofm.limit)

 175          c.addNeighbour(cand);

 176      }

 177    }

 178  

 179    public cell getWinner(double[] v) {

 180      double minDist = 100000.0; // choose this to be much bigger than a feasible distance

 181      cell choice = null;

 182      for (Enumeration e = cells.elements(); e.hasMoreElements();) {

 183        cell cand = (cell) e.nextElement(); // cand is current candidate

 184        double curDist = cand.dist(v);

 185        if (curDist < minDist) {

 186          choice = cand;

 187          minDist = curDist;

 188        }

 189      }

 190      return choice;

 191    }

 192  

 193    public void updateMap(Color c, Graphics g) {

 194      update(c2v(c), g); // convert the colour to a vector

 195    }

 196  

 197    public void update(double[] v, Graphics g) {

 198      cell winner = getWinner(v);

 199      winner.updateNeighbours(v, g); // query this

 200      rate *= decayFac;

 201    }

 202  

 203    public static double[] c2v(Color c) {

 204      double[] v = new double[3];

 205      v[0] = (double) c.getRed() / 255.0;

 206      v[1] = (double) c.getGreen() / 255.0;

 207      v[2] = (double) c.getBlue() / 255.0;

 208      return v;

 209    }

 210  

 211    public static Color v2c(double[] v) {

 212      return new Color((float) v[0], (float) v[1], (float) v[2]);

 213    }

 214  

 215    public void paint(Graphics g) {

 216      for (Enumeration e = cells.elements(); e.hasMoreElements();) {

 217        cell c = (cell) e.nextElement();

 218        g.setColor(v2c(c.ipv));

 219        g.fillRect((int)c.gridv[0]*size, (int)c.gridv[1]*size, size, size);

 220      }

 221    }

 222  }

 223  

 224  public class sofm extends Applet implements Runnable {

 225    map sorg;

 226    Thread animator;

 227  

 228    int nCols = 8;

 229    Color[] cols = new Color[8];

 230    static double limit=5.0;

 231    int sleepTime = 10;

 232  

 233    Button start, stop, reset;

 234    Choice neighbourhood, decay, speed; // decay not used

 235    Panel buttons;

 236  

 237  

 238  

 239    public void init() {

 240      setControlPanel();

 241      initColors();

 242      sorg = new map(3,10);

 243      animator = new Thread(this);

 244      animator.start();

 245    }

 246  

 247    public void setControlPanel() {

 248      setLayout(new BorderLayout());

 249      buttons = new Panel();

 250      start = new Button("start");

 251      stop = new Button("stop");

 252      reset = new Button("reset");

 253  

 254      neighbourhood = new Choice();

 255      neighbourhood.addItem("1.0");

 256      neighbourhood.addItem("2.0");

 257      neighbourhood.addItem("3.0");

 258      neighbourhood.addItem("5.0");

 259      neighbourhood.addItem("7.0");

 260      neighbourhood.addItem("10.0");

 261      neighbourhood.select("5.0");

 262  

 263      speed = new Choice();

 264      speed.addItem("fast");

 265      speed.addItem("slow");

 266      speed.select("fast");

 267  

 268      buttons.add(new Label("Radius"));

 269      buttons.add(neighbourhood);

 270      buttons.add(start);

 271      buttons.add(stop);

 272      buttons.add(reset);

 273      buttons.add(speed);

 274      add("South", buttons);

 275    }

 276  

 277    public void initColors() {

 278      cols[0] = Color.red;

 279      cols[1] = Color.blue;

 280      cols[2] = Color.green;

 281      cols[3] = Color.white;

 282      cols[4] = Color.black;

 283      cols[5] = Color.magenta;

 284      cols[6] = Color.pink;

 285      cols[7] = Color.yellow;

 286    }

 287  

 288    public static int randInt(int range) {

 289      return (int) (Math.random() * (double) range);

 290    }

 291  

 292    public void paint(Graphics g) {

 293      validate();

 294      sorg.paint(g);

 295    }

 296  

 297    public boolean handleEvent(Event event) {

 298      // use 1.0 event model for maxmimum browser compatibilty

 299      if (event.id == Event.ACTION_EVENT) {

 300        if (event.target == start) {

 301          // System.out.println("start");

 302          animator.resume();

 303          return true;

 304        }

 305        if (event.target == stop) {

 306          animator.suspend();

 307          return true;

 308        }

 309        if (event.target == reset) {

 310          sorg.reset();

 311          repaint();

 312          return true;

 313        }

 314        if (event.target == speed) {

 315          String s = speed.getSelectedItem();

 316          if (s.equals("slow"))

 317            sleepTime = 1000;

 318          else

 319            sleepTime = 10;

 320          return true;

 321        }

 322  

 323        if (event.target == neighbourhood) {

 324          limit = new Double(neighbourhood.getSelectedItem()).doubleValue();

 325          // System.out.println(limit);

 326          sorg.setNeighbours();

 327          return true;

 328        }

 329      }

 330      return super.handleEvent(event);

 331    }

 332  

 333    public void run() {

 334      sorg.paint(getGraphics());

 335      while(true) {

 336        Color current = cols[randInt(nCols)];

 337        Graphics g = getGraphics();

 338        g.setColor(current);

 339        g.fillRect(180, 410, 40, 30); // lazy - ouch!

 340        sorg.updateMap(current, g);

 341        try { Thread.sleep(sleepTime); }

 342        catch (Exception e) {

 343          System.out.println(e);

 344        }

 345      }

 346    }

 347  }


end of ee905\sofm.java