00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 #ifndef NEURON_H
00024 #define NEURON_H
00025 
00026 using namespace std;
00027 #include <amygdala/types.h>
00028 #include <vector>
00029 #if GCC_VERSION >= 30000
00030     #include <ext/hash_map>
00031 #else
00032     #include <hash_map>
00033 #endif
00034 #include <cmath>
00035 #include <amygdala/network.h>
00036 #include <amygdala/mpspikeinput.h>
00037 
00038 
00039 class MpNetwork;
00040 class Network;
00041 class FunctionLookup;
00042 class SpikeOutput;
00043 
00045 enum OutputMode { OFF, OUTPUT_LAYERS, ALL };
00046 
00050 class PhysicalProperties {
00051 public:
00052     PhysicalProperties();
00053     ~PhysicalProperties();
00054     void SetLocation(float _newVal[]);
00055     float * GetLocation(float loc[]);
00056 
00057 protected:
00059     float location[3];
00060 };
00061 
00062 
00063 
00067 class Synapse {
00068 public:
00069     Synapse(Neuron* _postNrn, float _weight, AmTimeInt _delay = 0);
00070     ~Synapse() {}
00071 
00072     float GetWeight() { return weight; }
00073     void SetWeight(float _weight) { weight = _weight; }
00074     const AmTimeInt& GetOffset() const { return offset; }
00075     AmTimeInt GetDelay() const;
00076     Neuron* GetPostNeuron() const { return postNrn; }
00077 protected:
00078     float weight;
00079     Neuron* postNrn;
00080     AmTimeInt offset;
00081     friend class CompareSynapse;
00082 };
00083 
00084 
00088 class SMPSynapse : public Synapse {
00089 public:
00090     SMPSynapse(MpSpikeInput *_connector, Neuron* _postNrn, float _weight, AmTimeInt _delay = 0);
00092     void SpikeEvent();
00093 protected:
00095     MpSpikeInput *connector;
00096 };
00097 
00098 
00104 class CompareSynapse {
00105 public:
00106     CompareSynapse() {}
00107     ~CompareSynapse() {}
00108 
00109 
00110     bool operator()(const Synapse* lsyn, const Synapse* rsyn) {
00111         return lsyn->postNrn < rsyn->postNrn; }
00112 };
00113 
00114 typedef vector<Synapse*>::iterator SynapseItr;
00115 
00116 
00149 class Neuron {
00150 public:
00151     friend void MpSpikeInput::ReadInputBuffer();
00152 
00155     Neuron(AmIdInt neuronId);
00156     virtual ~Neuron();
00157 
00165     void SetAxonSize(int size);
00166 
00168     int GetAxonSize() const { return axon.size(); }
00170     vector <SMPSynapse*>::iterator smpaxon_begin() { return smpAxon.begin(); }
00172     vector <SMPSynapse*>::iterator smpaxon_end() { return smpAxon.end(); }
00173 
00182     AmIdInt GetAxonID(unsigned int index) { return axon[index]->GetPostNeuron()->GetID(); }
00183 
00187     
00188     
00189     float GetAxonWeight(unsigned int index) { return axon[index]->GetWeight(); }
00190 
00199     float GetAxonDelay(unsigned int index) { return axon[index]->GetDelay(); }
00200 
00203     void SetParentNet(Network* const ParNet) { SNet = ParNet; }
00204 
00206     AmIdInt GetID() { return(nId); }
00207 
00211     LayerType GetLayerType() { return(layerType); }
00212 
00217     void SetLayerType(LayerType _layerType) { layerType = _layerType; }
00218 
00225     void SetTimeConstants(float synapticConst, float membraneConst);
00226 
00229     float GetMembraneConst() const { return memTimeConst; }
00230 
00233     float GetSynapticConst() const { return synTimeConst; }
00234 
00238     void SetRefractory(float period) { refPeriod = (period * 1000); }
00239 
00241     float GetRefractory() const { return (refPeriod / 1000); }
00242 
00246     void SetLearningConst(float learnConst) { learningConst = learnConst; }
00247 
00249     float GetLearningConst() const { return learningConst; }
00250 
00253     void SetRestPotential(float ptnl) { restPtnl = ptnl; }
00254 
00258     float GetRestPotential() const { return restPtnl; }
00259 
00262     void SetThresholdPotential(float ptnl) { thresholdPtnl = ptnl; }
00263 
00265     float GetThresholdPotential() const { return thresholdPtnl; }
00266 
00271     void TrainingOn(bool mode) { trainingMode = mode; }
00272 
00275     bool IsTraining() const { return trainingMode; }
00276 
00284     void Inhibitory(bool val) { inhibitory = val; }
00285 
00291     bool Inhibitory() const { return inhibitory; }
00292 
00302     static void CaptureOutput(OutputMode mode);
00303 
00311     static void SetSpikeOutput(SpikeOutput* output);
00312 
00317     static SpikeOutput* GetSpikeOutput() { return spikeOutput; }
00318 
00324     static void EnforceSign(bool enforce) { enforceSign = enforce; }
00325 
00331     static bool EnforceSign() { return enforceSign; }
00332 
00338     static bool GetMPMode() { return mpMode; }
00339 
00340 
00342     PhysicalProperties * GetPhysicalProperties();
00343 
00348     virtual const char* ClassId() = 0;
00349 
00361     static void EnableSpikeBatching();
00362 
00371     virtual float* InitializeLookupTable(int index) = 0;
00372 
00381     virtual float* GetTableParams(int index, int& numParams) = 0;
00382 
00384     typedef vector<Synapse*>::iterator iterator;
00386     typedef vector<Synapse*>::const_iterator const_iterator;
00388     typedef vector<Synapse*>::reverse_iterator reverse_iterator;
00390     typedef vector<Synapse*>::const_reverse_iterator const_reverse_iterator;
00391 
00395     iterator begin() { return axon.begin(); }
00398     iterator end() { return axon.end(); }
00401     const_iterator begin() const { return axon.begin(); }
00404     const_iterator end() const { return axon.end(); }
00407     reverse_iterator rbegin() { return axon.rbegin(); }
00410     reverse_iterator rend() { return axon.rend(); }
00413     const_reverse_iterator rbegin() const { return axon.rbegin(); }
00416     const_reverse_iterator rend() const { return axon.rend(); }
00417 
00418 protected:
00419     
00420 
00425     static void SetMPMode(bool mode) { mpMode = mode; }
00426 
00439     virtual void InputSpike(SynapseItr& inSynapse,
00440                             AmTimeInt inTime,
00441                             unsigned int numSyn = 0) = 0;
00442 
00451     virtual void SetMaxScaledWeight();
00452 
00461     virtual void Train(AmTimeInt& spikeTime);
00462 
00468     void SendSMPSpike(AmTimeInt& now);
00469 
00476     virtual int SetLookupTables(FunctionLookup* funcRef) = 0;
00477 
00482     inline void RoundTime(AmTimeInt& time);
00483 
00484     
00486     AmIdInt nId;
00487 
00489     LayerType layerType;
00490     bool inhibitory;
00491 
00492     static SpikeOutput* spikeOutput;  
00493     static bool defaultOutputObj;   
00494                                     
00495     static bool enforceSign;
00496 
00497     
00498     float memTimeConst;             
00499     float synTimeConst;             
00500     float refPeriod;                
00501 
00502     
00503     float synPotConst;              
00504     float synDepConst;              
00505     float learningMax;              
00506     float posLearnTimeConst;        
00507     float negLearnTimeConst;        
00508     float learningConst;            
00509 
00510     
00511     float thresholdPtnl;            
00512     float restPtnl;                 
00513     float membranePtnl;             
00514 
00515     
00516     float maxThreshCrs;
00517     float convergeRes;
00518 
00519     
00520 
00521 
00522 
00523 
00524 
00525     AmTimeInt schedSpikeTime;    
00526     AmTimeInt spikeTime;         
00527     AmTimeInt currTime;          
00528     AmTimeInt inputTime;         
00529 
00530     static AmTimeInt simStepSize;       
00531     static bool recordOutput;
00532     static OutputMode outputMode;
00533 
00534     
00535     
00536     bool trainingMode;              
00537     
00538     Network* SNet;                  
00539 
00540     int axonSize;                   
00541     
00542     int initAxonSize;               
00543     float maxScaledWeight;          
00544 
00545     vector<Synapse*> axon;
00546 
00551     struct SynapseHist {
00552         AmTimeInt time;
00553         Synapse* syn;
00554     };
00555     vector<SynapseHist> synapseHist;
00556 
00562     struct InputHist {
00563         float weight;
00564         AmTimeInt time;
00565     };
00566 
00567     vector<InputHist> inputHist;
00568     unsigned int histBeginIdx;
00569 
00575     vector <SMPSynapse*> smpAxon;
00576 
00577     unsigned int pspLSize;       
00578     unsigned int pspStepSize;    
00579 
00580     
00581     
00582     
00583     bool usePspLookup;
00584 
00585     PhysicalProperties physicalProperties;
00586 
00587 private:
00588 
00593     void SendSpike(AmTimeInt& now);
00594 
00601     void AddSynapse(Synapse* synapse);
00602 
00607     void AddSMPSynapse(SMPSynapse* synapse);
00608 
00617     void SetTableDimensions(int tblSize, int tblRes);
00618 
00620     void SetDefaults();
00621 
00626     float GetMaxScaledWeight() const { return maxScaledWeight; }
00627 
00628     static bool mpMode;
00629     static bool spikeDelaysOn;
00630 
00631     friend class Network;
00632     friend class MpNetwork;
00633 };
00634 
00635 inline void Neuron::RoundTime(AmTimeInt& time)
00636 {
00637     float tmpTime;
00638     float roundTime;
00639 
00640     tmpTime = float(time) / pspStepSize;
00641     tmpTime = modff(tmpTime, &roundTime);
00642     if (tmpTime > 0.5) {
00643         time = (int(roundTime) * pspStepSize) + pspStepSize;
00644     }
00645     else {
00646         time = int(roundTime) * pspStepSize;
00647     }
00648 }
00649 
00650 #endif