Main Page   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members  

neuron.h

00001 /***************************************************************************
00002                           neuron.h  -  description
00003                              -------------------
00004     copyright            : (C) 2001, 2002 by Matt Grover
00005     email                : mgrover@amygdala.org
00006  ***************************************************************************/
00007 
00008 /***************************************************************************
00009  *                                                                         *
00010  *   This program is free software; you can redistribute it and/or modify  *
00011  *   it under the terms of the GNU General Public License as published by  *
00012  *   the Free Software Foundation; either version 2 of the License, or     *
00013  *   (at your option) any later version.                                   *
00014  *                                                                         *
00015  ***************************************************************************/
00016 
00017 /***************************************************************************
00018     Implement a spiking neuron based on the Integrate-and-fire model.  This
00019     class class will be converted into an abstract base class in the near
00020     future and much of the code will be moved to BasicNeuron.
00021  ***************************************************************************/
00022 
00023 #ifndef NEURON_H
00024 #define NEURON_H
00025 
00026 using namespace std;
00027 #include <amygdala/types.h>
00028 #include <vector>
00029 #if GCC_VERSION >= 30000
00030     #include <ext/hash_map>
00031 #else
00032     #include <hash_map>
00033 #endif
00034 #include <cmath>
00035 #include <amygdala/network.h>
00036 #include <amygdala/mpspikeinput.h>
00037 
00038 
00039 class MpNetwork;
00040 class Network;
00041 class FunctionLookup;
00042 class SpikeOutput;
00043 
00045 enum OutputMode { OFF, OUTPUT_LAYERS, ALL };
00046 
00050 class PhysicalProperties {
00051 public:
00052     PhysicalProperties();
00053     ~PhysicalProperties();
00054     void SetLocation(float _newVal[]);
00055     float * GetLocation(float loc[]);
00056 
00057 protected:
00059     float location[3];
00060 };
00061 
00062 //class CompareSynapse;
00063 
00067 class Synapse {
00068 public:
00069     Synapse(Neuron* _postNrn, float _weight, AmTimeInt _delay = 0);
00070     ~Synapse() {}
00071 
00072     float GetWeight() { return weight; }
00073     void SetWeight(float _weight) { weight = _weight; }
00074     const AmTimeInt& GetOffset() const { return offset; }
00075     AmTimeInt GetDelay() const;
00076     Neuron* GetPostNeuron() const { return postNrn; }
00077 protected:
00078     float weight;
00079     Neuron* postNrn;
00080     AmTimeInt offset;
00081     friend class CompareSynapse;
00082 };
00083 
00084 
00088 class SMPSynapse : public Synapse {
00089 public:
00090     SMPSynapse(MpSpikeInput *_connector, Neuron* _postNrn, float _weight, AmTimeInt _delay = 0);
00092     void SpikeEvent();
00093 protected:
00095     MpSpikeInput *connector;
00096 };
00097 
00098 
00104 class CompareSynapse {
00105 public:
00106     CompareSynapse() {}
00107     ~CompareSynapse() {}
00108 //    bool operator()(const Synapse* lsyn, const Synapse* rsyn) {
00109 //        return lsyn->GetPostNeuron() < rsyn->GetPostNeuron(); }
00110     bool operator()(const Synapse* lsyn, const Synapse* rsyn) {
00111         return lsyn->postNrn < rsyn->postNrn; }
00112 };
00113 
00114 typedef vector<Synapse*>::iterator SynapseItr;
00115 
00116 
00149 class Neuron {
00150 public:
00151     friend void MpSpikeInput::ReadInputBuffer();
00152 
00155     Neuron(AmIdInt neuronId);
00156     virtual ~Neuron();
00157 
00165     void SetAxonSize(int size);
00166 
00168     int GetAxonSize() const { return axon.size(); }
00170     vector <SMPSynapse*>::iterator smpaxon_begin() { return smpAxon.begin(); }
00172     vector <SMPSynapse*>::iterator smpaxon_end() { return smpAxon.end(); }
00173 
00182     AmIdInt GetAxonID(unsigned int index) { return axon[index]->GetPostNeuron()->GetID(); }
00183 
00187     // FIXME: This has to return the normalized weight -- this will return
00188     // the scaled weight as it is now.  See the source of the old Neuron::GetWeight()
00189     float GetAxonWeight(unsigned int index) { return axon[index]->GetWeight(); }
00190 
00199     float GetAxonDelay(unsigned int index) { return axon[index]->GetDelay(); }
00200 
00203     void SetParentNet(Network* const ParNet) { SNet = ParNet; }
00204 
00206     AmIdInt GetID() { return(nId); }
00207 
00211     LayerType GetLayerType() { return(layerType); }
00212 
00217     void SetLayerType(LayerType _layerType) { layerType = _layerType; }
00218 
00225     void SetTimeConstants(float synapticConst, float membraneConst);
00226 
00229     float GetMembraneConst() const { return memTimeConst; }
00230 
00233     float GetSynapticConst() const { return synTimeConst; }
00234 
00238     void SetRefractory(float period) { refPeriod = (period * 1000); }
00239 
00241     float GetRefractory() const { return (refPeriod / 1000); }
00242 
00246     void SetLearningConst(float learnConst) { learningConst = learnConst; }
00247 
00249     float GetLearningConst() const { return learningConst; }
00250 
00253     void SetRestPotential(float ptnl) { restPtnl = ptnl; }
00254 
00258     float GetRestPotential() const { return restPtnl; }
00259 
00262     void SetThresholdPotential(float ptnl) { thresholdPtnl = ptnl; }
00263 
00265     float GetThresholdPotential() const { return thresholdPtnl; }
00266 
00271     void TrainingOn(bool mode) { trainingMode = mode; }
00272 
00275     bool IsTraining() const { return trainingMode; }
00276 
00284     void Inhibitory(bool val) { inhibitory = val; }
00285 
00291     bool Inhibitory() const { return inhibitory; }
00292 
00302     static void CaptureOutput(OutputMode mode);
00303 
00311     static void SetSpikeOutput(SpikeOutput* output);
00312 
00317     static SpikeOutput* GetSpikeOutput() { return spikeOutput; }
00318 
00324     static void EnforceSign(bool enforce) { enforceSign = enforce; }
00325 
00331     static bool EnforceSign() { return enforceSign; }
00332 
00338     static bool GetMPMode() { return mpMode; }
00339 
00340 
00342     PhysicalProperties * GetPhysicalProperties();
00343 
00348     virtual const char* ClassId() = 0;
00349 
00361     static void EnableSpikeBatching();
00362 
00371     virtual float* InitializeLookupTable(int index) = 0;
00372 
00381     virtual float* GetTableParams(int index, int& numParams) = 0;
00382 
00384     typedef vector<Synapse*>::iterator iterator;
00386     typedef vector<Synapse*>::const_iterator const_iterator;
00388     typedef vector<Synapse*>::reverse_iterator reverse_iterator;
00390     typedef vector<Synapse*>::const_reverse_iterator const_reverse_iterator;
00391 
00395     iterator begin() { return axon.begin(); }
00398     iterator end() { return axon.end(); }
00401     const_iterator begin() const { return axon.begin(); }
00404     const_iterator end() const { return axon.end(); }
00407     reverse_iterator rbegin() { return axon.rbegin(); }
00410     reverse_iterator rend() { return axon.rend(); }
00413     const_reverse_iterator rbegin() const { return axon.rbegin(); }
00416     const_reverse_iterator rend() const { return axon.rend(); }
00417 
00418 protected:
00419     // Functions:
00420 
00425     static void SetMPMode(bool mode) { mpMode = mode; }
00426 
00439     virtual void InputSpike(SynapseItr& inSynapse,
00440                             AmTimeInt inTime,
00441                             unsigned int numSyn = 0) = 0;
00442 
00451     virtual void SetMaxScaledWeight();
00452 
00461     virtual void Train(AmTimeInt& spikeTime);
00462 
00468     void SendSMPSpike(AmTimeInt& now);
00469 
00476     virtual int SetLookupTables(FunctionLookup* funcRef) = 0;
00477 
00482     inline void RoundTime(AmTimeInt& time);
00483 
00484     // Data:
00486     AmIdInt nId;
00487 
00489     LayerType layerType;
00490     bool inhibitory;
00491 
00492     static SpikeOutput* spikeOutput;  // Output class shared by all instances of Neuron
00493     static bool defaultOutputObj;   // Flag indicating that the default amygOutput
00494                                     // object is being used.
00495     static bool enforceSign;
00496 
00497     // Time constants
00498     float memTimeConst;             // Membrane time constant (ms)
00499     float synTimeConst;             // Synaptic time constant (ms)
00500     float refPeriod;                // Refractory time period (ms)
00501 
00502     // Learning parameters
00503     float synPotConst;              // Synaptic potentiation constant
00504     float synDepConst;              // Synaptic depression constant
00505     float learningMax;              // Relative time of max weight increase (ms)
00506     float posLearnTimeConst;        // Positive learning time const (t+) (ms)
00507     float negLearnTimeConst;        // Negative learning time const (t-) (ms)
00508     float learningConst;            // Learning constant
00509 
00510     // Potentials have mV units
00511     float thresholdPtnl;            // Threshold potential (theta)
00512     float restPtnl;                 // u-r  (unused)
00513     float membranePtnl;             // u-i  (unused)
00514 
00515     // Needed for calculation of spike times -- see InputSpike() for comments
00516     float maxThreshCrs;
00517     float convergeRes;
00518 
00519     /*
00520      * Times are measured in microseconds since the begining
00521      * of the simulation. These will eventually be changed
00522      * to 64-bit integers to allow for longer simulations
00523      * (32-bit integers will roll over after about an hour).
00524      */
00525     AmTimeInt schedSpikeTime;    // time of next scheduled spike
00526     AmTimeInt spikeTime;         // time of last spike
00527     AmTimeInt currTime;          // current simulation time - set in InputSpike()
00528     AmTimeInt inputTime;         // time of last received spike
00529 
00530     static AmTimeInt simStepSize;       // simulation step size in us
00531     static bool recordOutput;
00532     static OutputMode outputMode;
00533 
00534     // Also need some vars relevant to the training rule...
00535     // Get those out of the Hebbian Learning paper
00536     bool trainingMode;              // Training indicator
00537     
00538     Network* SNet;                  // Parent Network
00539 
00540     int axonSize;                   // number of outputs
00541     //int dendriteSize;               // number of inputs
00542     int initAxonSize;               // initial size of output
00543     float maxScaledWeight;          // scaling factor for normalized weights
00544 
00545     vector<Synapse*> axon;
00546 
00551     struct SynapseHist {
00552         AmTimeInt time;
00553         Synapse* syn;
00554     };
00555     vector<SynapseHist> synapseHist;
00556 
00562     struct InputHist {
00563         float weight;
00564         AmTimeInt time;
00565     };
00566 
00567     vector<InputHist> inputHist;
00568     unsigned int histBeginIdx;
00569 
00575     vector <SMPSynapse*> smpAxon;
00576 
00577     unsigned int pspLSize;       // array size
00578     unsigned int pspStepSize;    // size of time increment in pspLookup (in us)
00579 
00580     // Obsolete -- now used as a flag in SetTimeConstants to indicate
00581     // that the lookup tables have been filled.  A new method
00582     // for doing that should be developed.
00583     bool usePspLookup;
00584 
00585     PhysicalProperties physicalProperties;
00586 
00587 private:
00588 
00593     void SendSpike(AmTimeInt& now);
00594 
00601     void AddSynapse(Synapse* synapse);
00602 
00607     void AddSMPSynapse(SMPSynapse* synapse);
00608 
00617     void SetTableDimensions(int tblSize, int tblRes);
00618 
00620     void SetDefaults();
00621 
00626     float GetMaxScaledWeight() const { return maxScaledWeight; }
00627 
00628     static bool mpMode;
00629     static bool spikeDelaysOn;
00630 
00631     friend class Network;
00632     friend class MpNetwork;
00633 };
00634 
00635 inline void Neuron::RoundTime(AmTimeInt& time)
00636 {
00637     float tmpTime;
00638     float roundTime;
00639 
00640     tmpTime = float(time) / pspStepSize;
00641     tmpTime = modff(tmpTime, &roundTime);
00642     if (tmpTime > 0.5) {
00643         time = (int(roundTime) * pspStepSize) + pspStepSize;
00644     }
00645     else {
00646         time = int(roundTime) * pspStepSize;
00647     }
00648 }
00649 
00650 #endif

Generated on Wed Sep 4 02:30:35 2002 for Amygdala by doxygen1.2.14 written by Dimitri van Heesch, © 1997-2002