Using Zig comptime for conceptual dryness

-

On github.

While writ­ing my C# Forth, I grew un­happy about the con­cep­tual rep­e­ti­tion in the code. To add a new Forth word, you have to add a new value to the enu­mer­a­tor that rep­re­sents the op­code, add a new mem­ber to a hashtable that maps it to a string (what the user types) and fi­nally im­ple­ment the ac­tion for the word. See here.

Granted, by us­ing a struct, I could have kept these items ge­o­graph­i­cally closer, but the rep­e­ti­tion is there. Us­ing re­flec­tion, I could have re­moved it, but then per­for­mance would suck in the in­ner loop of my in­ter­preter.

My fa­mil­iar­ity with the Zig lan­guage made me re­al­ize that I could re­move the re­dun­dancy by us­ing zig’s comptime, with­out any loss in per­for­mance. This post de­scribes a pro­to­type of that.

The idea is to have func­tions on my main struct in the form op_WORDNAME. At com­pile time, I generate the enu­mer­a­tor with all the op­codes and the se­ries of con­di­tional state­ments to call the right func­tion op_WORDNAME when the user types WORDNAME. By mak­ing the func­tions inline, I don’t even pay the price of a func­tion call. BTW: Zig gives a com­pile-time er­ror if it can’t in­line.

In the Vm code be­low, if you want to sup­port a new word, you add a new func­tion. The rest of the code is un­changed.


const std = @import("std");

const Vm = struct {
    state: i32 = 0,

    inline fn op_double(this: *Vm) void {
        this.state = this.state * 2;
    }
    inline fn op_plus1(this: *Vm) void {
        this.state = this.state + 1;
    }
    inline fn op_notFound(_: *Vm) void {
        std.log.info("{s}", .{"Word not found."});
    }
    inline fn op_bye(_: *Vm) void {
        std.process.exit(0);
    }
};

This is the shell main loop. It is pretty stan­dard stuff un­til you get to the in­ner­most while loop, marked with (*). Even if it looks like the code is call­ing two func­tions (findToken and execToken), it is­n’t. The com­piler re­places these two func­tion calls with a se­ries of if state­ments to match a string with the cor­re­spond­ing op­code and to find the cor­rect func­tion to call.

Even the enu­mer­a­tor Token is not de­fined any­where in the code. It is gen­er­ated au­to­mat­i­cally at com­pile time and con­tains one value for each op_XXX func­tion de­fined on Vm. Even if I have not de­fined it, I can still ac­cess the val­ues nor­mally (i.e., see Token.ntFound be­low).


fn shellLoop(stdin: std.fs.File.Reader, stdout: std.fs.File.Writer) !void {
    const max_input = 1024;
    var input_buffer: [max_input]u8 = undefined;
    var vm = Vm{};

    while (true) {
        try stdout.print("> ", .{});

        var input_str = (try stdin.readUntilDelimiterOrEof(input_buffer[0..], '\n')) orelse {
            try stdout.print("\n", .{});
            return;
        };

        if (input_str.len == 0) continue;
        var words = std.mem.tokenize(u8, input_str, " ");

        while (words.next()) |word| { // (*)
            const token = findToken(word) orelse Token.notFound;
            execToken(&vm, token);
        }
        std.log.info("{}", .{vm.state});
    }
}

So, how do we do it? Let’s start with findToken. The code is rel­a­tively sim­ple be­cause you are not writ­ing a macro’. You are just writ­ing nor­mal Zig code. For .NET pro­gram­mers, this is like hav­ing System.Reflection and System.Reflection.Emit avail­able at com­pile time.

You could do some­thing sim­i­lar with C# source gen­er­a­tors, but you would need to op­er­ate on a rather com­plex AST and gen­er­ate code us­ing string con­cate­na­tion. It would prob­a­bly be dozens of lines of very in­tri­cated code. Here, it is three sim­ple lines.

An inline for tells Zig to un­roll the loop. The com­piler it­er­ates overa all the field of the Enum and gen­er­ates a se­ries of if state­ments that re­turn the value of the Enum that matches the given string (what the user typed).


inline fn findToken(word: []const u8) ?Token {
    inline for (@typeInfo(Token).Enum.fields) |enField| {
        if (std.mem.eql(u8, enField.name, word))
            return @field(Token, enField.name);
    }
    return null;
}

Token ex­e­cu­tion is sim­i­lar. Again we un­roll the loop at com­pile time, gen­er­at­ing a se­ries of if statements that ex­e­cute the Vm func­tion cor­re­spond­ing to the given Token.


inline fn execToken(vm: *Vm, tok: Token) void {
    inline for (@typeInfo(Token).Enum.fields) |enField| {
        const enumValue = @field(Token, enField.name);
        if (enumValue == tok) {
            const empty = .{};
            _ = @call(empty, @field(Vm, "op_" ++ @tagName(enumValue)), .{vm});
        }
    }
}

We gen­er­ate the Token enu­mer­a­tor by it­er­at­ing over all the de­c­la­ra­tions on the Vm struct and gen­er­at­ing a set of de­c­la­ra­tions that are then use to com­pile time con­struct the cor­rect Enum us­ing the @Type builtin func­tion.


const Token = GenerateTokenEnumType(Vm);

fn GenerateTokenEnumType(comptime T: type) type {
    const fieldInfos = std.meta.declarations(T);
    var enumDecls: [fieldInfos.len]std.builtin.TypeInfo.EnumField = undefined;
    var decls = [_]std.builtin.TypeInfo.Declaration{};
    inline for (fieldInfos) |field, i| {
        const name = field.name;
        if (name[0] == 'o' and name[1] == 'p') {
            enumDecls[i] = .{ .name = field.name[3..], .value = i };
        }
    }
    return @Type(.{
        .Enum = .{
            .layout = .Auto,
            .tag_type = u8,
            .fields = &enumDecls,
            .decls = &decls,
            .is_exhaustive = true,
        },
    });
}

Ok, but does it re­ally work? Well, you can run it with zig build run, but it is re­ally in­lin­ing cor­rectly? Well, the as­sem­bly lan­guage says yes. No calls to ex­ter­nal func­tions in the main loop.

const token = findToken(word) orelse Token.notFound;
2310db: f6 85 a1 fa ff ff 01 testb $0x1,-0x55f(%rbp)
2310e2: 75 09 jne 2310ed <shellLoop+0x38d>
2310e4: c6 85 9f fa ff ff 02 movb $0x2,-0x561(%rbp)
2310eb: eb 0c jmp 2310f9 <shellLoop+0x399>
2310ed: 8a 85 a0 fa ff ff mov -0x560(%rbp),%al
2310f3: 88 85 9f fa ff ff mov %al,-0x561(%rbp)
home/lucabol/dev/zig-forth/src/main.zig:65
execToken(&vm, token);
2310f9: 8a 85 9f fa ff ff mov -0x561(%rbp),%al
2310ff: 48 8d 8d 58 fb ff ff lea -0x4a8(%rbp),%rcx
231106: 48 89 4d d8 mov %rcx,-0x28(%rbp)
23110a: 88 45 d7 mov %al,-0x29(%rbp)
execToken():
home/lucabol/dev/zig-forth/src/main.zig:32
if (enumValue == tok) {
23110d: 31 c0 xor %eax,%eax
23110f: 3a 45 d7 cmp -0x29(%rbp),%al
231112: 75 50 jne 231164 <shellLoop+0x404>
home/lucabol/dev/zig-forth/src/main.zig:34
_ = @call(empty, @field(Vm, "op_" ++ @tagName(enumValue)), .{vm});
231114: 48 8b 45 d8 mov -0x28(%rbp),%rax
231118: 48 89 45 e0 mov %rax,-0x20(%rbp)
Vm.op_double():
home/lucabol/dev/zig-forth/src/main.zig:7
this.state = this.state * 2;
23111c: 48 8b 45 e0 mov -0x20(%rbp),%rax
231120: 48 89 85 58 fa ff ff mov %rax,-0x5a8(%rbp)
231127: 48 8b 4d e0 mov -0x20(%rbp),%rcx
23112b: b8 02 00 00 00 mov $0x2,%eax
231130: 0f af 01 imul (%rcx),%eax
231133: 89 85 60 fa ff ff mov %eax,-0x5a0(%rbp)
231139: 0f 90 c0 seto %al
23113c: 70 02 jo 231140 <shellLoop+0x3e0>
23113e: eb 13 jmp 231153 <shellLoop+0x3f3>
231140: 48 bf 68 1f 20 00 00 00 00 00 movabs $0x201f68,%rdi
23114a: 31 c0 xor %eax,%eax
23114c: 89 c6 mov %eax,%esi
23114e: e8 fd 32 fd ff callq 204450 <std.builtin.default_panic>
231153: 48 8b 85 58 fa ff ff mov -0x5a8(%rbp),%rax
23115a: 8b 8d 60 fa ff ff mov -0x5a0(%rbp),%ecx
231160: 89 08 mov %ecx,(%rax)
execToken():
home/lucabol/dev/zig-forth/src/main.zig:32
if (enumValue == tok) {
231162: eb 02 jmp 231166 <shellLoop+0x406>
231164: eb 00 jmp 231166 <shellLoop+0x406>
231166: b0 01 mov $0x1,%al
231168: 3a 45 d7 cmp -0x29(%rbp),%al
23116b: 75 4c jne 2311b9 <shellLoop+0x459>
home/lucabol/dev/zig-forth/src/main.zig:34
_ = @call(empty, @field(Vm, "op_" ++ @tagName(enumValue)), .{vm});
23116d: 48 8b 45 d8 mov -0x28(%rbp),%rax
231171: 48 89 45 e8 mov %rax,-0x18(%rbp)
Vm.op_plus1():
home/lucabol/dev/zig-forth/src/main.zig:10
this.state = this.state + 1;
231175: 48 8b 45 e8 mov -0x18(%rbp),%rax
231179: 48 89 85 48 fa ff ff mov %rax,-0x5b8(%rbp)
231180: 48 8b 45 e8 mov -0x18(%rbp),%rax
231184: 8b 00 mov (%rax),%eax
231186: ff c0 inc %eax
231188: 89 85 54 fa ff ff mov %eax,-0x5ac(%rbp)
23118e: 0f 90 c0 seto %al

Also, more di­rectly, see be­low:

return @field(Token, enField.name);
230ff0: c6 85 a1 fa ff ff 01 movb $0x1,-0x55f(%rbp)
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
230ff7: c6 85 a0 fa ff ff 00 movb $0x0,-0x560(%rbp)
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
230ffe: e9 d8 00 00 00 jmpq 2310db <shellLoop+0x37b>
home/lucabol/dev/zig-forth/src/main.zig:24
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
if (std.mem.eql(u8, enField.name, word))
231003: 48 8b 85 a8 fa ff ff mov -0x558(%rbp),%rax
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
23100a: 48 89 45 b0 mov %rax,-0x50(%rbp)
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
23100e: 48 8b 85 b0 fa ff ff mov -0x550(%rbp),%rax
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
231015: 48 89 45 b8 mov %rax,-0x48(%rbp)
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
231019: 48 bf e0 2a 20 00 00 00 00 00 movabs $0x202ae0,%rdi
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
231023: 48 8d b5 a8 fa ff ff lea -0x558(%rbp),%rsi
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
23102a: e8 61 ac fd ff callq 20bc90 <std.mem.eql>
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
23102f: a8 01 test $0x1,%al
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
231031: 75 02 jne 231035 <shellLoop+0x2d5>
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
231033: eb 13 jmp 231048 <shellLoop+0x2e8>
home/lucabol/dev/zig-forth/src/main.zig:25
inlined by /home/lucabol/dev/zig-forth/src/main.zig:64 (shellLoop)
return @field(Token, enField.name);

And the dri­ver is ob­vi­ous.


pub fn main() !u8 {
    const stdin = std.io.getStdIn().reader();
    const stdout = std.io.getStdOut().writer();
    try stdout.print("*** Hello, I am a Forth shell! ***\n", .{});

    try shellLoop(stdin, stdout);

    return 0; // We either crash or we are fine.
}

Tags