Gibbs sampler in Matlab using mexme

Darren Wilkinson has a nice post up comparing different programming languages (C, Java, scala, Python, and R) for writing Gibbs samplers. Unsurprisingly, C is fastest, although it is certainly not the easiest language to program in. In particular, I/O is a bitch.

Others have suggested an interesting solution: write the core of the Gibbs sampler in C and embed it in a high-level language. This saves you from writing the I/O in C. Rcpp and Cython can do this in R and Python. Here I will show how to do the same in Matlab, by automagically writing a mex file based on a numeric C snippet using mexme.

The C Gibbs sampler

Here is the original Gibbs sampler written in C:

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>

void main()
{
int N=50000;
int thin=1000;
int i,j;
gsl_rng *r = gsl_rng_alloc(gsl_rng_mt19937);
double x=0;
double y=0;
printf("Iter x y\n");
for (i=0;i<N;i++) {
for (j=0;j<thin;j++) {
x=gsl_ran_gamma(r,3.0,1.0/(y*y+4));
y=1.0/(x+1)+gsl_ran_gaussian(r,1.0/sqrt(2*x+2));
}
printf("%d %f %f\n",i,x,y);
}
}

This is a straightforward application of the GSL library. To create a mex file out of this, I extracted the numeric component of the code and saved it as djwsampler.csnip:

int i,j;
gsl_rng *r = gsl_rng_alloc(gsl_rng_mt19937);
double x=0;
double y=0;

for (i=0;i<N;i++) {
for (j=0;j<thin;j++) {
x=gsl_ran_gamma(r,3.0,1.0/(y*y+4));
y=1.0/(x+1)+gsl_ran_gaussian(r,1.0/sqrt(2*x+2));
}
samples[i] = x;
samples[i+N] = y;
}

Next, I created a djwsampler.includes file which contains the required #include calls for working with the GSL:

#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>

The Matlab interface

This is all we need to write in terms of C code. As you can see from the code above, I’ve removed the hard-coded definitions of N and thin, and instead of writing out the samples with printf, I am storing them in a samples vector of size Nx2. Thus, in Matlab, I define the two input arguments and the output argument, and then call mexme to automagically write the mex wrapper, then compile this wrapper. Thus:

inputargs = [InputNum('N',true,true,'int32');
InputNum('thin',true,true,'int32')];

outputargs = [OutputNum('samples','N,2')];

opts.extraincludes = readfile('djwsampler.includes');
cfile = mexme('djwsampler.csnip',inputargs,outputargs,opts);
writefile('djwsamplermex.c',cfile)
mex -lgsl -lgslcblas djwsamplermex.c

Matlab has now compiled a mex function called djwsamplermex, which can be called like any other Matlab function:

tic;
results = djwsamplermex(int32(50000),int32(1000));
toc;

Then you can manipulate the results like you would any other Matlab variable. To plot, for example, plot(results(:,1),results(:,2),'.'):

Speed

How fast is this? On a core i7 920, this takes 9.6 seconds. The original pure C version called from the command line took 9.7 seconds. That means that there is essentially no overhead in calling the mex function, and that it is slightly faster to return the raw values to Matlab than write them to a file using printf. Pretty cool, right?

Generated code

You might be wondering what mexme actually does here. It simply takes the definition of the mex file’s interface and generates the corresponding wrapper. This includes some basic checks to verify that the supplied arguments to the function make sense. Here’s what it looks like:

/* C file autogenerated by mexme.m */
#include <mex.h>
#include <math.h>
#include <matrix.h>
#include <stdlib.h>
#include <float.h>
#include <string.h>

/* Translate matlab types to C */
#define uint64 unsigned long int
#define int64 long int
#define uint32 unsigned int
#define int32 int
#define uint16 unsigned short
#define int16 short
#define uint8 unsigned char
#define int8 char
#define single float

#include "mexmetypecheck.c"

/* Your extra includes and function definitions here */
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>


void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[] )
{

/*Input output boilerplate*/
if(nlhs != 1 || nrhs != 2)
mexErrMsgTxt("Function must be called with 2 arguments and has 1 return values");

const mxArray *N_ptr = prhs[0];
mexmetypecheck(N_ptr,mxINT32_CLASS,"Argument N (#1) is expected to be of type int32");
if(mxGetNumberOfElements(N_ptr) != 1)
mexErrMsgTxt("Argument N (#1) must be scalar");
const int32   N = (int32) mxGetScalar(N_ptr);
const mxArray *thin_ptr = prhs[1];
mexmetypecheck(thin_ptr,mxINT32_CLASS,"Argument thin (#2) is expected to be of type int32");
if(mxGetNumberOfElements(thin_ptr) != 1)
mexErrMsgTxt("Argument thin (#2) must be scalar");
const int32   thin = (int32) mxGetScalar(thin_ptr);


mwSize samples_dims[] = {N,2};
plhs[0] = mxCreateNumericArray(2,samples_dims,mxDOUBLE_CLASS,mxREAL);
mxArray **samples_ptr = &plhs[0];
double   *samples = (double *) mxGetData(*samples_ptr);



/*Actual function*/
#include "djwsampler.csnip"


}

Of course, it’s not that hard to write the C code by hand, but if you are more comfortable with Matlab than with C, mexme could be helpful.

Leave a comment

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s