Mex files - an example
Sometimes when a matlab *.m file is too slow one option is to rewrite part of it in C/C++/fortran. Sometimes it's even possible to write faster versions of matlab built-in commands. Matlab's routines have to cope with many situations (complex numbers, extreme values, etc) safely and accurately. Your data may have a narrow range and you may be prepared to sacrifice some safety and accuracy for speed.
This document shows how to produce a near order-of-magnitude faster (though less accurate) alternative to matlab's exp function. To understand how it works you need to know how computers store numbers. Such knowledge isn't necessary usually when writing mex files. If you don't care about why this particular algorithm works, but just want a working example that you can adapt to your own needs, skip the sections on numerical analysis and go straight to the How to do it section.
The algorithm, maths and macros come from A Fast, Compact Approximation of the Exponential Function by Nicol N. Schraudolph in Neural Computation, 11,853-862 (1999).
How numbers are stored in memory
Integers
The smallest unit of computer memory is the bit, which can have two values - 0 or 1. It's straightforward to represent positive integers using bits - the bits are digits of a base 2 representation, so in an 8 bit byte
0 = 00000000 1 = 00000001 2 = 00000010 3 = 00000011 ... 255 = 11111111
The least significant bit is known as bit 0, the next bit is bit 1, etc. It's less trivial to represent signed values. The following scheme (called two's complement) is usually used
0 = 00000000 1 = 00000001 2 = 00000010 3 = 00000011 ... 127 = 01111111 -128 = 10000000 ... -1 = 11111111
Of course, this isn't the only scheme that is possible, but it has convenient features - for example the left-most bit shows you the sign of the number.
16 bit signed values work in a similar way
0 = 00000000 00000000 1 = 00000000 00000001 ... 255 = 00000000 11111111 256 = 00000001 00000000 ... 32767 = 01111111 11111111 -32768 = 10000000 00000000 ... -1 = 11111111 11111111
C's ints are usually 32 bit values and follow the same scheme.
Doubles
A double uses 8 bytes (ie 64 bits) of memory and has this layout
sxxxxxxxxxxxmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm
Each character here represents a bit
- 's' means the bit represents the sign of the number
- 'x' means the bit represents part of the exponent (11 bits)
- 'm' means that bit represents part of the mantissa - the value (52 bits)
The double's value is
(-1)**s * (1+mantissa) * 2**(exponent-1023)
so to represent the number 1.0, s would be 0, the mantissa would be 0, and the exponent would be 1023, so the bit pattern would be
00111111111100000000000000000000000000000000000000000000000000000
Setting bits
One way of placing the bit pattern above into memory in C is to do the straightforward
double d=1.0;
but to have fine control over the bits we need to use other means. The following C fragment is used in the algorithm. What it does is create a new variable called _eco.
static union { double d; struct { int i,j; } n; } _eco;
This variable is rather complicated, especially if you're new to C. It has things inside it - a double d and a structure called n. n in turn has 2 integers inside it - i and j. If you wanted to set its d component to 1.0 you could do
_eco.d = 1.0;
If you wanted to set i to 7 you could do
_eco.n.i = 7;
(i is a component of n which is in turn inside _eco). n is a struct containing 2 integers that are stored side by side in memory. _eco is a union of d and n. The difference in C between a union and a struct is that the components of a union occupy the same area of memory - changing i overwrites the double d.
Usually this sort of memory corruption is unwanted, but here it's useful for our purposes. i overlays the first 32 bits of d, so if we set the value of i carefully, we can change the exponent part (and only that part) of d. For example, to give d the value 1.0 we need to set i to be
00111111111100000000000000000000
which is
00000000000000000000001111111111
(i.e. 1023) shifted left 20 times. Because this is base 2, shifting left 20 times is equivalent to doubling 20 times. 220 is 1048576, a number you'll be seeing more of later.
Theory into practise
We're now ready to start taking advantage of our knowledge. Suppose given an integer called num we want to create a double with the value 2num. We could do
_eco.i = (num+1023)*1048576;
and read the value from _eco.d.
To compute enum rather than 2num num must first
be divided by ln(2).
So after
_eco.i = (num/ln(2)+1023)*1048576;_eco.d has the value enum | an aside: Why ln(2)? Because if
enum = 2num/x we can take the natural log of each side num= ln(2num/x) then use one of the rules of logs which says that loga(xy) = yloga(x), so the expression becomes num = ln(2)*(num/x) leading to x=ln(2) |
If num has a fractional part, then multiplying by 220 will leave bits in the part of i that overlaps with the mantissa part of d. This is a fortunate result of the layout of doubles - the bigger the fraction the greater the mantissa, so it acts as interpolation to make the result more accurate.
Theory and experiment confirm that adding a "fudge factor" improves the mean error. The algorithm uses the following
_eco.i = (num/ln(2)+1023)*1048576-fudge_factor;
which (when re-organised, and the optimum fudge factor determined) turns into
_eco.i = num*1048576/0.69314718055994530942+1072693248-60801;
Accuracy, portability and limitations
Speed has its price
- The code becomes harder to understand
- On some systems (Pentium based ones, for example) integers are stored with their least significant byte first. These are called LITTLE_ENDIAN systems. The details of the code need to be slightly different for such machines. The code below has an optional line "#define LITTLE_ENDIAN 1" that can be enabled by removing the '/*' and '*/' symbols.
- If doubles aren't laid out as described above (i.e. if they don't conform to IEEE-754) the code won't work. IEEE-754 conformance is very common nowadays.
- The code goes wrong if the magnitude of num is too great, but it doesn't do any error checking. The acceptable range is about -700 to 700.
- The average error is about 1.5%.
How to do it
Let's take it a step at a time, starting without matlab. Put this code into a called "fastexp.c" and compile it (using a line like "gcc -o fastexp fastexp.c -lm"). If you run the resulting program you should find that the 2 values printed out are very similar.
#include <math.h> /* #define LITTLE_ENDIAN 1 */ static union { double d; struct { #ifdef LITTLE_ENDIAN int j,i; #else int i,j; #endif } n; } _eco; /* If M_LN2 isn't defined in math.h, replace it by the value */ #define EXP_A (1048576/M_LN2) #define EXP_C 60801 #define EXP(y) (_eco.n.i = EXP_A*(y) + (1072693248 - EXP_C), _eco.d) /* note that EXP is a macro rather than a function to avoid the function call overhead */ int main() { double g=1.27; printf("%f\n", (double)EXP(g)); printf("%f\n", exp(g)); /* should be 3.560853 */ }
If that works we'll now try to integrate it into matlab. When we write a matlab routine in C we have to do some work first to deal with the input and output variables. That code needs to be in a routine called mexFunction. The code will be much the same in all mex files. In this example our routine mexexp is given pointers to the start of the input values and the start of the output values. The elements of the output matrix are filled in one at a time in the loop. To keep things simple, the code deals only with vectors but it can easily be extended to cope with matrices.
To prepare the code to run with matlab put the following into a file called mexexp.c and run mex mexexp.c (depending on your system you may need type the mex line on the command line rather than at the matlab prompt).
#include <math.h> #include "mex.h" /* these 2 #define lines help make the later code more readable */ /* Input Arguments */ #define PARAMETER_IN prhs[0] /* Output Arguments */ #define RESULT_OUT plhs[0] /* #define LITTLE_ENDIAN 1 */ static union { double d; struct { #ifdef LITTLE_ENDIAN int j,i; #else int i,j; #endif } n; } _eco; #define EXP_A (1048576/0.69314718055994530942) #define EXP_C 60801 #define EXP(y) (_eco.n.i = EXP_A*(y) + (1072693248 - EXP_C), _eco.d) void mexexp(double*y, double*yp, int m) { while(m>0) { m--; yp[m]=EXP(y[m]); } } void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) { double *yp; double *t,*y; unsigned int m,n; /* Check for proper number of arguments */ if (nrhs != 1) { mexErrMsgTxt("One input argument required."); } else if (nlhs > 1) { mexErrMsgTxt("Too many output arguments."); } m = mxGetM(PARAMETER_IN); n = mxGetN(PARAMETER_IN); /* Create a matrix for the return argument */ RESULT_OUT = mxCreateDoubleMatrix(m, n, mxREAL); /* Assign pointers to the various parameters */ yp = mxGetPr(RESULT_OUT); y = mxGetPr(PARAMETER_IN); /* Do the actual computation*/ mexexp(y,yp,n); return; }
Then test it by running the following matlab script. I get an 8x speed-up.
x=1:.0001:70; t=cputime; exp(x); cputime-t t=cputime; mexexp(x); cputime-t
To test accuracy, try this
x=1:.0001:70; a=exp(x); b=mexexp(x); percent_errors=100*abs(a-b)/a; disp('max percentage error is '); max(percent_errors)
Note however that there can be performance disadvantages to using a mex file - there are overheads and matlab's accelerator may be less effective. In the following example, exp and mexexp are being passed scalars. The mexexp implementation takes at least twice as long.
t=cputime; for x=1:.0001:70, exp(x); end cputime-t clear all t=cputime; for x=1:.0001:70, mexexp(x); end cputime-t