Source code for neuronunit.tests.base

"""Base classes and attributes for many neuronunit tests.

No classes here meant for direct use in testing.
"""

from types import MethodType

import numpy as np
import quantities as pq

import sciunit
from sciunit.tests import ProtocolToFeaturesTest
import sciunit.scores as scores
import neuronunit.capabilities as ncap
import sciunit.capabilities as scap
from neuronunit import neuroelectro


[docs]class VmTest(ProtocolToFeaturesTest): """Base class for tests involving the membrane potential of a model.""" def __init__(self, observation={'mean': None, 'std': None}, name=None, **params): super(VmTest, self).__init__(observation, name, **params) cap = [] for cls in self.__class__.__bases__: cap += cls.required_capabilities self.required_capabilities += tuple(cap) self._extra() required_capabilities = (scap.Runnable, ncap.ProducesMembranePotential,) name = '' units = pq.Dimensionless ephysprop_name = '' observation_schema = [("Mean, Standard Deviation, N", {'mean': {'units': True, 'required': True}, 'std': {'units': True, 'min': 0, 'required': True}, 'n': {'type': 'integer', 'min': 1}}), ("Mean, Standard Error, N", {'mean': {'units': True, 'required': True}, 'sem': {'units': True, 'min': 0, 'required': True}, 'n': {'type': 'integer', 'min': 1, 'required': True}})] default_params = {'amplitude': 0.0*pq.pA, 'delay': 100.0*pq.ms, 'duration': 300.0*pq.ms, 'dt': 0.025*pq.ms, 'padding': 200*pq.ms} params_schema = {'dt': {'type': 'time', 'min': 0, 'required': False}, 'tmax': {'type': 'time', 'min': 0, 'required': False}, 'delay': {'type': 'time', 'min': 0, 'required': False}, 'duration': {'type': 'time', 'min': 0, 'required': False}, 'amplitude': {'type': 'current', 'required': False}, 'padding': {'type': 'time', 'min': 0, 'required': False}} def _extra(self): pass
[docs] def compute_params(self): self.params['tmax'] = (self.params['delay'] + self.params['duration'] + self.params['padding'])
[docs] def validate_observation(self, observation): super(VmTest, self).validate_observation(observation) # Catch another case that is trickier if 'std' not in observation: observation['std'] = observation['sem'] * np.sqrt(observation['n']) return observation
[docs] def condition_model(self, model): model.set_run_params(t_stop=self.params['tmax'])
[docs] def bind_score(self, score, model, observation, prediction): score.related_data['vm'] = model.get_membrane_potential() score.related_data['model_name'] = '%s_%s' % (model.name, self.name) def plot_vm(self, ax=None, ylim=(None, None)): """A plot method the score can use for convenience.""" import matplotlib.pyplot as plt if ax is None: ax = plt.gca() vm = score.related_data['vm'].rescale('mV') ax.plot(vm.times, vm) y_min = float(vm.min()-5.0*pq.mV) if ylim[0] is None else ylim[0] y_max = float(vm.max()+5.0*pq.mV) if ylim[1] is None else ylim[1] ax.set_xlim(vm.times.min(), vm.times.max()) ax.set_ylim(y_min, y_max) ax.set_xlabel('Time (s)') ax.set_ylabel('Vm (mV)') score.plot_vm = MethodType(plot_vm, score) # Bind to the score. score.unpicklable.append('plot_vm')
[docs] @classmethod def neuroelectro_summary_observation(cls, neuron, cached=False): reference_data = neuroelectro.NeuroElectroSummary( neuron=neuron, # Neuron type lookup using the NeuroLex ID. ephysprop={'name': cls.ephysprop_name}, # Ephys property name in # NeuroElectro ontology. cached=cached ) # Get and verify summary data from neuroelectro.org. reference_data.get_values(quiet=not cls.verbose) if hasattr(reference_data, 'mean'): observation = {'mean': reference_data.mean*cls.units, 'std': reference_data.std*cls.units, 'n': reference_data.n} else: observation = None return observation
[docs] @classmethod def neuroelectro_pooled_observation(cls, neuron, cached=False, quiet=True): reference_data = neuroelectro.NeuroElectroPooledSummary( neuron=neuron, # Neuron type lookup using the NeuroLex ID. # Ephys property name in NeuroElectro ontology. ephysprop={'name': cls.ephysprop_name}, cached=cached ) # Get and verify summary data from neuroelectro.org. reference_data.get_values(quiet=quiet) observation = {'mean': reference_data.mean*cls.units, 'std': reference_data.std*cls.units, 'n': reference_data.n} return observation
[docs] def sanity_check(self, rheobase, model): self.params['injected_square_current']['delay'] = self.params['delay'] self.params['injected_square_current']['duration'] = \ self.params['duration'] self.params['injected_square_current']['amplitude'] = rheobase model.inject_square_current(self.params['injected_square_current']) mp = model.results['vm'] if np.any(np.isnan(mp)) or np.any(np.isinf(mp)): return False sws = ncap.spike_functions.get_spike_waveforms( model.get_membrane_potential()) for i, s in enumerate(sws): s = np.array(s) dvdt = np.diff(s) for j in dvdt: if np.isnan(j): return False return True
[docs] @classmethod def get_default_injected_square_current(cls): current = {key: cls.default_params[key] for key in ['duration', 'delay', 'amplitude']} return current
[docs] def get_injected_square_current(self): current = {key: self.default_params[key] for key in ['duration', 'delay', 'amplitude']} return current
@property def state(self): state = super(VmTest, self).state return self._state(state=state, exclude=['unpicklable', 'verbose'])
[docs]class FakeTest(sciunit.Test): """Fake test class. Just computes agreement between an observation key and a model attribute. e.g. observation = {'a':[0.8,0.3], 'b':[0.5,0.1], 'vr':[-70*pq.mV,5*pq.mV]} fake_test_a = FakeTest("test_a",observation=observation) fake_test_b = FakeTest("test_b",observation=observation) fake_test_vr = FakeTest("test_vr",observation=observation) """
[docs] def generate_prediction(self, model): self.key_param = self.name.split('_')[1] return model.attrs[self.key_param]
[docs] def compute_score(self, observation, prediction): mean = observation[self.key_param][0] std = observation[self.key_param][1] z = (prediction - mean)/std return scores.ZScore(z)