Creating C closures from Lua closures

Keywords: assembly, volatile, VirtualAlloc

When I set out to bridge the entire Windows C API to Lua, one of the most interesting and fun challenges I had was to create C callbacks from Lua functions. Without this, a significant portion of the API would be useless, such as WNDPROC.

Brief primer on Lua closures

For those unfamiliar with Lua 5.4, it has semantics perhaps 90% similar to JavaScript. In this case, that includes closures and lexical scoping. That is, outer variables are "enclosed around" when a new function is created that makes use of them. In Lua, these are called "up values", because they come from "up" in a higher lexical scope.

local a = 1
local f = function(b) return a + b end
print(f(2)) -- prints 3

Brief primer on C "closures"

The brutally but beautifully minimalistic language of C has no concept of closures at all. But it does have functions and pointers to raw memory, which are enough to work with. Therefore it's conventional to manually pass data through to callback functions via pointer to appropriate memory.

The following trivial example shows the basic concept at work. The outer caller and inner callback pass around a pointer and are aware of its true type, but the middle function is unaware, and simply passes it through as a void *ctx.

int cb(void *ctx, int b) {
  int a = *(int*)ctx;
  return a + b;
}

int add(void *a, int b, int (*f)(void *ctx)) {
  return f(a, b);
}

int main() {
  int a = 1;
  add(&a, 2, &cb);
}

Note: the C expressions cb and &cb both produce an address (in fact, the same address), but I explicitly used &cb here to make it more clear that it's a normal C address, in order to tie better into what we're going to do later.

The ideal bridging API

This is the Lua API that I settled on:

local a = 1
local lpfnWndProc = WNDPROC(function(hwnd, umsg, wparam, lparam)
  if umsg == WM_KEYDOWN then
    print(a)
  end
  return DefWindowProc(hwnd, umsg, wparam, lparam)
end)

That is, you pass a Lua function to another Lua function that represents the C callback type, and it returns a void* pointing to a dynamically created C function.

The naive implementation

Let's make our first example work, with a minor tweak:

local a = 1
local f = CALLBACK(function(b) return a + b end)
print(Add(f, 2)) -- prints 3

We introduced two new functions:

This is enough to be a simplified variant of the pattern used by WNDPROC, WNDCLASS, and CreateWindow.

We can trivially implement this:

static int findex;

static int REAL_CALLBACK(lua_State *L, int b) {
  // push the function and arg onto the stack
  lua_rawgeti(L, LUA_REGISTRYINDEX, findex);
  lua_pushinteger(L, b);
  // call the Lua function with 1 arg, 1 result
  lua_call(L, 1, 1);
  // return its result interpreted as a C int
  return lua_tointeger(L, -1);
}

static int CALLBACK(lua_State *L) {
  // store Lua function in Lua registry,
  // save its index as C global findex
  findex = luaL_ref(L, LUA_REGISTRYINDEX);
  // push the heavy lifting C function
  // as a light userdata (plain C pointer)
  lua_pushlightuserdata(L, &REAL_CALLBACK);
  return 1;
}

static int Add(lua_State *L) {
  int (*f)(lua_State *L, int b) = lua_touserdata(L, 1);
  int b = lua_tointeger(L, 2);

  // shim real work to the C function
  int c = f(L, b);

  lua_pushinteger(L, c);
  return 1;
}

int main() {
  lua_State *L = luaL_newstate();

  lua_pushcfunction(L, CALLBACK);
  lua_setglobal(L, "CALLBACK");

  lua_pushcfunction(L, Add);
  lua_setglobal(L, "Add");

  luaL_dostring(L, the_above_code);
}

This should be mostly self-explanatory, even if you're not familiar with Lua's C API, because it is is fairly self-documenting. But in short, Lua's C API uses an implicit stack of objects, and all the functions operate on this stack, with index 1 pointing to the lowest object, and index -1 pointing to the top-most object. We convert C types to Lua types via functions that operate on this stack and push/pop/interpret from it. (We're only pushing/interpreting from it, since our usage has no need to pop anything.) Userdata is how Lua stores and gives out C pointers.

The other thing worth explaining briefly here is LUA_REGISTRYINDEX, which is an index to a global table that's only accessible from C. The luaL_ref function pops the last stack object and stores it in this table at a unique index. This is perfect for our case, because then we can retrieve it from this table later in our callback. All we need to do is store the function's index in the findex global.

The most obvious problem with this solution is that it's limited to having a single function remembered by the callback. In other words, the following code would not work, because findex is overwritten on each call to CALLBACK, causing it to "lose" the first function, and only remember the second.

local a1, a2 = 1, 2
local f1 = CALLBACK(function(b) return a1 + b end)
local f2 = CALLBACK(function(b) return a2 + b end)
print(Add(f1, 2)) -- prints 4 -- wrong!
print(Add(f2, 2)) -- prints 4, right

The question is, how do you create a new and unique C pointer for each Lua function passed in, one that uses the correct Lua function? The answer, of course, is dynamically generated assembly!

Creating C closures dynamically at runtime

Now we're finally getting to the fun stuff! We're going to change CALLBACK to dynamically generate a "C" function at runtime, one that correctly sets our C global findex just before calling REAL_CALLBACK.

Naturally, this will need to have different implementations for different platforms, which is easy enough using #if guards. But for now we're only concerned with 64-bit Windows. The x64 calling convention document will be a helpful reference, and perhaps Overview of x64 ABI conventions too, but I'll explain the basics that we need to know.

The equivalent C code that we'll be generating is essentially this:

/* unspecified */ generated_function(/* unspecified */) {
  findex = _closed_over_findex;
  goto REAL_CALLBACK(/* pass args as-is */);
  /* return value is whatever REAL_CALLBACK returns */
}

In other words:

  1. Create a temporary _closed_over_findex value
  2. Set the C global findex to our new _closed_over_findex value
  3. Jump to first instruction at &REAL_CALLBACK

Learning assembly via Visual Studio's debugger

As a starting point, we can look at what assembly MSVC generates for code very similar to what we need to do. Let's create a minimalistic main that just sets a few globals and then jumps to an address. Then set a debug point on the first line of main (foo1 = 111), run it, and hit Ctrl-Alt-D to show the assembly.

static int foo1;
static int foo2;
static int foo3;

int main() {
  foo1 = 111;
  goto later;
  foo2 = 222;
  foo3 = 333;
later:
  return 0;
}

When you compile in Release mode, you get this:

nop          ;; 
xor eax, eax ;; set ret to 0
ret          ;; exit main

It happens because the values are optimized away. That's not particularly useful in this instance. But we can fix it by adding volatile to the variables. This tells the compiler that our variables are off-limits for optimizations, possibly for the sake of side-effects, or in our case, so we can inspect the actual output.

- static int foo1;
- static int foo2;
- static int foo3;
+ static volatile int foo1;
+ static volatile int foo2;
+ static volatile int foo3;

Now we get this:

nop          ;; 
mov dword ptr [foo1 (07FF7838E0670h)], 6Fh
xor eax, eax ;; set ret to 0
ret          ;; exit main

You'll get a different value than 07FF7838E0670h, but it will always be the same value as &foo1. You can test this with printf("%lld\n", &foo1); We have to use lld because in Windows 64-bit, pointers are 64-bits aka long long int.

Still, the compiler is just too smart for us here, and we have to compile in Debug configuration, so that we can see the jmp at all. Now we get this:

;int main() {
push  rbp  
push  rdi  
sub   rsp,0E8h  
lea   rbp,[rsp+20h]  
lea   rcx,[__29E14FF1_ConsoleApplication1@c (07FF717A41008h)]  
call  __CheckForDebuggerJustMyCode (07FF717A31352h)  
nop  
;  foo1 = 111;
mov   dword ptr [foo1 (07FF717A3C200h)],6Fh  
;  goto later;
jmp   $later (07FF717A317ACh)  
;  foo2 = 222;
mov   dword ptr [foo2 (07FF717A3C204h)],0DEh  
;  foo3 = 333;
mov   dword ptr [foo3 (07FF717A3C208h)],14Dh  
;later:
;  return 0;
xor   eax,eax  
;}
lea   rsp,[rbp+0C8h]  
pop   rdi  
pop   rbp  
ret  

Now we start seeing stack setup, and debug shims, but we can ignore all that. The only interesting part here is this:

mov   dword ptr [foo1 (07FF717A3C200h)],6Fh  
jmp   $later (07FF717A317ACh)  

Now, tying this to our original code, we'll want to generate this assembly:

mov   dword ptr [findex], _closed_over_findex
jmp   REAL_CALLBACK

Learning how to generate assembly

There's two websites that will help us here:

If we use both and compare outputs, we can double-check our solution.

We'll start by inputting the assembly as-is, but we'll notice errors and try to fix them. For starters, it's not helpful to have variables, so we can change them out with placeholder numbers, which we can later replace with our own values. In this case, let's stick with the first example:

mov   dword ptr [07FF717A3C200h],6Fh
jmp   07FF717A317ACh

The first issue we run into is that the first website (which uses GCC) hates the h syntax for hex values. So let's use the 0x syntax:

mov   dword ptr [0x07FF717A3C200],0x6F
jmp   0x07FF717A317AC

Now we see some "Unsupported instructions" in the GCC site, and we continue to see "Relocations" in the asmjit site. Let's simplify this by breaking down each instruction.

Learning the jmp instruction

First, let's look at jmp by itself. Here's a table of the values I tried to encode, with the results from the GCC site and asmjit site:

;; input               ;; gcc        ;; asmjit
jmp   0x07FF717A317AC  ;; error      ;; 40E900000000
jmp   0xffffffff       ;; error      ;; 40E900000000
jmp   0xfffffff        ;; e900000000 ;; 40E900000000

Interesting, we finally get assembly from GCC, and the 0xE9 matches what asmjit has been giving us, except that it's missing the 0x40 from asmjit.

The jmp manual page indicates that E9 refers to a relative address. That explains why it only allows values just under 32-bits. If we want an absolute address, particularly a 64-bit one, we can use FF and refer to memory via a register. This is confirmed in this SO answer.

Now we need to find an unused register. Consulting the x64-conventions page, we need to find a scratch register, particularly one that isn't used for params, since we're creating a shim at the beginning of a function. That page suggests that rax is fine, since it's both "volatile" (scratch register) and not used by params.

So let's change our assembly to use that:

;; input   ;; gcc    ;; asmjit
jmp   rax  ;; FFE0   ;; FFE0

Voila! We have a perfect match. And, since rax is the 64-bit version of eax, we have automatic support for 64-bit values. Nice.

So our jump is just two bytes: 0xFF and 0xE0. Let's note this for later and move on to mov.

Learning the mov instruction

Going back to our example:

mov   dword ptr [0x07FF717A3C200],0x6F

This works in asmjit, but not the GCC site. And the asmjit version has relocations again.

Let's see what happens if we reuse the rax trick. We'll store a literal value into rax, which we'll need to do anyway for the jmp instruction, so it knows where to goto.

mov   dword ptr [rax], 0x6F
;; gcc:    C7006F000000
;; asmjit: C7006F000000

Another perfect match! And no relocations! The immediate value (0x6F) is generated with the instruction.

There's only one problem: the d-word. Double words are only 32 bytes. What if our value is 64 bytes? We'll need a qword:

mov   qword ptr [rax], 0x6F
;; gcc:    48C7006F000000
;; asmjit: 48C7006F000000

It works, but it generates different instructions. But once we try a 64-bit value, such as 0xffffffffff we see this fail in both sites. That's because this is actually incorrect asm.

The syntax qword ptr [rax] made sense when the "dest" operand was a memory address and the "src" operand was an immediate (literal value). But now that we're dealing with a register, this is not the correct syntax. If we retry it with just plain rax we get this:

mov   rax, 0x6F               ;; 48C7C06F000000
mov   rax, 0xfffffff          ;; 48C7C0FFFFFF0F
mov   rax, 0xffffffff         ;; 48B8FFFFFFFF00000000
mov   rax, 0xfffffffff        ;; 48B8FFFFFFFF0F000000
mov   rax, 0x7f7f7f7f7f7f7f7f ;; 48B87F7F7F7F7F7F7F7F

Which means, our function movToRax looks like this:

static BYTE* movToRax(BYTE* exe, UINT64 val) {
  if (val > 0xffffffff) {
    *exe++ = 0x48;
    *exe++ = 0xb8;
    *((UINT64*)exe) = val;
    exe += sizeof(UINT64);
  }
  else {
    *exe++ = 0x48;
    *exe++ = 0xc7;
    *exe++ = 0xc0;
    *((UINT32*)exe) = val;
    exe += sizeof(UINT64);
  }
  return exe;
}

There's probably a more concise way to do this, but this gets the job done, and returns a pointer to the next byte after the memory is written, which we will use when we tie it all together.

We will need the inverse, a movRaxTo function, because our overall assembly generator will need to:

  1. Move _closed_over_findex to RAX
  2. Move RAX to &findex
  3. Move &REAL_CALLBACK to RAX
  4. Jump to RAX

We can use movToRax for 1 and 3, and 4 is the exact same logic as movToRax except that it just emits 0xFF and 0xE0 and doesn't need to check the size.

Implementing movRaxTo is an exercise for the reader. Which means now we're at the super fun part!

Generating assembly at runtime

We're almost ready to generate assembly! First, let's change CALLBACK to use the function we need but don't have yet:

static int CALLBACK(lua_State *L) {
- findex = luaL_ref(L, LUA_REGISTRYINDEX);
- lua_pushlightuserdata(L, &REAL_CALLBACK);
+ int index = luaL_ref(L, LUA_REGISTRYINDEX);
+ void *_generated_function = generate_function(index);
+ lua_pushlightuserdata(L, _generated_function);
  return 1;
}

A few changes:

  1. Instead of storing the index in findex directly, we're passing the int to generate_function, which will deal with findex in a new way.

  2. We have a new function generate_function which returns a void* pointing to a new function which we're going to create dynamically at runtime, which we named _generated_function here.

  3. Instead of returning &REAL_CALLBACK, we're returning the pointer _generated_function.

Now we have to implement the new generate_function, which produces assembly corresponding to generated_function as outlined above.

static void *generate_function(int _closed_over_findex) {
  BYTE *exe = VirtualAlloc(
    NULL,
    0x10000,
    MEM_RESERVE | MEM_COMMIT,
    PAGE_READWRITE | PAGE_EXECUTE
  );

  BYTE *fn = exe;

  exe = movToRax(exe, _closed_over_findex);
  exe = movRaxTo(exe, &findex);
  exe = movToRax(exe, &REAL_CALLBACK);
  exe = jmpToRax(exe);

  return fn;
}

That's it! We now have executable memory, which we wrote assembly to at runtime. It creates functions which set findex and then jumps to &REAL_CALLBACK. We return the memory, and it's indistinguishable at runtime from a function we'd have written in C, except mainly that it's missing debug information.

Caveats on incorrect things here

Notice that we didn't do any error handling. In practice, lowkPRO does a lot more, but it's rather boring and distracting.

Also notice that we have a hugely inefficient use of VirtualAlloc. That function allocates in 64kb chunks. You do not want to use it this way. In practice, lowkPRO allocates each chunk as needed, and splits it up into nodes in a large linked list, to be used at each callback creation. Nodes are freed by a winapi.freecallback function that lowkPRO provides.

And finally, it's not good practice to have executable memory that's both writable and executable. In practice, lowkPRO uses AllocProtect to set it as writable only when creating a callback, and executable at all other times.

It does a little more work, too, using thread_local and a critical section to make sure that creating, freeing, and calling callbacks are thread safe operations.

The techniques I describe in this article were used to bridge all Windows API callbacks to Lua, from WNDPROC to pD3DCompile, using a script written in C# that generates C++ code which itself takes about 30 minutes to recompile, porting about half a million Windows API entries.