summaryrefslogtreecommitdiff
path: root/src/codegen_jit.h
diff options
context:
space:
mode:
authorDavid Moc <personal@cdatgoose.org>2026-03-05 23:38:49 +0100
committerDavid Moc <personal@cdatgoose.org>2026-03-05 23:38:49 +0100
commit0385817bb1301a778bb33f8405a435293b9f8905 (patch)
tree53f4b6f13e393bd368c37ba4363826b46940dfd3 /src/codegen_jit.h
parent262abf9b552a168ef3ae91f91af97683f16420a7 (diff)
Pushing to repo for safety.
Diffstat (limited to 'src/codegen_jit.h')
-rw-r--r--src/codegen_jit.h827
1 files changed, 827 insertions, 0 deletions
diff --git a/src/codegen_jit.h b/src/codegen_jit.h
new file mode 100644
index 0000000..ff1d687
--- /dev/null
+++ b/src/codegen_jit.h
@@ -0,0 +1,827 @@
+#ifndef CODEGEN_JIT_H
+#define CODEGEN_JIT_H
+
+#include "ast.h"
+#include "token.h"
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <unistd.h>
+#include <limits.h>
+#include <ctype.h>
+
+enum {
+ RAX = 0, RCX = 1, RDX = 2, RBX = 3,
+ RSP = 4, RBP = 5, RSI = 6, RDI = 7,
+ R8 = 8, R9 = 9, R10 = 10, R11 = 11,
+ R12 = 12, R13 = 13, R14 = 14, R15 = 15
+};
+
+typedef struct {
+ uint8_t *buf;
+ size_t len;
+ size_t cap;
+} CodeBuf;
+
+typedef struct FuncMap {
+ char *name;
+ void *addr;
+ size_t size;
+ size_t alloc_size;
+ struct FuncMap *next;
+} FuncMap;
+
+typedef struct PatchEntry {
+ size_t offset; // offset of the imm64 within owning function's code
+ char *func_name; // target function to patch in
+ char *owning_func;
+ struct PatchEntry *next;
+} PatchEntry;
+
+typedef struct VarMap {
+ char *name;
+ int offset; // RBP-relative offset
+ _TY type;
+ struct VarMap *next;
+} VarMap;
+
+typedef struct {
+ FuncMap *func_list;
+ VarMap *var_list;
+ int next_local_offset;
+ CodeBuf cb;
+ char *current_func_name;
+ PatchEntry *patch_list;
+} JIT;
+
+static void jit_init(JIT *jit) {
+ jit->func_list = NULL;
+ jit->var_list = NULL;
+ jit->next_local_offset = -8;
+ jit->cb.buf = NULL;
+ jit->cb.len = jit->cb.cap = 0;
+ jit->current_func_name = NULL;
+ jit->patch_list = NULL;
+}
+
+static void jit_free(JIT *jit) {
+ for (VarMap *v = jit->var_list; v;) {
+ VarMap *n = v->next; free(v->name); free(v); v = n;
+ }
+ jit->var_list = NULL;
+
+ for (PatchEntry *p = jit->patch_list; p;) {
+ PatchEntry *n = p->next; free(p->func_name); free(p->owning_func); free(p); p = n;
+ }
+ jit->patch_list = NULL;
+
+ for (FuncMap *f = jit->func_list; f;) {
+ FuncMap *n = f->next;
+ if (f->addr && f->alloc_size > 0) munmap(f->addr, f->alloc_size);
+ free(f->name); free(f); f = n;
+ }
+ jit->func_list = NULL;
+
+ free(jit->cb.buf);
+ jit->cb.buf = NULL;
+ jit->cb.len = jit->cb.cap = 0;
+}
+
+/* --- Code buffer --- */
+
+static void cb_init(CodeBuf *c) {
+ c->cap = 1024; c->len = 0;
+ c->buf = (uint8_t *)malloc(c->cap);
+ if (!c->buf) { fprintf(stderr, "[JIT] failed to allocate code buffer\n"); exit(1); }
+}
+static void cb_free(CodeBuf *c) {
+ free(c->buf); c->buf = NULL; c->len = c->cap = 0;
+}
+static void cb_grow(CodeBuf *c, size_t need) {
+ if (c->len + need > c->cap) {
+ while (c->len + need > c->cap) c->cap *= 2;
+ uint8_t *nb = (uint8_t *)realloc(c->buf, c->cap);
+ if (!nb) { fprintf(stderr, "[JIT] failed to grow code buffer\n"); free(c->buf); exit(1); }
+ c->buf = nb;
+ }
+}
+static void emit8(CodeBuf *c, uint8_t b) { cb_grow(c,1); c->buf[c->len++] = b; }
+static void emit32(CodeBuf *c, uint32_t v) { cb_grow(c,4); memcpy(c->buf+c->len,&v,4); c->len+=4; }
+static void emit64(CodeBuf *c, uint64_t v) { cb_grow(c,8); memcpy(c->buf+c->len,&v,8); c->len+=8; }
+static void emitN(CodeBuf *c, const void *p, size_t n) { cb_grow(c,n); memcpy(c->buf+c->len,p,n); c->len+=n; }
+
+/* --- x86-64 encoding helpers --- */
+
+static void emit_rex(CodeBuf *c, int reg, int rm, int w) {
+ uint8_t rex = 0x40;
+ if (w) rex |= 0x08;
+ if (reg&8) rex |= 0x04;
+ if (rm&8) rex |= 0x01;
+ emit8(c, (rex != 0x40 || w) ? rex : 0x48);
+}
+static void emit_modrm(CodeBuf *c, uint8_t mod, uint8_t reg, uint8_t rm) {
+ emit8(c, (mod<<6)|((reg&7)<<3)|(rm&7));
+}
+
+static void emit_movabs_rax_imm64(CodeBuf *c, uint64_t imm) {
+ emit8(c,0x48); emit8(c,0xB8); emit64(c,imm);
+}
+static void emit_mov_reg_imm64(CodeBuf *c, int reg, int64_t imm) {
+ emit_rex(c,reg,0,1); emit8(c,0xB8+(reg&7)); emit64(c,imm);
+}
+static void emit_mov_mem64_reg(CodeBuf *c, int disp32, int reg) {
+ emit_rex(c,reg,RBP,1); emit8(c,0x89); emit_modrm(c,2,reg,RBP); emit32(c,(uint32_t)disp32);
+}
+static void emit_mov_reg_mem64(CodeBuf *c, int reg, int disp32) {
+ emit_rex(c,reg,RBP,1); emit8(c,0x8B); emit_modrm(c,2,reg,RBP); emit32(c,(uint32_t)disp32);
+}
+static void emit_mov_reg_mem_reg(CodeBuf *c, int dst, int base) {
+ emit_rex(c,dst,base,1); emit8(c,0x8B); emit_modrm(c,0,dst,base);
+}
+static void emit_mov_mem_reg_reg(CodeBuf *c, int base, int src) {
+ emit_rex(c,src,base,1); emit8(c,0x89); emit_modrm(c,0,src,base);
+}
+static void emit_movzx_rax_mem8(CodeBuf *c, int disp32) {
+ emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB6); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
+}
+static void emit_mov_mem8_rax(CodeBuf *c, int disp32) {
+ emit_rex(c,RAX,RBP,0); emit8(c,0x88); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
+}
+static void emit_mov_reg_reg(CodeBuf *c, int dst, int src) {
+ emit_rex(c,src,dst,1); emit8(c,0x89); emit_modrm(c,3,src,dst);
+}
+static void emit_add_reg_reg(CodeBuf *c, int dst, int src) {
+ emit_rex(c,src,dst,1); emit8(c,0x01); emit_modrm(c,3,src,dst);
+}
+static void emit_sub_reg_reg(CodeBuf *c, int dst, int src) {
+ emit_rex(c,src,dst,1); emit8(c,0x29); emit_modrm(c,3,src,dst);
+}
+static void emit_add_reg_imm32(CodeBuf *c, int dst, int32_t imm) {
+ emit_rex(c,0,dst,1); emit8(c,0x81); emit_modrm(c,3,0,dst); emit32(c,(uint32_t)imm);
+}
+static void emit_push_reg(CodeBuf *c, int reg) {
+ if (reg&8) emit8(c,0x41); emit8(c,0x50+(reg&7));
+}
+static void emit_pop_reg(CodeBuf *c, int reg) {
+ if (reg&8) emit8(c,0x41); emit8(c,0x58+(reg&7));
+}
+
+static void emit_add_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x01,0xD8},3); }
+static void emit_imul_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xAF,0xC3},4); }
+static void emit_idiv_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); }
+static void emit_imod(CodeBuf *c) {
+ emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); // cqo; idiv rbx
+ emit_mov_reg_reg(c, RAX, RDX); // remainder -> RAX
+}
+static void emit_or_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x09,0xD8},3); }
+static void emit_xor_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x31,0xD8},3); }
+static void emit_and_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x21,0xD8},3); }
+static void emit_shl_rax_cl(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0xD3,0xE0},3); }
+static void emit_shr_rax_cl(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0xD3,0xE8},3); }
+static void emit_cmp_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x39,0xD8},3); }
+static void emit_setcc_al(CodeBuf *c, uint8_t cc) { emitN(c,(uint8_t[]){0x0F,0x90|cc,0xC0},3); }
+static void emit_movzx_rax_al(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xB6,0xC0},4); }
+static void emit_jmp_rel32(CodeBuf *c, int32_t rel) { emit8(c,0xE9); emit32(c,(uint32_t)rel); }
+static void emit_jcc_rel32(CodeBuf *c, uint8_t cc, int32_t rel) {
+ emitN(c,(uint8_t[]){0x0F,(uint8_t)(0x80|cc)},2); emit32(c,(uint32_t)rel);
+}
+static void emit_test_rax_rax(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x85,0xC0},3); }
+
+static void emit_prologue(CodeBuf *c, int total_stack_size) {
+ emit8(c,0x55); // push rbp
+ emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); // mov rbp, rsp
+ emit_push_reg(c, RBX); // save callee-saved RBX
+ int stack_bytes = ((total_stack_size+15)/16)*16;
+ if (stack_bytes > 0) {
+ emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); // sub rsp, imm32
+ emit32(c,(uint32_t)stack_bytes);
+ }
+}
+static void emit_epilogue(CodeBuf *c) {
+ emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); // restore RBX; leave; ret
+}
+static void emit_lea_rax_rbp_disp(CodeBuf *c, int disp32) {
+ emit8(c,0x48); emit8(c,0x8D); emit_modrm(c,2,RAX,RBP); emit32(c,(uint32_t)disp32);
+}
+static void emit_load_rax_from_mem(CodeBuf *c, int disp32, int size) {
+ if (size == 1) emit_movzx_rax_mem8(c, disp32); else emit_mov_reg_mem64(c, RAX, disp32);
+}
+static void emit_store_rax_to_mem(CodeBuf *c, int disp32, int size) {
+ if (size == 1) emit_mov_mem8_rax(c, disp32); else emit_mov_mem64_reg(c, disp32, RAX);
+}
+
+/* --- Variable and type helpers --- */
+
+static int calculate_type_size(_TY type) {
+ if (type.ptr_level > 0) return 8;
+ int base_size = (type.base == TY_CHAR) ? 1 : 8;
+ size_t total = (size_t)base_size;
+ if (type.array_size > 0) {
+ total *= (size_t)type.array_size;
+ if (total > (size_t)INT_MAX) { fprintf(stderr, "[JIT] array size too large\n"); exit(1); }
+ }
+ return (int)total;
+}
+
+static int align_offset(int offset, _TY type) {
+ int align = (type.base == TY_CHAR) ? 1 : 8;
+ if (offset % align == 0) return offset;
+ return (offset / align) * align;
+}
+
+static void reset_varmap(JIT *jit) {
+ for (VarMap *v = jit->var_list; v;) {
+ VarMap *n = v->next; free(v->name); free(v); v = n;
+ }
+ jit->var_list = NULL;
+ jit->next_local_offset = -16; // after saved RBX at RBP-8, 16-byte aligned
+}
+
+static void add_var(JIT *jit, const char *name, _TY type) {
+ VarMap *v = (VarMap *)malloc(sizeof(VarMap));
+ if (!v) { fprintf(stderr, "[JIT] malloc failed in add_var\n"); exit(1); }
+ v->name = strdup(name);
+ if (!v->name) { fprintf(stderr, "[JIT] strdup failed in add_var\n"); free(v); exit(1); }
+ v->type = type;
+ int type_size = calculate_type_size(type);
+ jit->next_local_offset = align_offset(jit->next_local_offset, type);
+ v->offset = jit->next_local_offset;
+ jit->next_local_offset -= type_size;
+ v->next = jit->var_list;
+ jit->var_list = v;
+}
+
+static int get_var_offset(JIT *jit, const char *name) {
+ for (VarMap *v = jit->var_list; v; v = v->next)
+ if (strcmp(v->name, name) == 0) return v->offset;
+ fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1);
+}
+
+static _TY get_var_type(JIT *jit, const char *name) {
+ for (VarMap *v = jit->var_list; v; v = v->next)
+ if (strcmp(v->name, name) == 0) return v->type;
+ fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1);
+}
+
+/* --- Type checking --- */
+
+static _TY get_expr_type(JIT *jit, _EX *expr) {
+ switch (expr->kind) {
+ case EX_NUMBER: return (_TY){TY_INT, 0, -1};
+ case EX_STRING: return (_TY){TY_CHAR, 1, -1};
+ case EX_VAR: return get_var_type(jit, expr->name);
+ case EX_BINOP: {
+ _TY left_type = get_expr_type(jit, expr->binop.l);
+ // All arithmetic, comparison, bitwise ops produce int
+ (void)get_expr_type(jit, expr->binop.r);
+ switch (expr->binop.op) {
+ case TK_PLUS: case TK_MINUS: case TK_STAR: case TK_SLASH: case TK_PERCENT:
+ case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE:
+ case TK_AND: case TK_OR:
+ case TK_AMP: case TK_BAR: case TK_CARET: case TK_SHL: case TK_SHR:
+ return (_TY){TY_INT, 0, -1};
+ default: return left_type;
+ }
+ }
+ case EX_CALL: return (_TY){TY_INT, 0, -1};
+ case EX_INDEX: {
+ _TY t = get_expr_type(jit, expr->index.array);
+ if (t.array_size > 0) return (_TY){t.base, t.ptr_level, -1};
+ if (t.ptr_level > 0) return (_TY){t.base, t.ptr_level-1, -1};
+ return (_TY){TY_INT, 0, -1};
+ }
+ case EX_DEREF: {
+ _TY t = get_expr_type(jit, expr->deref.expr);
+ if (t.ptr_level > 0) return (_TY){t.base, t.ptr_level-1, -1};
+ return (_TY){TY_INT, 0, -1};
+ }
+ case EX_ADDR: {
+ _TY t = get_expr_type(jit, expr->addr.expr);
+ return (_TY){t.base, t.ptr_level+1, -1};
+ }
+ default: return (_TY){TY_INT, 0, -1};
+ }
+}
+
+static int types_compatible(_TY expected, _TY actual) {
+ if (expected.base == actual.base &&
+ expected.ptr_level == actual.ptr_level &&
+ expected.array_size == actual.array_size) return 1;
+ // Allow untyped int literals to be assigned anywhere
+ if (actual.base == TY_INT && actual.ptr_level == 0 && actual.array_size == -1) return 1;
+ return 0;
+}
+
+/* --- Function registry --- */
+
+static void register_func(JIT *jit, const char *name) {
+ FuncMap *f = (FuncMap *)malloc(sizeof(FuncMap));
+ if (!f) { fprintf(stderr, "[JIT] malloc failed in register_func\n"); exit(1); }
+ f->name = strdup(name);
+ if (!f->name) { fprintf(stderr, "[JIT] strdup failed in register_func\n"); free(f); exit(1); }
+ f->addr = NULL; f->size = 0; f->alloc_size = 0;
+ f->next = jit->func_list;
+ jit->func_list = f;
+}
+
+static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, size_t alloc_size) {
+ for (FuncMap *f = jit->func_list; f; f = f->next) {
+ if (strcmp(f->name, name) != 0) continue;
+ f->addr = addr; f->size = size; f->alloc_size = alloc_size;
+
+ fprintf(stderr, "\n[JIT] %s @ %p (%zu bytes)\n", name, addr, size);
+ uint8_t *p = (uint8_t *)addr;
+ for (size_t i = 0; i < size; i += 16) {
+ fprintf(stderr, "\t%04zx ", i);
+ for (size_t j = 0; j < 16; j++)
+ (i+j < size) ? fprintf(stderr, "%02x ", p[i+j]) : fprintf(stderr, " ");
+ fprintf(stderr, " |");
+ for (size_t j = 0; j < 16 && i+j < size; j++)
+ fprintf(stderr, "%c", isprint(p[i+j]) ? p[i+j] : '.');
+ fprintf(stderr, "\n");
+ }
+ return;
+ }
+ fprintf(stderr, "[JIT] set_func_addr: unknown function '%s'\n", name); exit(1);
+}
+
+static void *get_func_addr(JIT *jit, const char *name) {
+ for (FuncMap *f = jit->func_list; f; f = f->next) {
+ if (strcmp(f->name, name) == 0)
+ return f->addr ? f->addr : (void*)0xDEADBEEF; // placeholder; patched later
+ }
+ fprintf(stderr, "[JIT] get_func_addr: unknown function '%s'\n", name); exit(1);
+}
+
+/* --- Stack size calculation --- */
+
+static int calculate_stack_size(_STN *s) {
+ int total = 0;
+ while (s) {
+ if (s->kind == STK_VAR_DECL) total += calculate_type_size(s->var_decl.type);
+ if (s->kind == STK_BLOCK) total += calculate_stack_size(s->body);
+ if (s->kind == STK_FOR) {
+ if (s->fr.init) total += calculate_stack_size(s->fr.init);
+ total += calculate_stack_size(s->fr.body);
+ }
+ s = s->n;
+ }
+ return total;
+}
+
+/* --- Code generation (forward declarations) --- */
+
+static void gen_expr_jit(JIT *jit, _EX *e);
+static int gen_stmt_jit(JIT *jit, _STN *s);
+
+/* --- Statement generation --- */
+
+static int gen_stmt_jit(JIT *jit, _STN *s) {
+ while (s) {
+ switch (s->kind) {
+
+ case STK_VAR_DECL:
+ add_var(jit, s->var_decl.name, s->var_decl.type);
+ if (s->var_decl.init) {
+ _TY init_type = get_expr_type(jit, s->var_decl.init);
+ if (!types_compatible(s->var_decl.type, init_type)) {
+ fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n",
+ tybase_name(init_type.base), tybase_name(s->var_decl.type.base));
+ exit(1);
+ }
+ gen_expr_jit(jit, s->var_decl.init);
+ int sz = (s->var_decl.type.ptr_level > 0) ? 8
+ : (s->var_decl.type.base == TY_CHAR) ? 1 : 8;
+ emit_store_rax_to_mem(&jit->cb, get_var_offset(jit, s->var_decl.name), sz);
+ }
+ break;
+
+ case STK_ASSIGN: {
+ _EX *lhs = s->assign.lhs;
+ if (lhs->kind == EX_VAR) {
+ int offset = get_var_offset(jit, lhs->name);
+ _TY type = get_var_type(jit, lhs->name);
+ _TY expr_type = get_expr_type(jit, s->assign.expr);
+ if (!types_compatible(type, expr_type)) {
+ fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n",
+ tybase_name(expr_type.base), tybase_name(type.base));
+ exit(1);
+ }
+ gen_expr_jit(jit, s->assign.expr);
+ int sz = (type.ptr_level > 0) ? 8 : (type.base == TY_CHAR) ? 1 : 8;
+ emit_store_rax_to_mem(&jit->cb, offset, sz);
+
+ } else if (lhs->kind == EX_DEREF) {
+ gen_expr_jit(jit, lhs->deref.expr);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); // RBX = address
+ gen_expr_jit(jit, s->assign.expr); // RAX = value
+ emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+
+ } else if (lhs->kind == EX_INDEX) {
+ if (lhs->index.array->kind != EX_VAR) {
+ gen_expr_jit(jit, lhs->index.array); break;
+ }
+ _TY var_type = get_var_type(jit, lhs->index.array->name);
+ if (var_type.array_size <= 0) {
+ fprintf(stderr, "[JIT] Cannot index non-array '%s'\n", lhs->index.array->name); exit(1);
+ }
+ int array_offset = get_var_offset(jit, lhs->index.array->name);
+ int element_size = (var_type.base == TY_CHAR) ? 1 : 8;
+ gen_expr_jit(jit, lhs->index.index); // RAX = index
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
+ emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index]
+ emit_mov_reg_reg(&jit->cb, RBX, RAX);
+ gen_expr_jit(jit, s->assign.expr); // RAX = value
+ emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+
+ } else {
+ fprintf(stderr, "[JIT] Unsupported assignment LHS kind %d\n", lhs->kind); exit(1);
+ }
+ break;
+ }
+
+ case STK_RETURN:
+ if (s->return_expr) gen_expr_jit(jit, s->return_expr);
+ emit_epilogue(&jit->cb);
+ return 1;
+
+ case STK_EXPR:
+ gen_expr_jit(jit, s->expr);
+ break;
+
+ case STK_BLOCK: {
+ int r = gen_stmt_jit(jit, s->body);
+ if (r) return 1;
+ break;
+ }
+
+ case STK_IF: {
+ gen_expr_jit(jit, s->ifs.cond);
+ emit_test_rax_rax(&jit->cb);
+ size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); // JZ else
+ gen_stmt_jit(jit, s->ifs.thenb);
+ size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); // JMP end
+ int32_t rel_else = (int32_t)(jit->cb.len - (jz_pos + 6));
+ memcpy(jit->cb.buf + jz_pos + 2, &rel_else, 4);
+ if (s->ifs.elseb) gen_stmt_jit(jit, s->ifs.elseb);
+ int32_t rel_end = (int32_t)(jit->cb.len - (jmp_pos + 5));
+ memcpy(jit->cb.buf + jmp_pos + 1, &rel_end, 4);
+ break;
+ }
+
+ case STK_WHILE: {
+ size_t loop_start = jit->cb.len;
+ gen_expr_jit(jit, s->whl.cond);
+ emit_test_rax_rax(&jit->cb);
+ size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
+ gen_stmt_jit(jit, s->whl.body);
+ emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6));
+ memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4);
+ break;
+ }
+
+ case STK_FOR: {
+ if (s->fr.init) gen_stmt_jit(jit, s->fr.init);
+ size_t loop_start = jit->cb.len;
+ if (s->fr.cond) {
+ gen_expr_jit(jit, s->fr.cond);
+ emit_test_rax_rax(&jit->cb);
+ size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
+ gen_stmt_jit(jit, s->fr.body);
+ if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
+ emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6));
+ memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4);
+ } else {
+ gen_stmt_jit(jit, s->fr.body);
+ if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
+ emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ }
+ break;
+ }
+
+ default:
+ fprintf(stderr, "[JIT] Unsupported statement kind %d\n", s->kind); exit(1);
+ }
+ s = s->n;
+ }
+ return 0;
+}
+
+/* --- Expression generation --- */
+
+static void gen_expr_jit(JIT *jit, _EX *e) {
+ if (!e) return;
+ switch (e->kind) {
+
+ case EX_NUMBER:
+ emit_movabs_rax_imm64(&jit->cb, (uint64_t)e->value);
+ break;
+
+ case EX_VAR: {
+ int off = get_var_offset(jit, e->name);
+ _TY ty = get_var_type(jit, e->name);
+ int sz = (ty.ptr_level > 0) ? 8 : (ty.base == TY_CHAR) ? 1 : 8;
+ if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off);
+ else emit_mov_reg_mem64(&jit->cb, RAX, off);
+ break;
+ }
+
+ case EX_BINOP: {
+ // Short-circuit AND
+ if (e->binop.op == TK_AND) {
+ gen_expr_jit(jit, e->binop.l);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
+ size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
+ gen_expr_jit(jit, e->binop.r);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
+ emit_setcc_al(&jit->cb, 0x05); emit_movzx_rax_al(&jit->cb);
+ size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0);
+ int32_t rel = (int32_t)(jit->cb.len - (jz_pos + 6));
+ memcpy(jit->cb.buf + jz_pos + 2, &rel, 4);
+ emit_movabs_rax_imm64(&jit->cb, 0);
+ rel = (int32_t)(jit->cb.len - (jmp_pos + 5));
+ memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4);
+ break;
+ }
+ // Short-circuit OR
+ if (e->binop.op == TK_OR) {
+ gen_expr_jit(jit, e->binop.l);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
+ size_t jnz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x05, 0);
+ gen_expr_jit(jit, e->binop.r);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
+ emit_setcc_al(&jit->cb, 0x05); emit_movzx_rax_al(&jit->cb);
+ size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0);
+ int32_t rel = (int32_t)(jit->cb.len - (jnz_pos + 6));
+ memcpy(jit->cb.buf + jnz_pos + 2, &rel, 4);
+ emit_movabs_rax_imm64(&jit->cb, 1);
+ rel = (int32_t)(jit->cb.len - (jmp_pos + 5));
+ memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4);
+ break;
+ }
+ // All other binary ops: eval left -> push; eval right -> RBX = left
+ gen_expr_jit(jit, e->binop.l);
+ emit_push_reg(&jit->cb, RAX);
+ gen_expr_jit(jit, e->binop.r);
+ emit_pop_reg(&jit->cb, RBX); // RBX = left, RAX = right
+ switch (e->binop.op) {
+ case TK_PLUS: emit_add_rax_rbx(&jit->cb); break;
+ case TK_MINUS:
+ emit_mov_reg_reg(&jit->cb, RCX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_rex(&jit->cb, RCX, RAX, 1); emit8(&jit->cb, 0x29); emit_modrm(&jit->cb, 3, RCX, RAX);
+ break;
+ case TK_STAR: emit_imul_rax_rbx(&jit->cb); break;
+ case TK_SLASH:
+ emit_mov_reg_reg(&jit->cb, RCX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RCX);
+ emit_idiv_rbx(&jit->cb); break;
+ case TK_PERCENT:
+ emit_mov_reg_reg(&jit->cb, RCX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RCX);
+ emit_imod(&jit->cb); break;
+ case TK_BAR: emit_or_rax_rbx(&jit->cb); break;
+ case TK_CARET: emit_xor_rax_rbx(&jit->cb); break;
+ case TK_AMP: emit_and_rax_rbx(&jit->cb); break;
+ case TK_SHL:
+ emit_mov_reg_reg(&jit->cb, RCX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_shl_rax_cl(&jit->cb); break;
+ case TK_SHR:
+ emit_mov_reg_reg(&jit->cb, RCX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_shr_rax_cl(&jit->cb); break;
+ case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE: {
+ emit_mov_reg_reg(&jit->cb, RCX, RAX); // RCX = right
+ emit_mov_reg_reg(&jit->cb, RAX, RBX); // RAX = left
+ emit_mov_reg_reg(&jit->cb, RBX, RCX); // RBX = right
+ emit_cmp_rax_rbx(&jit->cb);
+ uint8_t cc;
+ switch (e->binop.op) {
+ case TK_EQ: cc=0x04; break; case TK_NE: cc=0x05; break;
+ case TK_LT: cc=0x0C; break; case TK_LE: cc=0x0E; break;
+ case TK_GT: cc=0x0F; break; case TK_GE: cc=0x0D; break;
+ default: cc=0x04; break;
+ }
+ emit_setcc_al(&jit->cb, cc); emit_movzx_rax_al(&jit->cb);
+ break;
+ }
+ default:
+ fprintf(stderr, "[JIT] Unsupported binop %d\n", e->binop.op); exit(1);
+ }
+ break;
+ }
+
+ case EX_CALL: {
+ int total_args = e->call.argc;
+ int stack_args = total_args > 6 ? total_args - 6 : 0;
+ int padding = (stack_args % 2) ? 8 : 0;
+ const int arg_regs[6] = {RDI, RSI, RDX, RCX, R8, R9};
+
+ for (int i = total_args-1; i >= 6; i--) {
+ gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX);
+ }
+ for (int i = 0; i < total_args && i < 6; i++) {
+ gen_expr_jit(jit, e->call.args[i]); emit_mov_reg_reg(&jit->cb, arg_regs[i], RAX);
+ }
+ if (padding) {
+ emit8(&jit->cb,0x48); emit8(&jit->cb,0x83); emit8(&jit->cb,0xEC); emit8(&jit->cb,0x08);
+ }
+
+ void *addr = get_func_addr(jit, e->call.func_name);
+ if (addr == (void*)0xDEADBEEF) {
+ // Forward call: emit placeholder imm64, record offset for later patching
+ emit_movabs_rax_imm64(&jit->cb, (uint64_t)0xDEADBEEF);
+ PatchEntry *patch = (PatchEntry*)malloc(sizeof(PatchEntry));
+ if (!patch) { fprintf(stderr, "[JIT] malloc failed (patch)\n"); exit(1); }
+ patch->func_name = strdup(e->call.func_name);
+ patch->owning_func = strdup(jit->current_func_name);
+ if (!patch->func_name || !patch->owning_func) {
+ fprintf(stderr, "[JIT] strdup failed (patch)\n");
+ free(patch->func_name); free(patch->owning_func); free(patch); exit(1);
+ }
+ patch->offset = jit->cb.len - 8; // points at the imm64 in the buffer
+ patch->next = jit->patch_list;
+ jit->patch_list = patch;
+ } else {
+ emit_movabs_rax_imm64(&jit->cb, (uint64_t)addr);
+ }
+ emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); // CALL RAX
+
+ if (stack_args * 8 + padding > 0) {
+ emit8(&jit->cb,0x48); emit8(&jit->cb,0x81); emit8(&jit->cb,0xC4);
+ emit32(&jit->cb, (uint32_t)(stack_args*8 + padding));
+ }
+ break;
+ }
+
+ case EX_ADDR:
+ if (e->addr.expr->kind != EX_VAR) {
+ fprintf(stderr, "[JIT] &expr: only &var supported\n"); exit(1);
+ }
+ emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name));
+ break;
+
+ case EX_DEREF:
+ gen_expr_jit(jit, e->deref.expr);
+ emit_rex(&jit->cb, RAX, RAX, 1); emit8(&jit->cb, 0x8B); emit_modrm(&jit->cb, 0, RAX, RAX);
+ break;
+
+ case EX_STRING: {
+ char *s = strdup(e->string);
+ if (!s) { fprintf(stderr, "[JIT] strdup failed (string literal)\n"); exit(1); }
+ emit_movabs_rax_imm64(&jit->cb, (uint64_t)s);
+ break;
+ }
+
+ case EX_INDEX: {
+ if (e->index.array->kind != EX_VAR) {
+ gen_expr_jit(jit, e->index.array); break;
+ }
+ _TY var_type = get_var_type(jit, e->index.array->name);
+ int element_size = (var_type.base == TY_CHAR) ? 1 : 8;
+
+ if (var_type.ptr_level > 0) {
+ // Pointer indexing: load pointer, add index*element_size
+ int ptr_offset = get_var_offset(jit, e->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ gen_expr_jit(jit, e->index.index); // RAX = index
+ if (element_size > 1) {
+ emit_mov_reg_reg(&jit->cb, RCX, RBX); // save pointer
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ emit_mov_reg_reg(&jit->cb, RBX, RCX); // restore pointer
+ }
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index]
+ if (element_size == 1) {
+ emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6);
+ emit_modrm(&jit->cb, 0, RAX, RBX);
+ } else {
+ emit_mov_reg_mem_reg(&jit->cb, RAX, RBX);
+ }
+
+ } else if (var_type.array_size > 0) {
+ // Array indexing: compute RBP-relative address = array_offset - index*element_size
+ int array_offset = get_var_offset(jit, e->index.array->name);
+ gen_expr_jit(jit, e->index.index); // RAX = index
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
+ emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index]
+ emit_mov_reg_reg(&jit->cb, RBX, RAX);
+ if (element_size == 1) {
+ emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6);
+ emit_modrm(&jit->cb, 0, RAX, RBX);
+ } else {
+ emit_mov_reg_mem_reg(&jit->cb, RAX, RBX);
+ }
+
+ } else {
+ fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", e->index.array->name); exit(1);
+ }
+ break;
+ }
+
+ default:
+ fprintf(stderr, "[JIT] Unsupported expression kind %d\n", e->kind); exit(1);
+ }
+}
+
+/* --- Compile one function --- */
+
+static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) {
+ reset_varmap(jit);
+ jit->current_func_name = f->name;
+
+ int total_stack_size = calculate_stack_size(f->body);
+ cb_init(&jit->cb);
+ emit_prologue(&jit->cb, total_stack_size);
+
+ const int param_regs[6] = {RDI, RSI, RDX, RCX, R8, R9};
+ for (int i = 0; i < f->pac && i < 6; i++) {
+ add_var(jit, f->params[i], f->param_types[i]);
+ emit_mov_mem64_reg(&jit->cb, get_var_offset(jit, f->params[i]), param_regs[i]);
+ }
+
+ int did_return = gen_stmt_jit(jit, f->body);
+ if (!did_return) {
+ emit_movabs_rax_imm64(&jit->cb, 0);
+ emit_epilogue(&jit->cb);
+ }
+
+ size_t pagesz = (size_t)sysconf(_SC_PAGESIZE);
+ size_t alloc_size = ((jit->cb.len + pagesz - 1) / pagesz) * pagesz;
+ void *mem = mmap(NULL, alloc_size, PROT_READ|PROT_WRITE|PROT_EXEC,
+ MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
+ if (mem == MAP_FAILED) { perror("mmap"); cb_free(&jit->cb); return NULL; }
+ memcpy(mem, jit->cb.buf, jit->cb.len);
+ if (out_size) *out_size = jit->cb.len;
+ set_func_addr(jit, f->name, mem, jit->cb.len, alloc_size);
+ cb_free(&jit->cb);
+ return mem;
+}
+
+/* --- Patch forward calls after all functions are compiled --- */
+
+static void patch_function_calls(JIT *jit) {
+ for (PatchEntry *patch = jit->patch_list; patch; patch = patch->next) {
+ void *target = NULL;
+ for (FuncMap *f = jit->func_list; f; f = f->next)
+ if (strcmp(f->name, patch->func_name) == 0 && f->addr) { target = f->addr; break; }
+ if (!target) {
+ fprintf(stderr, "[JIT] patch: function '%s' not found\n", patch->func_name); exit(1);
+ }
+ void *owner = NULL;
+ for (FuncMap *f = jit->func_list; f; f = f->next)
+ if (strcmp(f->name, patch->owning_func) == 0 && f->addr) { owner = f->addr; break; }
+ if (!owner) {
+ fprintf(stderr, "[JIT] patch: owning function '%s' not found\n", patch->owning_func); exit(1);
+ }
+ *(void **)((uint8_t *)owner + patch->offset) = target;
+ }
+}
+
+/* --- Compile all functions and patch forward calls --- */
+
+static void jit_compile_all(JIT *jit, _FN *fn_list) {
+ for (_FN *cur = fn_list; cur; cur = cur->n)
+ register_func(jit, cur->name);
+
+ // Compile in reverse order so callees are typically compiled before callers
+ _FN *functions[64];
+ int count = 0;
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (count >= 64) { fprintf(stderr, "[JIT] Too many functions (max 64)\n"); exit(1); }
+ functions[count++] = cur;
+ }
+ for (int i = count-1; i >= 0; i--)
+ gen_function_jit(jit, functions[i], NULL);
+
+ patch_function_calls(jit);
+}
+
+/* --- Entry point --- */
+
+static int jit_run(JIT *jit, int argc, char **argv) {
+ int (*main_func)(int, char **) = get_func_addr(jit, "main");
+ fprintf(stderr, "[JIT] Running main (argc=%d argv=%p)\n", argc, (void*)argv);
+ return main_func(argc, argv);
+}
+
+#endif \ No newline at end of file