### GLMs and Hidden Markov models for single neurons

I posted recently about modeling neurons with continuous state-space dynamics. It’s also possible to model neurons with Hidden Markov models (HMMs), which are state-space models with discrete rather than continuous states. In this post I’m going to focus on the application of HMMs to single neuron data.

### Single neurons with simple states

Suppose that a neuron has several discrete internal states, and that it switches slowly, at unknown times, between these states. In the simplest scenario, these states will correspond to its mean firing rate. Such a model is appropriate for the analysis of bistable neurons, which might occur with binocular rivalry data or looking at sleep-wake or on-off states.

Calling the hidden state that the neuron is in at time t $z_t$, we can write the transition probability going from one state to the next as:

$p(z_n=k|z_{n-1}=j) = A_{jk}$

Thus the hidden states are discrete and have the Markov property, hence the name, Hidden Markov Model (HMM). Now it’s possible to outfit the HMM with any type of emission, whether discrete or continuous, and inference remains more or less the same. In our case, if we bin the spike data with sufficiently small bins, there will only be at most one spike per bin. Thus there are only 2 possibly observed outputs, 0 for no spike and 1 for a spike, and what differs in each hidden state is the probability of each output being emitted. Such an HMM with discrete emissions is standard, and you can use Matlab’s built-in HMM toolbox to analyze it.

Here’s an example spike train which switches between two discrete states every 50 samples:

ns = 1000;
a = mod(floor((1:ns)/50),2) ==1;
y = zeros(size(a));
y(a==0) = rand(nnz(a==0),1)<.3;
y(a==1) = rand(nnz(a==1),1)<.05;
plot(y);axis([0,1000,0,1.3]);


Which looks like this:

A few calls to Matlab’s HMM toolbox will determine the firing rate in each state and the transition probabilities (“learning”):

%%
%The guessed probability of the model switching from the low to the
%high-firing rate state:
switchProb = [.9,.1;.1,.9];
%The guessed probability of emitting the symbols (0 = no spike, 1 = spike)
%in each state:
%Note that state 1 corresponds to low firing rate, and state 2 is
%high firing rate
emissProb = [.9,.1;
.6,.4];
[T,emissions] = hmmtrain(y,switchProb,emissProb,'symbols',[0,1]);
T
emissions


Which gives:

T =

0.9612    0.0388
0.0445    0.9555
emissions =    0.9584    0.0416
0.6658    0.3342

The probability that the neuron is in each state, as well as the most likely state are easily found:

%%
%Probability of each state
[pstates] = hmmdecode(y,T,emissions,'symbols',[0,1]);
subplot(2,1,1);plot(1:length(y),y,1:length(y),pstates(2,:))

%Most probable sequence of states
[ostate] = hmmviterbi(y,T,emissions,'symbols',[0,1]);
subplot(2,1,2);plot(1:length(y),y,1:length(y),1.2*ostate-1.3);
axis([0,1000,-.2,1.3]);


With the output:

And that’s pretty much all there is to it.

### Single neurons with complex states

By using a standard HMM, we were able to avoid getting into the nitty gritty of learning with HMMs. But there are other slightly more complex  analysis scenarios where you’ll need to do a little more work. In the previous analysis, I assumed that the only thing that is changing inside the neuron is its firing rate. But what if other aspects of its response change as well, like its burstiness or its receptive field?

In that case you can still assume that states are discrete and have Markov transition probabilities. However, instead of assuming that the mapping from state to firing rate is direct, we’ll assume that the output is generated by a Generalized Linear model (GLM). For example, when the cell is in state 1, its output is generated by a GLM with a certain receptive field, while in state 2, it will be generated by another GLM with a different receptive field.

How can you analyze this? Chapter 13 of Bishop is quite useful here. Parameters of HMMs are typically infered through EM (expectation-maximization). In the E step, you infer the probability that the model is in each hidden state for all time steps. This is done through the forward-backward algorithm. It’s important to understand that this step is basically independent of the emission mechanism. All you need to perform this step is to estimate the probability of the observed data at each time step assuming the model is in state 1 or state 2.

In the M step, we’ll maximize the expected complete data log-likelihood with respect to the model parameters. This sounds scary, but in actuality this involves fitting one GLM per hidden state, with each data point weighted by the probability, computed in the E step, that the data is in the target hidden state. So to be perfectly clear, if the GLM in question was the normal-identity GLM, inference outside the EM context would involve minimizing the sum-of-squared error:

$\mbox{arg} \min_w \sum_{ij}(y_i-X_{ij}w_j)^2$

In the context of the M-step, this would be replaced by the weighted sum-of-squares:

$\mbox{arg} \min_w \sum_{ij} \gamma_i (y_i-X_{ij}w_j)^2$

Now the E step can be done with generic software. I used the hmmFwdBack function in PMTK to do this step. For the M step, I added a new “weights” option in my GLM toolbox. I filled in the rest by following chapter 13 of Bishop as well as Rabiner (1989).

Here’s some data I generated which simulates a receptive field which changes every 50 time steps:

ns = 2000;
X = randn(ns,10);
a = mod(floor((1:ns)/50),2) ==1;

rg = (1:10)';
ws = [exp(-(rg-3).^2/2/.8^2),exp(-(rg-7).^2/2/.8^2)];

W = a'*ws(:,1)' + (1-a)'*ws(:,2)';

eta = sum(X.*W,2)-1;
p   = 1./(1+exp(-eta));
y = double(p>rand(size(p)));

plot(ws)


Here are the two receptive fields it changes between:

Here’s a function that performs EM for this data set:

function [results] = glhmmfit(y,X,W0,T0,glmopts)
thresh = 1e-3;
p0 = ones(size(W0,2));
p0 = p0/sum(p0);

%Fit a generalized linear model with hidden Markov states
%through EM
W = W0;
niter = 100;
lls = zeros(size(W,2),size(X,1));
T = T0;
loglik0 = -Inf;

glmopts.Display = 'off';

for ii = 1:niter
%Compute the likelihood of each observation for every state
etas = X*W;
for jj = 1:size(W,2)
[~,~,~,~,lls(jj,:)] = evalGlmLikelihood(y,etas(:,jj),1,0,glmopts.family,glmopts.familyextra);
end
%lls0 = lls;
%dl = max(lls);
%lls = exp(-bsxfun(@minus,lls,dl));
lls = exp(-lls);

%Forward-backward
[gamma,alpha,beta,loglik] = hmmFwdBack(p0,T,lls);

%loglik = loglik + sum(dl);

%Fit glms for every submodel
%Equation (13.17)
opts2 = glmopts;
for jj = 1:size(W,2)
opts2.weights = gamma(jj,:)';
results = glmfitqp(y,X,.001*speye(size(X,2)),opts2);
W(:,jj) = results.w;
end

%Update start and transition probabilities
%eq (13.18)
p0 = gamma(:,1);

%eqs (13.19) & (13.43)
xi = 0;
for jj = 1:size(X,1)-1
xi0 = (alpha(:,jj)*(lls(:,jj+1).*beta(:,jj+1))').*T;
xi = xi + xi0/sum(xi0(:));
end

xi = bsxfun(@times,xi,1./sum(gamma,2));
T = bsxfun(@times,xi,1./sum(xi,2));

if loglik < loglik0 + thresh
%converged
fprintf('Converged in %d iterations\n',ii);
break;
end
loglik0 = loglik;
fprintf('Iteration %d, loglik = %8.2f\n',ii,loglik);
end

results.W = W;
results.T = T;
results.p = p0;
results.loglik = loglik;
results.gamma = gamma;
end


Call this to perform inference:

glmopts.family = 'binomlogit';
glmopts.familyextra = 1;
results = glhmmfit(y,X,randn(size(ws)),T,glmopts);


And here’s the estimated RFs with the infered states underneath: