00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifndef NEURON_H
00024 #define NEURON_H
00025
00026 using namespace std;
00027 #include <amygdala/types.h>
00028 #include <vector>
00029 #if GCC_VERSION >= 30000
00030 #include <ext/hash_map>
00031 #else
00032 #include <hash_map>
00033 #endif
00034 #include <cmath>
00035 #include <amygdala/network.h>
00036 #include <amygdala/mpspikeinput.h>
00037
00038
00039 class MpNetwork;
00040 class Network;
00041 class FunctionLookup;
00042 class SpikeOutput;
00043
00045 enum OutputMode { OFF, OUTPUT_LAYERS, ALL };
00046
00050 class PhysicalProperties {
00051 public:
00052 PhysicalProperties();
00053 ~PhysicalProperties();
00054 void SetLocation(float _newVal[]);
00055 float * GetLocation(float loc[]);
00056
00057 protected:
00059 float location[3];
00060 };
00061
00062
00063
00067 class Synapse {
00068 public:
00069 Synapse(Neuron* _postNrn, float _weight, AmTimeInt _delay = 0);
00070 ~Synapse() {}
00071
00072 float GetWeight() { return weight; }
00073 void SetWeight(float _weight) { weight = _weight; }
00074 const AmTimeInt& GetOffset() const { return offset; }
00075 AmTimeInt GetDelay() const;
00076 Neuron* GetPostNeuron() const { return postNrn; }
00077 protected:
00078 float weight;
00079 Neuron* postNrn;
00080 AmTimeInt offset;
00081 friend class CompareSynapse;
00082 };
00083
00084
00088 class SMPSynapse : public Synapse {
00089 public:
00090 SMPSynapse(MpSpikeInput *_connector, Neuron* _postNrn, float _weight, AmTimeInt _delay = 0);
00092 void SpikeEvent();
00093 protected:
00095 MpSpikeInput *connector;
00096 };
00097
00098
00104 class CompareSynapse {
00105 public:
00106 CompareSynapse() {}
00107 ~CompareSynapse() {}
00108
00109
00110 bool operator()(const Synapse* lsyn, const Synapse* rsyn) {
00111 return lsyn->postNrn < rsyn->postNrn; }
00112 };
00113
00114 typedef vector<Synapse*>::iterator SynapseItr;
00115
00116
00149 class Neuron {
00150 public:
00151 friend void MpSpikeInput::ReadInputBuffer();
00152
00155 Neuron(AmIdInt neuronId);
00156 virtual ~Neuron();
00157
00165 void SetAxonSize(int size);
00166
00168 int GetAxonSize() const { return axon.size(); }
00170 vector <SMPSynapse*>::iterator smpaxon_begin() { return smpAxon.begin(); }
00172 vector <SMPSynapse*>::iterator smpaxon_end() { return smpAxon.end(); }
00173
00182 AmIdInt GetAxonID(unsigned int index) { return axon[index]->GetPostNeuron()->GetID(); }
00183
00187
00188
00189 float GetAxonWeight(unsigned int index) { return axon[index]->GetWeight(); }
00190
00199 float GetAxonDelay(unsigned int index) { return axon[index]->GetDelay(); }
00200
00203 void SetParentNet(Network* const ParNet) { SNet = ParNet; }
00204
00206 AmIdInt GetID() { return(nId); }
00207
00211 LayerType GetLayerType() { return(layerType); }
00212
00217 void SetLayerType(LayerType _layerType) { layerType = _layerType; }
00218
00225 void SetTimeConstants(float synapticConst, float membraneConst);
00226
00229 float GetMembraneConst() const { return memTimeConst; }
00230
00233 float GetSynapticConst() const { return synTimeConst; }
00234
00238 void SetRefractory(float period) { refPeriod = (period * 1000); }
00239
00241 float GetRefractory() const { return (refPeriod / 1000); }
00242
00246 void SetLearningConst(float learnConst) { learningConst = learnConst; }
00247
00249 float GetLearningConst() const { return learningConst; }
00250
00253 void SetRestPotential(float ptnl) { restPtnl = ptnl; }
00254
00258 float GetRestPotential() const { return restPtnl; }
00259
00262 void SetThresholdPotential(float ptnl) { thresholdPtnl = ptnl; }
00263
00265 float GetThresholdPotential() const { return thresholdPtnl; }
00266
00271 void TrainingOn(bool mode) { trainingMode = mode; }
00272
00275 bool IsTraining() const { return trainingMode; }
00276
00284 void Inhibitory(bool val) { inhibitory = val; }
00285
00291 bool Inhibitory() const { return inhibitory; }
00292
00302 static void CaptureOutput(OutputMode mode);
00303
00311 static void SetSpikeOutput(SpikeOutput* output);
00312
00317 static SpikeOutput* GetSpikeOutput() { return spikeOutput; }
00318
00324 static void EnforceSign(bool enforce) { enforceSign = enforce; }
00325
00331 static bool EnforceSign() { return enforceSign; }
00332
00338 static bool GetMPMode() { return mpMode; }
00339
00340
00342 PhysicalProperties * GetPhysicalProperties();
00343
00348 virtual const char* ClassId() = 0;
00349
00361 static void EnableSpikeBatching();
00362
00371 virtual float* InitializeLookupTable(int index) = 0;
00372
00381 virtual float* GetTableParams(int index, int& numParams) = 0;
00382
00384 typedef vector<Synapse*>::iterator iterator;
00386 typedef vector<Synapse*>::const_iterator const_iterator;
00388 typedef vector<Synapse*>::reverse_iterator reverse_iterator;
00390 typedef vector<Synapse*>::const_reverse_iterator const_reverse_iterator;
00391
00395 iterator begin() { return axon.begin(); }
00398 iterator end() { return axon.end(); }
00401 const_iterator begin() const { return axon.begin(); }
00404 const_iterator end() const { return axon.end(); }
00407 reverse_iterator rbegin() { return axon.rbegin(); }
00410 reverse_iterator rend() { return axon.rend(); }
00413 const_reverse_iterator rbegin() const { return axon.rbegin(); }
00416 const_reverse_iterator rend() const { return axon.rend(); }
00417
00418 protected:
00419
00420
00425 static void SetMPMode(bool mode) { mpMode = mode; }
00426
00439 virtual void InputSpike(SynapseItr& inSynapse,
00440 AmTimeInt inTime,
00441 unsigned int numSyn = 0) = 0;
00442
00451 virtual void SetMaxScaledWeight();
00452
00461 virtual void Train(AmTimeInt& spikeTime);
00462
00468 void SendSMPSpike(AmTimeInt& now);
00469
00476 virtual int SetLookupTables(FunctionLookup* funcRef) = 0;
00477
00482 inline void RoundTime(AmTimeInt& time);
00483
00484
00486 AmIdInt nId;
00487
00489 LayerType layerType;
00490 bool inhibitory;
00491
00492 static SpikeOutput* spikeOutput;
00493 static bool defaultOutputObj;
00494
00495 static bool enforceSign;
00496
00497
00498 float memTimeConst;
00499 float synTimeConst;
00500 float refPeriod;
00501
00502
00503 float synPotConst;
00504 float synDepConst;
00505 float learningMax;
00506 float posLearnTimeConst;
00507 float negLearnTimeConst;
00508 float learningConst;
00509
00510
00511 float thresholdPtnl;
00512 float restPtnl;
00513 float membranePtnl;
00514
00515
00516 float maxThreshCrs;
00517 float convergeRes;
00518
00519
00520
00521
00522
00523
00524
00525 AmTimeInt schedSpikeTime;
00526 AmTimeInt spikeTime;
00527 AmTimeInt currTime;
00528 AmTimeInt inputTime;
00529
00530 static AmTimeInt simStepSize;
00531 static bool recordOutput;
00532 static OutputMode outputMode;
00533
00534
00535
00536 bool trainingMode;
00537
00538 Network* SNet;
00539
00540 int axonSize;
00541
00542 int initAxonSize;
00543 float maxScaledWeight;
00544
00545 vector<Synapse*> axon;
00546
00551 struct SynapseHist {
00552 AmTimeInt time;
00553 Synapse* syn;
00554 };
00555 vector<SynapseHist> synapseHist;
00556
00562 struct InputHist {
00563 float weight;
00564 AmTimeInt time;
00565 };
00566
00567 vector<InputHist> inputHist;
00568 unsigned int histBeginIdx;
00569
00575 vector <SMPSynapse*> smpAxon;
00576
00577 unsigned int pspLSize;
00578 unsigned int pspStepSize;
00579
00580
00581
00582
00583 bool usePspLookup;
00584
00585 PhysicalProperties physicalProperties;
00586
00587 private:
00588
00593 void SendSpike(AmTimeInt& now);
00594
00601 void AddSynapse(Synapse* synapse);
00602
00607 void AddSMPSynapse(SMPSynapse* synapse);
00608
00617 void SetTableDimensions(int tblSize, int tblRes);
00618
00620 void SetDefaults();
00621
00626 float GetMaxScaledWeight() const { return maxScaledWeight; }
00627
00628 static bool mpMode;
00629 static bool spikeDelaysOn;
00630
00631 friend class Network;
00632 friend class MpNetwork;
00633 };
00634
00635 inline void Neuron::RoundTime(AmTimeInt& time)
00636 {
00637 float tmpTime;
00638 float roundTime;
00639
00640 tmpTime = float(time) / pspStepSize;
00641 tmpTime = modff(tmpTime, &roundTime);
00642 if (tmpTime > 0.5) {
00643 time = (int(roundTime) * pspStepSize) + pspStepSize;
00644 }
00645 else {
00646 time = int(roundTime) * pspStepSize;
00647 }
00648 }
00649
00650 #endif