diff options
| author | David Moc <personal@cdatgoose.org> | 2026-03-05 23:38:49 +0100 |
|---|---|---|
| committer | David Moc <personal@cdatgoose.org> | 2026-03-05 23:38:49 +0100 |
| commit | 0385817bb1301a778bb33f8405a435293b9f8905 (patch) | |
| tree | 53f4b6f13e393bd368c37ba4363826b46940dfd3 /src/codegen_jit.h | |
| parent | 262abf9b552a168ef3ae91f91af97683f16420a7 (diff) | |
Pushing to repo for safety.
Diffstat (limited to 'src/codegen_jit.h')
| -rw-r--r-- | src/codegen_jit.h | 827 |
1 files changed, 827 insertions, 0 deletions
diff --git a/src/codegen_jit.h b/src/codegen_jit.h new file mode 100644 index 0000000..ff1d687 --- /dev/null +++ b/src/codegen_jit.h @@ -0,0 +1,827 @@ +#ifndef CODEGEN_JIT_H +#define CODEGEN_JIT_H + +#include "ast.h" +#include "token.h" + +#include <assert.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/mman.h> +#include <unistd.h> +#include <limits.h> +#include <ctype.h> + +enum { + RAX = 0, RCX = 1, RDX = 2, RBX = 3, + RSP = 4, RBP = 5, RSI = 6, RDI = 7, + R8 = 8, R9 = 9, R10 = 10, R11 = 11, + R12 = 12, R13 = 13, R14 = 14, R15 = 15 +}; + +typedef struct { + uint8_t *buf; + size_t len; + size_t cap; +} CodeBuf; + +typedef struct FuncMap { + char *name; + void *addr; + size_t size; + size_t alloc_size; + struct FuncMap *next; +} FuncMap; + +typedef struct PatchEntry { + size_t offset; // offset of the imm64 within owning function's code + char *func_name; // target function to patch in + char *owning_func; + struct PatchEntry *next; +} PatchEntry; + +typedef struct VarMap { + char *name; + int offset; // RBP-relative offset + _TY type; + struct VarMap *next; +} VarMap; + +typedef struct { + FuncMap *func_list; + VarMap *var_list; + int next_local_offset; + CodeBuf cb; + char *current_func_name; + PatchEntry *patch_list; +} JIT; + +static void jit_init(JIT *jit) { + jit->func_list = NULL; + jit->var_list = NULL; + jit->next_local_offset = -8; + jit->cb.buf = NULL; + jit->cb.len = jit->cb.cap = 0; + jit->current_func_name = NULL; + jit->patch_list = NULL; +} + +static void jit_free(JIT *jit) { + for (VarMap *v = jit->var_list; v;) { + VarMap *n = v->next; free(v->name); free(v); v = n; + } + jit->var_list = NULL; + + for (PatchEntry *p = jit->patch_list; p;) { + PatchEntry *n = p->next; free(p->func_name); free(p->owning_func); free(p); p = n; + } + jit->patch_list = NULL; + + for (FuncMap *f = jit->func_list; f;) { + FuncMap *n = f->next; + if (f->addr && f->alloc_size > 0) munmap(f->addr, f->alloc_size); + free(f->name); free(f); f = n; + } + jit->func_list = NULL; + + free(jit->cb.buf); + jit->cb.buf = NULL; + jit->cb.len = jit->cb.cap = 0; +} + +/* --- Code buffer --- */ + +static void cb_init(CodeBuf *c) { + c->cap = 1024; c->len = 0; + c->buf = (uint8_t *)malloc(c->cap); + if (!c->buf) { fprintf(stderr, "[JIT] failed to allocate code buffer\n"); exit(1); } +} +static void cb_free(CodeBuf *c) { + free(c->buf); c->buf = NULL; c->len = c->cap = 0; +} +static void cb_grow(CodeBuf *c, size_t need) { + if (c->len + need > c->cap) { + while (c->len + need > c->cap) c->cap *= 2; + uint8_t *nb = (uint8_t *)realloc(c->buf, c->cap); + if (!nb) { fprintf(stderr, "[JIT] failed to grow code buffer\n"); free(c->buf); exit(1); } + c->buf = nb; + } +} +static void emit8(CodeBuf *c, uint8_t b) { cb_grow(c,1); c->buf[c->len++] = b; } +static void emit32(CodeBuf *c, uint32_t v) { cb_grow(c,4); memcpy(c->buf+c->len,&v,4); c->len+=4; } +static void emit64(CodeBuf *c, uint64_t v) { cb_grow(c,8); memcpy(c->buf+c->len,&v,8); c->len+=8; } +static void emitN(CodeBuf *c, const void *p, size_t n) { cb_grow(c,n); memcpy(c->buf+c->len,p,n); c->len+=n; } + +/* --- x86-64 encoding helpers --- */ + +static void emit_rex(CodeBuf *c, int reg, int rm, int w) { + uint8_t rex = 0x40; + if (w) rex |= 0x08; + if (reg&8) rex |= 0x04; + if (rm&8) rex |= 0x01; + emit8(c, (rex != 0x40 || w) ? rex : 0x48); +} +static void emit_modrm(CodeBuf *c, uint8_t mod, uint8_t reg, uint8_t rm) { + emit8(c, (mod<<6)|((reg&7)<<3)|(rm&7)); +} + +static void emit_movabs_rax_imm64(CodeBuf *c, uint64_t imm) { + emit8(c,0x48); emit8(c,0xB8); emit64(c,imm); +} +static void emit_mov_reg_imm64(CodeBuf *c, int reg, int64_t imm) { + emit_rex(c,reg,0,1); emit8(c,0xB8+(reg&7)); emit64(c,imm); +} +static void emit_mov_mem64_reg(CodeBuf *c, int disp32, int reg) { + emit_rex(c,reg,RBP,1); emit8(c,0x89); emit_modrm(c,2,reg,RBP); emit32(c,(uint32_t)disp32); +} +static void emit_mov_reg_mem64(CodeBuf *c, int reg, int disp32) { + emit_rex(c,reg,RBP,1); emit8(c,0x8B); emit_modrm(c,2,reg,RBP); emit32(c,(uint32_t)disp32); +} +static void emit_mov_reg_mem_reg(CodeBuf *c, int dst, int base) { + emit_rex(c,dst,base,1); emit8(c,0x8B); emit_modrm(c,0,dst,base); +} +static void emit_mov_mem_reg_reg(CodeBuf *c, int base, int src) { + emit_rex(c,src,base,1); emit8(c,0x89); emit_modrm(c,0,src,base); +} +static void emit_movzx_rax_mem8(CodeBuf *c, int disp32) { + emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB6); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); +} +static void emit_mov_mem8_rax(CodeBuf *c, int disp32) { + emit_rex(c,RAX,RBP,0); emit8(c,0x88); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); +} +static void emit_mov_reg_reg(CodeBuf *c, int dst, int src) { + emit_rex(c,src,dst,1); emit8(c,0x89); emit_modrm(c,3,src,dst); +} +static void emit_add_reg_reg(CodeBuf *c, int dst, int src) { + emit_rex(c,src,dst,1); emit8(c,0x01); emit_modrm(c,3,src,dst); +} +static void emit_sub_reg_reg(CodeBuf *c, int dst, int src) { + emit_rex(c,src,dst,1); emit8(c,0x29); emit_modrm(c,3,src,dst); +} +static void emit_add_reg_imm32(CodeBuf *c, int dst, int32_t imm) { + emit_rex(c,0,dst,1); emit8(c,0x81); emit_modrm(c,3,0,dst); emit32(c,(uint32_t)imm); +} +static void emit_push_reg(CodeBuf *c, int reg) { + if (reg&8) emit8(c,0x41); emit8(c,0x50+(reg&7)); +} +static void emit_pop_reg(CodeBuf *c, int reg) { + if (reg&8) emit8(c,0x41); emit8(c,0x58+(reg&7)); +} + +static void emit_add_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x01,0xD8},3); } +static void emit_imul_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xAF,0xC3},4); } +static void emit_idiv_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); } +static void emit_imod(CodeBuf *c) { + emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); // cqo; idiv rbx + emit_mov_reg_reg(c, RAX, RDX); // remainder -> RAX +} +static void emit_or_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x09,0xD8},3); } +static void emit_xor_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x31,0xD8},3); } +static void emit_and_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x21,0xD8},3); } +static void emit_shl_rax_cl(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0xD3,0xE0},3); } +static void emit_shr_rax_cl(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0xD3,0xE8},3); } +static void emit_cmp_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x39,0xD8},3); } +static void emit_setcc_al(CodeBuf *c, uint8_t cc) { emitN(c,(uint8_t[]){0x0F,0x90|cc,0xC0},3); } +static void emit_movzx_rax_al(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xB6,0xC0},4); } +static void emit_jmp_rel32(CodeBuf *c, int32_t rel) { emit8(c,0xE9); emit32(c,(uint32_t)rel); } +static void emit_jcc_rel32(CodeBuf *c, uint8_t cc, int32_t rel) { + emitN(c,(uint8_t[]){0x0F,(uint8_t)(0x80|cc)},2); emit32(c,(uint32_t)rel); +} +static void emit_test_rax_rax(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x85,0xC0},3); } + +static void emit_prologue(CodeBuf *c, int total_stack_size) { + emit8(c,0x55); // push rbp + emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); // mov rbp, rsp + emit_push_reg(c, RBX); // save callee-saved RBX + int stack_bytes = ((total_stack_size+15)/16)*16; + if (stack_bytes > 0) { + emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); // sub rsp, imm32 + emit32(c,(uint32_t)stack_bytes); + } +} +static void emit_epilogue(CodeBuf *c) { + emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); // restore RBX; leave; ret +} +static void emit_lea_rax_rbp_disp(CodeBuf *c, int disp32) { + emit8(c,0x48); emit8(c,0x8D); emit_modrm(c,2,RAX,RBP); emit32(c,(uint32_t)disp32); +} +static void emit_load_rax_from_mem(CodeBuf *c, int disp32, int size) { + if (size == 1) emit_movzx_rax_mem8(c, disp32); else emit_mov_reg_mem64(c, RAX, disp32); +} +static void emit_store_rax_to_mem(CodeBuf *c, int disp32, int size) { + if (size == 1) emit_mov_mem8_rax(c, disp32); else emit_mov_mem64_reg(c, disp32, RAX); +} + +/* --- Variable and type helpers --- */ + +static int calculate_type_size(_TY type) { + if (type.ptr_level > 0) return 8; + int base_size = (type.base == TY_CHAR) ? 1 : 8; + size_t total = (size_t)base_size; + if (type.array_size > 0) { + total *= (size_t)type.array_size; + if (total > (size_t)INT_MAX) { fprintf(stderr, "[JIT] array size too large\n"); exit(1); } + } + return (int)total; +} + +static int align_offset(int offset, _TY type) { + int align = (type.base == TY_CHAR) ? 1 : 8; + if (offset % align == 0) return offset; + return (offset / align) * align; +} + +static void reset_varmap(JIT *jit) { + for (VarMap *v = jit->var_list; v;) { + VarMap *n = v->next; free(v->name); free(v); v = n; + } + jit->var_list = NULL; + jit->next_local_offset = -16; // after saved RBX at RBP-8, 16-byte aligned +} + +static void add_var(JIT *jit, const char *name, _TY type) { + VarMap *v = (VarMap *)malloc(sizeof(VarMap)); + if (!v) { fprintf(stderr, "[JIT] malloc failed in add_var\n"); exit(1); } + v->name = strdup(name); + if (!v->name) { fprintf(stderr, "[JIT] strdup failed in add_var\n"); free(v); exit(1); } + v->type = type; + int type_size = calculate_type_size(type); + jit->next_local_offset = align_offset(jit->next_local_offset, type); + v->offset = jit->next_local_offset; + jit->next_local_offset -= type_size; + v->next = jit->var_list; + jit->var_list = v; +} + +static int get_var_offset(JIT *jit, const char *name) { + for (VarMap *v = jit->var_list; v; v = v->next) + if (strcmp(v->name, name) == 0) return v->offset; + fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1); +} + +static _TY get_var_type(JIT *jit, const char *name) { + for (VarMap *v = jit->var_list; v; v = v->next) + if (strcmp(v->name, name) == 0) return v->type; + fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1); +} + +/* --- Type checking --- */ + +static _TY get_expr_type(JIT *jit, _EX *expr) { + switch (expr->kind) { + case EX_NUMBER: return (_TY){TY_INT, 0, -1}; + case EX_STRING: return (_TY){TY_CHAR, 1, -1}; + case EX_VAR: return get_var_type(jit, expr->name); + case EX_BINOP: { + _TY left_type = get_expr_type(jit, expr->binop.l); + // All arithmetic, comparison, bitwise ops produce int + (void)get_expr_type(jit, expr->binop.r); + switch (expr->binop.op) { + case TK_PLUS: case TK_MINUS: case TK_STAR: case TK_SLASH: case TK_PERCENT: + case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE: + case TK_AND: case TK_OR: + case TK_AMP: case TK_BAR: case TK_CARET: case TK_SHL: case TK_SHR: + return (_TY){TY_INT, 0, -1}; + default: return left_type; + } + } + case EX_CALL: return (_TY){TY_INT, 0, -1}; + case EX_INDEX: { + _TY t = get_expr_type(jit, expr->index.array); + if (t.array_size > 0) return (_TY){t.base, t.ptr_level, -1}; + if (t.ptr_level > 0) return (_TY){t.base, t.ptr_level-1, -1}; + return (_TY){TY_INT, 0, -1}; + } + case EX_DEREF: { + _TY t = get_expr_type(jit, expr->deref.expr); + if (t.ptr_level > 0) return (_TY){t.base, t.ptr_level-1, -1}; + return (_TY){TY_INT, 0, -1}; + } + case EX_ADDR: { + _TY t = get_expr_type(jit, expr->addr.expr); + return (_TY){t.base, t.ptr_level+1, -1}; + } + default: return (_TY){TY_INT, 0, -1}; + } +} + +static int types_compatible(_TY expected, _TY actual) { + if (expected.base == actual.base && + expected.ptr_level == actual.ptr_level && + expected.array_size == actual.array_size) return 1; + // Allow untyped int literals to be assigned anywhere + if (actual.base == TY_INT && actual.ptr_level == 0 && actual.array_size == -1) return 1; + return 0; +} + +/* --- Function registry --- */ + +static void register_func(JIT *jit, const char *name) { + FuncMap *f = (FuncMap *)malloc(sizeof(FuncMap)); + if (!f) { fprintf(stderr, "[JIT] malloc failed in register_func\n"); exit(1); } + f->name = strdup(name); + if (!f->name) { fprintf(stderr, "[JIT] strdup failed in register_func\n"); free(f); exit(1); } + f->addr = NULL; f->size = 0; f->alloc_size = 0; + f->next = jit->func_list; + jit->func_list = f; +} + +static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, size_t alloc_size) { + for (FuncMap *f = jit->func_list; f; f = f->next) { + if (strcmp(f->name, name) != 0) continue; + f->addr = addr; f->size = size; f->alloc_size = alloc_size; + + fprintf(stderr, "\n[JIT] %s @ %p (%zu bytes)\n", name, addr, size); + uint8_t *p = (uint8_t *)addr; + for (size_t i = 0; i < size; i += 16) { + fprintf(stderr, "\t%04zx ", i); + for (size_t j = 0; j < 16; j++) + (i+j < size) ? fprintf(stderr, "%02x ", p[i+j]) : fprintf(stderr, " "); + fprintf(stderr, " |"); + for (size_t j = 0; j < 16 && i+j < size; j++) + fprintf(stderr, "%c", isprint(p[i+j]) ? p[i+j] : '.'); + fprintf(stderr, "\n"); + } + return; + } + fprintf(stderr, "[JIT] set_func_addr: unknown function '%s'\n", name); exit(1); +} + +static void *get_func_addr(JIT *jit, const char *name) { + for (FuncMap *f = jit->func_list; f; f = f->next) { + if (strcmp(f->name, name) == 0) + return f->addr ? f->addr : (void*)0xDEADBEEF; // placeholder; patched later + } + fprintf(stderr, "[JIT] get_func_addr: unknown function '%s'\n", name); exit(1); +} + +/* --- Stack size calculation --- */ + +static int calculate_stack_size(_STN *s) { + int total = 0; + while (s) { + if (s->kind == STK_VAR_DECL) total += calculate_type_size(s->var_decl.type); + if (s->kind == STK_BLOCK) total += calculate_stack_size(s->body); + if (s->kind == STK_FOR) { + if (s->fr.init) total += calculate_stack_size(s->fr.init); + total += calculate_stack_size(s->fr.body); + } + s = s->n; + } + return total; +} + +/* --- Code generation (forward declarations) --- */ + +static void gen_expr_jit(JIT *jit, _EX *e); +static int gen_stmt_jit(JIT *jit, _STN *s); + +/* --- Statement generation --- */ + +static int gen_stmt_jit(JIT *jit, _STN *s) { + while (s) { + switch (s->kind) { + + case STK_VAR_DECL: + add_var(jit, s->var_decl.name, s->var_decl.type); + if (s->var_decl.init) { + _TY init_type = get_expr_type(jit, s->var_decl.init); + if (!types_compatible(s->var_decl.type, init_type)) { + fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n", + tybase_name(init_type.base), tybase_name(s->var_decl.type.base)); + exit(1); + } + gen_expr_jit(jit, s->var_decl.init); + int sz = (s->var_decl.type.ptr_level > 0) ? 8 + : (s->var_decl.type.base == TY_CHAR) ? 1 : 8; + emit_store_rax_to_mem(&jit->cb, get_var_offset(jit, s->var_decl.name), sz); + } + break; + + case STK_ASSIGN: { + _EX *lhs = s->assign.lhs; + if (lhs->kind == EX_VAR) { + int offset = get_var_offset(jit, lhs->name); + _TY type = get_var_type(jit, lhs->name); + _TY expr_type = get_expr_type(jit, s->assign.expr); + if (!types_compatible(type, expr_type)) { + fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n", + tybase_name(expr_type.base), tybase_name(type.base)); + exit(1); + } + gen_expr_jit(jit, s->assign.expr); + int sz = (type.ptr_level > 0) ? 8 : (type.base == TY_CHAR) ? 1 : 8; + emit_store_rax_to_mem(&jit->cb, offset, sz); + + } else if (lhs->kind == EX_DEREF) { + gen_expr_jit(jit, lhs->deref.expr); + emit_mov_reg_reg(&jit->cb, RBX, RAX); // RBX = address + gen_expr_jit(jit, s->assign.expr); // RAX = value + emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + + } else if (lhs->kind == EX_INDEX) { + if (lhs->index.array->kind != EX_VAR) { + gen_expr_jit(jit, lhs->index.array); break; + } + _TY var_type = get_var_type(jit, lhs->index.array->name); + if (var_type.array_size <= 0) { + fprintf(stderr, "[JIT] Cannot index non-array '%s'\n", lhs->index.array->name); exit(1); + } + int array_offset = get_var_offset(jit, lhs->index.array->name); + int element_size = (var_type.base == TY_CHAR) ? 1 : 8; + gen_expr_jit(jit, lhs->index.index); // RAX = index + emit_mov_reg_imm64(&jit->cb, RBX, element_size); + emit_imul_rax_rbx(&jit->cb); // RAX = byte offset + emit_mov_reg_imm64(&jit->cb, RBX, array_offset); + emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RBP); + emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index] + emit_mov_reg_reg(&jit->cb, RBX, RAX); + gen_expr_jit(jit, s->assign.expr); // RAX = value + emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + + } else { + fprintf(stderr, "[JIT] Unsupported assignment LHS kind %d\n", lhs->kind); exit(1); + } + break; + } + + case STK_RETURN: + if (s->return_expr) gen_expr_jit(jit, s->return_expr); + emit_epilogue(&jit->cb); + return 1; + + case STK_EXPR: + gen_expr_jit(jit, s->expr); + break; + + case STK_BLOCK: { + int r = gen_stmt_jit(jit, s->body); + if (r) return 1; + break; + } + + case STK_IF: { + gen_expr_jit(jit, s->ifs.cond); + emit_test_rax_rax(&jit->cb); + size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); // JZ else + gen_stmt_jit(jit, s->ifs.thenb); + size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); // JMP end + int32_t rel_else = (int32_t)(jit->cb.len - (jz_pos + 6)); + memcpy(jit->cb.buf + jz_pos + 2, &rel_else, 4); + if (s->ifs.elseb) gen_stmt_jit(jit, s->ifs.elseb); + int32_t rel_end = (int32_t)(jit->cb.len - (jmp_pos + 5)); + memcpy(jit->cb.buf + jmp_pos + 1, &rel_end, 4); + break; + } + + case STK_WHILE: { + size_t loop_start = jit->cb.len; + gen_expr_jit(jit, s->whl.cond); + emit_test_rax_rax(&jit->cb); + size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); + gen_stmt_jit(jit, s->whl.body); + emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); + int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6)); + memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4); + break; + } + + case STK_FOR: { + if (s->fr.init) gen_stmt_jit(jit, s->fr.init); + size_t loop_start = jit->cb.len; + if (s->fr.cond) { + gen_expr_jit(jit, s->fr.cond); + emit_test_rax_rax(&jit->cb); + size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); + gen_stmt_jit(jit, s->fr.body); + if (s->fr.step) gen_stmt_jit(jit, s->fr.step); + emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); + int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6)); + memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4); + } else { + gen_stmt_jit(jit, s->fr.body); + if (s->fr.step) gen_stmt_jit(jit, s->fr.step); + emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); + } + break; + } + + default: + fprintf(stderr, "[JIT] Unsupported statement kind %d\n", s->kind); exit(1); + } + s = s->n; + } + return 0; +} + +/* --- Expression generation --- */ + +static void gen_expr_jit(JIT *jit, _EX *e) { + if (!e) return; + switch (e->kind) { + + case EX_NUMBER: + emit_movabs_rax_imm64(&jit->cb, (uint64_t)e->value); + break; + + case EX_VAR: { + int off = get_var_offset(jit, e->name); + _TY ty = get_var_type(jit, e->name); + int sz = (ty.ptr_level > 0) ? 8 : (ty.base == TY_CHAR) ? 1 : 8; + if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off); + else emit_mov_reg_mem64(&jit->cb, RAX, off); + break; + } + + case EX_BINOP: { + // Short-circuit AND + if (e->binop.op == TK_AND) { + gen_expr_jit(jit, e->binop.l); + emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); + size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); + gen_expr_jit(jit, e->binop.r); + emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); + emit_setcc_al(&jit->cb, 0x05); emit_movzx_rax_al(&jit->cb); + size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); + int32_t rel = (int32_t)(jit->cb.len - (jz_pos + 6)); + memcpy(jit->cb.buf + jz_pos + 2, &rel, 4); + emit_movabs_rax_imm64(&jit->cb, 0); + rel = (int32_t)(jit->cb.len - (jmp_pos + 5)); + memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4); + break; + } + // Short-circuit OR + if (e->binop.op == TK_OR) { + gen_expr_jit(jit, e->binop.l); + emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); + size_t jnz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x05, 0); + gen_expr_jit(jit, e->binop.r); + emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); + emit_setcc_al(&jit->cb, 0x05); emit_movzx_rax_al(&jit->cb); + size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); + int32_t rel = (int32_t)(jit->cb.len - (jnz_pos + 6)); + memcpy(jit->cb.buf + jnz_pos + 2, &rel, 4); + emit_movabs_rax_imm64(&jit->cb, 1); + rel = (int32_t)(jit->cb.len - (jmp_pos + 5)); + memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4); + break; + } + // All other binary ops: eval left -> push; eval right -> RBX = left + gen_expr_jit(jit, e->binop.l); + emit_push_reg(&jit->cb, RAX); + gen_expr_jit(jit, e->binop.r); + emit_pop_reg(&jit->cb, RBX); // RBX = left, RAX = right + switch (e->binop.op) { + case TK_PLUS: emit_add_rax_rbx(&jit->cb); break; + case TK_MINUS: + emit_mov_reg_reg(&jit->cb, RCX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_rex(&jit->cb, RCX, RAX, 1); emit8(&jit->cb, 0x29); emit_modrm(&jit->cb, 3, RCX, RAX); + break; + case TK_STAR: emit_imul_rax_rbx(&jit->cb); break; + case TK_SLASH: + emit_mov_reg_reg(&jit->cb, RCX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RCX); + emit_idiv_rbx(&jit->cb); break; + case TK_PERCENT: + emit_mov_reg_reg(&jit->cb, RCX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RCX); + emit_imod(&jit->cb); break; + case TK_BAR: emit_or_rax_rbx(&jit->cb); break; + case TK_CARET: emit_xor_rax_rbx(&jit->cb); break; + case TK_AMP: emit_and_rax_rbx(&jit->cb); break; + case TK_SHL: + emit_mov_reg_reg(&jit->cb, RCX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_shl_rax_cl(&jit->cb); break; + case TK_SHR: + emit_mov_reg_reg(&jit->cb, RCX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_shr_rax_cl(&jit->cb); break; + case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE: { + emit_mov_reg_reg(&jit->cb, RCX, RAX); // RCX = right + emit_mov_reg_reg(&jit->cb, RAX, RBX); // RAX = left + emit_mov_reg_reg(&jit->cb, RBX, RCX); // RBX = right + emit_cmp_rax_rbx(&jit->cb); + uint8_t cc; + switch (e->binop.op) { + case TK_EQ: cc=0x04; break; case TK_NE: cc=0x05; break; + case TK_LT: cc=0x0C; break; case TK_LE: cc=0x0E; break; + case TK_GT: cc=0x0F; break; case TK_GE: cc=0x0D; break; + default: cc=0x04; break; + } + emit_setcc_al(&jit->cb, cc); emit_movzx_rax_al(&jit->cb); + break; + } + default: + fprintf(stderr, "[JIT] Unsupported binop %d\n", e->binop.op); exit(1); + } + break; + } + + case EX_CALL: { + int total_args = e->call.argc; + int stack_args = total_args > 6 ? total_args - 6 : 0; + int padding = (stack_args % 2) ? 8 : 0; + const int arg_regs[6] = {RDI, RSI, RDX, RCX, R8, R9}; + + for (int i = total_args-1; i >= 6; i--) { + gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX); + } + for (int i = 0; i < total_args && i < 6; i++) { + gen_expr_jit(jit, e->call.args[i]); emit_mov_reg_reg(&jit->cb, arg_regs[i], RAX); + } + if (padding) { + emit8(&jit->cb,0x48); emit8(&jit->cb,0x83); emit8(&jit->cb,0xEC); emit8(&jit->cb,0x08); + } + + void *addr = get_func_addr(jit, e->call.func_name); + if (addr == (void*)0xDEADBEEF) { + // Forward call: emit placeholder imm64, record offset for later patching + emit_movabs_rax_imm64(&jit->cb, (uint64_t)0xDEADBEEF); + PatchEntry *patch = (PatchEntry*)malloc(sizeof(PatchEntry)); + if (!patch) { fprintf(stderr, "[JIT] malloc failed (patch)\n"); exit(1); } + patch->func_name = strdup(e->call.func_name); + patch->owning_func = strdup(jit->current_func_name); + if (!patch->func_name || !patch->owning_func) { + fprintf(stderr, "[JIT] strdup failed (patch)\n"); + free(patch->func_name); free(patch->owning_func); free(patch); exit(1); + } + patch->offset = jit->cb.len - 8; // points at the imm64 in the buffer + patch->next = jit->patch_list; + jit->patch_list = patch; + } else { + emit_movabs_rax_imm64(&jit->cb, (uint64_t)addr); + } + emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); // CALL RAX + + if (stack_args * 8 + padding > 0) { + emit8(&jit->cb,0x48); emit8(&jit->cb,0x81); emit8(&jit->cb,0xC4); + emit32(&jit->cb, (uint32_t)(stack_args*8 + padding)); + } + break; + } + + case EX_ADDR: + if (e->addr.expr->kind != EX_VAR) { + fprintf(stderr, "[JIT] &expr: only &var supported\n"); exit(1); + } + emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name)); + break; + + case EX_DEREF: + gen_expr_jit(jit, e->deref.expr); + emit_rex(&jit->cb, RAX, RAX, 1); emit8(&jit->cb, 0x8B); emit_modrm(&jit->cb, 0, RAX, RAX); + break; + + case EX_STRING: { + char *s = strdup(e->string); + if (!s) { fprintf(stderr, "[JIT] strdup failed (string literal)\n"); exit(1); } + emit_movabs_rax_imm64(&jit->cb, (uint64_t)s); + break; + } + + case EX_INDEX: { + if (e->index.array->kind != EX_VAR) { + gen_expr_jit(jit, e->index.array); break; + } + _TY var_type = get_var_type(jit, e->index.array->name); + int element_size = (var_type.base == TY_CHAR) ? 1 : 8; + + if (var_type.ptr_level > 0) { + // Pointer indexing: load pointer, add index*element_size + int ptr_offset = get_var_offset(jit, e->index.array->name); + emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer + gen_expr_jit(jit, e->index.index); // RAX = index + if (element_size > 1) { + emit_mov_reg_reg(&jit->cb, RCX, RBX); // save pointer + emit_mov_reg_imm64(&jit->cb, RBX, element_size); + emit_imul_rax_rbx(&jit->cb); // RAX = byte offset + emit_mov_reg_reg(&jit->cb, RBX, RCX); // restore pointer + } + emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index] + if (element_size == 1) { + emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); + emit_modrm(&jit->cb, 0, RAX, RBX); + } else { + emit_mov_reg_mem_reg(&jit->cb, RAX, RBX); + } + + } else if (var_type.array_size > 0) { + // Array indexing: compute RBP-relative address = array_offset - index*element_size + int array_offset = get_var_offset(jit, e->index.array->name); + gen_expr_jit(jit, e->index.index); // RAX = index + emit_mov_reg_imm64(&jit->cb, RBX, element_size); + emit_imul_rax_rbx(&jit->cb); // RAX = byte offset + emit_mov_reg_imm64(&jit->cb, RBX, array_offset); + emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RBP); + emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index] + emit_mov_reg_reg(&jit->cb, RBX, RAX); + if (element_size == 1) { + emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); + emit_modrm(&jit->cb, 0, RAX, RBX); + } else { + emit_mov_reg_mem_reg(&jit->cb, RAX, RBX); + } + + } else { + fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", e->index.array->name); exit(1); + } + break; + } + + default: + fprintf(stderr, "[JIT] Unsupported expression kind %d\n", e->kind); exit(1); + } +} + +/* --- Compile one function --- */ + +static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) { + reset_varmap(jit); + jit->current_func_name = f->name; + + int total_stack_size = calculate_stack_size(f->body); + cb_init(&jit->cb); + emit_prologue(&jit->cb, total_stack_size); + + const int param_regs[6] = {RDI, RSI, RDX, RCX, R8, R9}; + for (int i = 0; i < f->pac && i < 6; i++) { + add_var(jit, f->params[i], f->param_types[i]); + emit_mov_mem64_reg(&jit->cb, get_var_offset(jit, f->params[i]), param_regs[i]); + } + + int did_return = gen_stmt_jit(jit, f->body); + if (!did_return) { + emit_movabs_rax_imm64(&jit->cb, 0); + emit_epilogue(&jit->cb); + } + + size_t pagesz = (size_t)sysconf(_SC_PAGESIZE); + size_t alloc_size = ((jit->cb.len + pagesz - 1) / pagesz) * pagesz; + void *mem = mmap(NULL, alloc_size, PROT_READ|PROT_WRITE|PROT_EXEC, + MAP_PRIVATE|MAP_ANONYMOUS, -1, 0); + if (mem == MAP_FAILED) { perror("mmap"); cb_free(&jit->cb); return NULL; } + memcpy(mem, jit->cb.buf, jit->cb.len); + if (out_size) *out_size = jit->cb.len; + set_func_addr(jit, f->name, mem, jit->cb.len, alloc_size); + cb_free(&jit->cb); + return mem; +} + +/* --- Patch forward calls after all functions are compiled --- */ + +static void patch_function_calls(JIT *jit) { + for (PatchEntry *patch = jit->patch_list; patch; patch = patch->next) { + void *target = NULL; + for (FuncMap *f = jit->func_list; f; f = f->next) + if (strcmp(f->name, patch->func_name) == 0 && f->addr) { target = f->addr; break; } + if (!target) { + fprintf(stderr, "[JIT] patch: function '%s' not found\n", patch->func_name); exit(1); + } + void *owner = NULL; + for (FuncMap *f = jit->func_list; f; f = f->next) + if (strcmp(f->name, patch->owning_func) == 0 && f->addr) { owner = f->addr; break; } + if (!owner) { + fprintf(stderr, "[JIT] patch: owning function '%s' not found\n", patch->owning_func); exit(1); + } + *(void **)((uint8_t *)owner + patch->offset) = target; + } +} + +/* --- Compile all functions and patch forward calls --- */ + +static void jit_compile_all(JIT *jit, _FN *fn_list) { + for (_FN *cur = fn_list; cur; cur = cur->n) + register_func(jit, cur->name); + + // Compile in reverse order so callees are typically compiled before callers + _FN *functions[64]; + int count = 0; + for (_FN *cur = fn_list; cur; cur = cur->n) { + if (count >= 64) { fprintf(stderr, "[JIT] Too many functions (max 64)\n"); exit(1); } + functions[count++] = cur; + } + for (int i = count-1; i >= 0; i--) + gen_function_jit(jit, functions[i], NULL); + + patch_function_calls(jit); +} + +/* --- Entry point --- */ + +static int jit_run(JIT *jit, int argc, char **argv) { + int (*main_func)(int, char **) = get_func_addr(jit, "main"); + fprintf(stderr, "[JIT] Running main (argc=%d argv=%p)\n", argc, (void*)argv); + return main_func(argc, argv); +} + +#endif
\ No newline at end of file |
