Source code for neuronunit.tests.fi

"""F/I neuronunit tests.

For example, investigating firing rates and patterns as a
function of input current.
"""

import os
import multiprocessing
import copy

import dask.bag as db

import neuronunit
from neuronunit.optimization.data_transport_container import DataTC
from neuronunit.models.reduced import ReducedModel
from .base import np, pq, ncap, VmTest, scores

N_CPUS = multiprocessing.cpu_count()


[docs]class RheobaseTest(VmTest): """Serial implementation of a binary search to test the rheobase. Strengths: this algorithm is faster than the parallel class, present in this file under important and limited circumstances: this serial algorithm is faster than parallel for model backends that are able to implement numba jit optimization. Weaknesses this serial class is significantly slower, for many backend implementations including raw NEURON, NEURON via PyNN, and possibly GLIF. """ def _extra(self): self.prediction = {} self.high = 300*pq.pA self.small = 0*pq.pA self.rheobase_vm = None required_capabilities = (ncap.ReceivesSquareCurrent, ncap.ProducesSpikes) name = "Rheobase test" description = ("A test of the rheobase, i.e. the minimum injected current " "needed to evoke at least one spike.") units = pq.pA ephysprop_name = 'Rheobase' score_type = scores.RatioScore default_params = dict(VmTest.default_params) default_params.update({'amplitude': 100*pq.pA, 'duration': 1000*pq.ms, 'tolerance': 1.0*pq.pA}) params_schema = dict(VmTest.params_schema) params_schema.update({'tolerance': {'type': 'current', 'min': 1, 'required': False}})
[docs] def condition_model(self, model): model.set_run_params(t_stop=self.params['tmax'])
[docs] def generate_prediction(self, model): """Implement sciunit.Test.generate_prediction.""" # Method implementation guaranteed by # ProducesActionPotentials capability. self.condition_model(model) prediction = {'value': None} try: units = self.observation['value'].units except KeyError: units = self.observation['mean'].units # begin_rh = time.time() lookup = self.threshold_FI(model, units) sub = np.array([x for x in lookup if lookup[x] == 0])*units supra = np.array([x for x in lookup if lookup[x] > 0])*units if self.verbose: if len(sub): print("Highest subthreshold current is %s" % (float(sub.max())*units)) else: print("No subthreshold current was tested.") if len(supra): print("Lowest suprathreshold current is %s" % supra.min()) else: print("No suprathreshold current was tested.") if len(sub) and len(supra): rheobase = supra.min() else: rheobase = None prediction['value'] = rheobase return prediction
[docs] def threshold_FI(self, model, units, guess=None): """Use binary search to generate an FI curve including rheobase.""" lookup = {} # A lookup table global to the function below. def f(ampl): if float(ampl) not in lookup: current = self.get_injected_square_current() current['amplitude'] = ampl model.inject_square_current(current) n_spikes = model.get_spike_count() if self.verbose >= 2: print("Injected %s current and got %d spikes" % (ampl, n_spikes)) lookup[float(ampl)] = n_spikes spike_counts = \ np.array([n for x, n in lookup.items() if n > 0]) if n_spikes and n_spikes <= spike_counts.min(): self.rheobase_vm = model.get_membrane_potential() max_iters = 25 # evaluate once with a current injection at 0pA high = self.high small = self.small f(high) i = 0 while True: # sub means below threshold, or no spikes sub = np.array([x for x in lookup if lookup[x] == 0])*units # supra means above threshold, # but possibly too high above threshold. supra = np.array([x for x in lookup if lookup[x] > 0])*units # The actual part of the Rheobase test that is # computation intensive and therefore # a target for parellelization. if len(supra) and len(sub): delta = float(supra.min()) - float(sub.max()) tolerance = float(self.params['tolerance'].rescale(pq.pA)) if delta < tolerance or (str(supra.min()) == str(sub.max())): break if i >= max_iters: break # Its this part that should be like an evaluate function # that is passed to futures map. if len(sub) and len(supra): f((supra.min() + sub.max())/2) elif len(sub): f(max(small, sub.max()*2)) elif len(supra): f(min(-small, supra.min()*2)) i += 1 return lookup
[docs] def compute_score(self, observation, prediction): """Implement sciunit.Test.score_prediction.""" if prediction is None or \ (isinstance(prediction, dict) and prediction['value'] is None): score = scores.InsufficientDataScore(None) else: score = super(RheobaseTest, self).\ compute_score(observation, prediction) # self.bind_score(score,None,observation,prediction) return score
[docs] def bind_score(self, score, model, observation, prediction): """Bind additional attributes to the test score.""" super(RheobaseTest, self).bind_score(score, model, observation, prediction) if self.rheobase_vm is not None: score.related_data['vm'] = self.rheobase_vm
[docs]class RheobaseTestP(RheobaseTest): """Parallel implementation of a binary search to test the rheobase. Strengths: this algorithm is faster than the serial class, present in this file for model backends that are not able to implement numba jit optimization, which actually happens to be typical of a signifcant number of backends. """ name = "Rheobase test" description = ("A test of the rheobase, i.e. the minimum injected current " "needed to evoke at least one spike.") units = pq.pA ephysprop_name = 'Rheobase' score_type = scores.RatioScore get_rheobase_vm = True
[docs] def condition_model(self, model): model.set_run_params(t_stop=self.params['tmax'])
[docs] def generate_prediction(self, model): """Generate the test prediction.""" self.condition_model(model) dtc = DataTC() dtc.attrs = {} for k, v in model.attrs.items(): dtc.attrs[k] = v # this is not a perservering assignment, of value, # but rather a multi statement assertion that will be checked. dtc = init_dtc(dtc) if model.orig_lems_file_path: dtc.model_path = model.orig_lems_file_path dtc.backend = model.backend assert os.path.isfile(dtc.model_path),\ "%s is not a file" % dtc.model_path prediction = {} rheobase = find_rheobase(self, dtc).rheobase if rheobase is not None: # Something like the below commented line must happen to set the # vm trace associated with the rheobase current. One additional # simulation may need to be run, unless we want one of the compute # nodes to set it (when found) in either the dtc or in the calling # instance of the test. # self.rheobase_vm = model.get_membrane_potential() prediction['value'] = float(rheobase) * pq.pA if self.get_rheobase_vm: print("Getting rheobase vm") c = self.get_injected_square_current() c['amplitude'] = prediction['value'] model.inject_square_current(c) self.rheobase_vm = model.get_membrane_potential() else: prediction = None self.rheobase_vm = None return prediction
""" Functions to support the parallel rheobase search. """
[docs]def check_fix_range(dtc): """Check for the rheobase value. Inputs: lookup, A dictionary of previous current injection values used to search rheobase Outputs: A boolean to indicate if the correct rheobase current was found and a dictionary containing the range of values used. If rheobase was actually found then rather returning a boolean and a dictionary, instead logical True, and the rheobase current is returned. given a dictionary of rheobase search values, use that dictionary as input for a subsequent search. """ steps = [] dtc.rheobase = None sub, supra = get_sub_supra(dtc.lookup) if 0. in supra and len(sub) == 0: dtc.boolean = True dtc.rheobase = -1 return dtc elif (len(sub) + len(supra)) == 0: # This assertion would only be occur if there was a bug assert sub.max() <= supra.min() elif len(sub) and len(supra): # Termination criterion steps = np.linspace(sub.max(), supra.min(), N_CPUS+1)*pq.pA steps = steps[1:-1]*pq.pA elif len(sub): steps = np.linspace(sub.max(), 2*sub.max(), N_CPUS+1)*pq.pA steps = steps[1:-1]*pq.pA elif len(supra): steps = np.linspace(supra.min()-100, supra.min(), N_CPUS+1)*pq.pA steps = steps[1:-1]*pq.pA dtc.current_steps = steps return dtc
[docs]def get_sub_supra(lookup): """Get subthreshold and suprathreshold current values.""" sub, supra = [], [] for current, n_spikes in lookup.items(): if n_spikes == 0: # No spikes sub.append(current) elif n_spikes > 0: # Some spikes supra.append(current) sub = np.array(sorted(list(set(sub)))) supra = np.array(sorted(list(set(supra)))) return sub, supra
[docs]def check_current(dtc): """Check the response to the proposed current and count spikes. Inputs are an amplitude to test and a virtual model output is an virtual model with an updated dictionary. """ dtc.boolean = False LEMS_MODEL_PATH = str(neuronunit.__path__[0]) + \ str('/models/NeuroML2/LEMS_2007One.xml') dtc.model_path = LEMS_MODEL_PATH model = ReducedModel(dtc.model_path, name='vanilla', backend=(dtc.backend, {'DTC': dtc})) if dtc.backend is str('NEURON') or dtc.backend is str('jNEUROML'): dtc.current_src_name = model._backend.current_src_name assert dtc.current_src_name is not None dtc.cell_name = model._backend.cell_name if hasattr(model._backend, 'current_src_name'): dtc.current_src_name = model._backend.current_src_name assert dtc.current_src_name is not None dtc.cell_name = model._backend.cell_name ampl = float(dtc.ampl) if ampl not in dtc.lookup or len(dtc.lookup) == 0: current = RheobaseTest.get_default_injected_square_current() uc = {'amplitude': ampl*pq.pA} current.update(uc) dtc.run_number += 1 model.inject_square_current(current) dtc.previous = ampl n_spikes = model.get_spike_count() dtc.lookup[float(ampl)] = n_spikes return dtc
[docs]def init_dtc(dtc): """Exploit memory of last model in genes.""" # check for memory and exploit it. if dtc.initiated is True: dtc = check_current(dtc) if dtc.boolean: return dtc else: # Exploit memory of the genes to inform searchable range. # if this model has lineage, assume it didn't mutate that # far away from it's ancestor. # using that assumption, on first pass, consult a very # narrow range, of test current injection samples: # only slightly displaced away from the ancestors rheobase # value. if isinstance(dtc.current_steps, float): dtc.current_steps = [0.75 * dtc.current_steps, 1.25 * dtc.current_steps] elif isinstance(dtc.current_steps, list): dtc.current_steps = [s * 1.25 for s in dtc.current_steps] # logically unnecessary but included for readibility dtc.initiated = True if dtc.initiated is False: dtc.boolean = False steps = np.linspace(1, 250, 7) steps_current = [i*pq.pA for i in steps] dtc.current_steps = steps_current dtc.initiated = True return dtc
[docs]def find_rheobase(self, dtc): assert os.path.isfile(dtc.model_path),\ "%s is not a file" % dtc.model_path # If this it not the first pass/ first generation # then assume the rheobase value found before mutation still holds # until proven otherwise. # dtc = check_current(model.rheobase,dtc) # If its not true enter a search, with ranges informed by memory cnt = 0 sub = np.array([0, 0]) while dtc.boolean is False and cnt < 40: if len(sub): if sub.max() > 1500.0: dtc.rheobase = None dtc.boolean = False return dtc dtc_clones = [copy.copy(dtc) for i in range(0, len(dtc.current_steps))] for i, s in enumerate(dtc.current_steps): dtc_clones[i].ampl = dtc.current_steps[i] dtc_clones = [d for d in dtc_clones if not np.isnan(d.ampl)] b0 = db.from_sequence(dtc_clones, npartitions=N_CPUS) dtc_clone = list(b0.map(check_current).compute()) for dtc in dtc_clone: if dtc.boolean is True: return dtc for d in dtc_clone: dtc.lookup.update(d.lookup) dtc = check_fix_range(dtc) cnt += 1 sub, supra = get_sub_supra(dtc.lookup) if len(supra) and len(sub): delta = float(supra.min()) - float(sub.max()) tolerance = self.params['tolerance'].rescale(pq.pA) if delta < tolerance or (str(supra.min()) == str(sub.max())): dtc.rheobase = supra.min()*pq.pA dtc.boolean = True return dtc if self.verbose >= 2: print("Try %d: SubMax = %s; SupraMin = %s" % (cnt, sub.max() if len(sub) else None, supra.min() if len(supra) else None)) return dtc