#ifndef CODEGEN_JIT_H #define CODEGEN_JIT_H #include "ast.h" #include "token.h" #include #include #include #include #include #include #include #include #include enum { RAX = 0, RCX = 1, RDX = 2, RBX = 3, RSP = 4, RBP = 5, RSI = 6, RDI = 7, R8 = 8, R9 = 9, R10 = 10, R11 = 11, R12 = 12, R13 = 13, R14 = 14, R15 = 15 }; typedef struct { uint8_t *buf; size_t len; size_t cap; } CodeBuf; typedef struct FuncMap { char *name; void *addr; size_t size; size_t alloc_size; _TY ret_type; /* return type of this function */ struct FuncMap *next; } FuncMap; typedef struct PatchEntry { size_t offset; // offset of the imm64 within owning function's code char *func_name; // target function to patch in char *owning_func; struct PatchEntry *next; } PatchEntry; typedef struct VarMap { char *name; int offset; // RBP-relative offset _TY type; struct VarMap *next; } VarMap; typedef struct GlobalVar { char *name; uint8_t *addr; /* pointer into globals_buf */ _TY type; int size; /* total bytes allocated */ struct GlobalVar *next; } GlobalVar; typedef struct { FuncMap *func_list; VarMap *var_list; int next_local_offset; CodeBuf cb; char *current_func_name; PatchEntry *patch_list; /* Globals data segment */ uint8_t *globals_buf; /* mmap'd RW page(s) for global variables */ size_t globals_cap; size_t globals_used; GlobalVar *global_list; /* Break/continue patch stacks (up to 64 nesting levels) */ size_t break_patches[64][256]; /* offsets to patch */ int break_patch_count[64]; size_t cont_patches[64][256]; int cont_patch_count[64]; int loop_depth; } JIT; static void jit_init(JIT *jit) { jit->func_list = NULL; jit->var_list = NULL; jit->next_local_offset = -8; jit->cb.buf = NULL; jit->cb.len = jit->cb.cap = 0; jit->current_func_name = NULL; jit->patch_list = NULL; jit->globals_cap = 1024 * 1024; jit->globals_used = 0; jit->globals_buf = (uint8_t *)mmap(NULL, jit->globals_cap, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0); if (jit->globals_buf == MAP_FAILED) { perror("mmap globals"); exit(1); } memset(jit->globals_buf, 0, jit->globals_cap); jit->global_list = NULL; jit->loop_depth = 0; memset(jit->break_patch_count, 0, sizeof(jit->break_patch_count)); memset(jit->cont_patch_count, 0, sizeof(jit->cont_patch_count)); } static void jit_free(JIT *jit) { for (VarMap *v = jit->var_list; v;) { VarMap *n = v->next; free(v->name); free(v); v = n; } jit->var_list = NULL; for (PatchEntry *p = jit->patch_list; p;) { PatchEntry *n = p->next; free(p->func_name); free(p->owning_func); free(p); p = n; } jit->patch_list = NULL; for (GlobalVar *g = jit->global_list; g;) { GlobalVar *n = g->next; free(g->name); free(g); g = n; } jit->global_list = NULL; if (jit->globals_buf && jit->globals_buf != MAP_FAILED) munmap(jit->globals_buf, jit->globals_cap); for (FuncMap *f = jit->func_list; f;) { FuncMap *n = f->next; if (f->addr && f->alloc_size > 0) munmap(f->addr, f->alloc_size); free(f->name); free(f); f = n; } jit->func_list = NULL; free(jit->cb.buf); jit->cb.buf = NULL; jit->cb.len = jit->cb.cap = 0; } static void cb_init(CodeBuf *c) { c->cap = 1024; c->len = 0; c->buf = (uint8_t *)malloc(c->cap); if (!c->buf) { fprintf(stderr, "[JIT] failed to allocate code buffer\n"); exit(1); } } static void cb_free(CodeBuf *c) { free(c->buf); c->buf = NULL; c->len = c->cap = 0; } static void cb_grow(CodeBuf *c, size_t need) { if (c->len + need > c->cap) { while (c->len + need > c->cap) c->cap *= 2; uint8_t *nb = (uint8_t *)realloc(c->buf, c->cap); if (!nb) { fprintf(stderr, "[JIT] failed to grow code buffer\n"); free(c->buf); exit(1); } c->buf = nb; } } static void emit8(CodeBuf *c, uint8_t b) { cb_grow(c,1); c->buf[c->len++] = b; } static void emit32(CodeBuf *c, uint32_t v) { cb_grow(c,4); memcpy(c->buf+c->len,&v,4); c->len+=4; } static void emit64(CodeBuf *c, uint64_t v) { cb_grow(c,8); memcpy(c->buf+c->len,&v,8); c->len+=8; } static void emitN(CodeBuf *c, const void *p, size_t n) { cb_grow(c,n); memcpy(c->buf+c->len,p,n); c->len+=n; } static void emit_rex(CodeBuf *c, int reg, int rm, int w) { uint8_t rex = 0x40; if (w) rex |= 0x08; if (reg&8) rex |= 0x04; if (rm&8) rex |= 0x01; emit8(c, (rex != 0x40 || w) ? rex : 0x48); } static void emit_modrm(CodeBuf *c, uint8_t mod, uint8_t reg, uint8_t rm) { emit8(c, (mod<<6)|((reg&7)<<3)|(rm&7)); } static void emit_movabs_rax_imm64(CodeBuf *c, uint64_t imm) { emit8(c,0x48); emit8(c,0xB8); emit64(c,imm); } static void emit_mov_reg_imm64(CodeBuf *c, int reg, int64_t imm) { emit_rex(c,reg,0,1); emit8(c,0xB8+(reg&7)); emit64(c,imm); } static void emit_mov_mem64_reg(CodeBuf *c, int disp32, int reg) { emit_rex(c,reg,RBP,1); emit8(c,0x89); emit_modrm(c,2,reg,RBP); emit32(c,(uint32_t)disp32); } static void emit_mov_reg_mem64(CodeBuf *c, int reg, int disp32) { emit_rex(c,reg,RBP,1); emit8(c,0x8B); emit_modrm(c,2,reg,RBP); emit32(c,(uint32_t)disp32); } static void emit_mov_reg_mem_reg(CodeBuf *c, int dst, int base) { emit_rex(c,dst,base,1); emit8(c,0x8B); emit_modrm(c,0,dst,base); } static void emit_mov_mem_reg_reg(CodeBuf *c, int base, int src) { emit_rex(c,src,base,1); emit8(c,0x89); emit_modrm(c,0,src,base); } static void emit_mov_mem8_reg_reg(CodeBuf *c, int base) { if (base & 8) emit8(c, 0x41); /* REX.B for extended base registers */ emit8(c, 0x88); emit_modrm(c, 0, RAX, base); } static void emit_movzx_rax_mem8(CodeBuf *c, int disp32) { emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB6); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); } static void emit_movzx_rax_mem8_base(CodeBuf *c, int base) { if (base & 8) emit8(c, 0x49); else emit8(c, 0x48); emit8(c, 0x0F); emit8(c, 0xB6); emit_modrm(c, 0, RAX, base); } static void emit_mov_mem8_rax(CodeBuf *c, int disp32) { emit_rex(c,RAX,RBP,0); emit8(c,0x88); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); } static void emit_mov_reg_reg(CodeBuf *c, int dst, int src) { emit_rex(c,src,dst,1); emit8(c,0x89); emit_modrm(c,3,src,dst); } static void emit_add_reg_reg(CodeBuf *c, int dst, int src) { emit_rex(c,src,dst,1); emit8(c,0x01); emit_modrm(c,3,src,dst); } static void emit_sub_reg_reg(CodeBuf *c, int dst, int src) { emit_rex(c,src,dst,1); emit8(c,0x29); emit_modrm(c,3,src,dst); } static void emit_add_reg_imm32(CodeBuf *c, int dst, int32_t imm) { emit_rex(c,0,dst,1); emit8(c,0x81); emit_modrm(c,3,0,dst); emit32(c,(uint32_t)imm); } static void emit_push_reg(CodeBuf *c, int reg) { if (reg&8) emit8(c,0x41); emit8(c,0x50+(reg&7)); } static void emit_pop_reg(CodeBuf *c, int reg) { if (reg&8) emit8(c,0x41); emit8(c,0x58+(reg&7)); } static void emit_add_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x01,0xD8},3); } static void emit_imul_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xAF,0xC3},4); } static void emit_idiv_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); } static void emit_imod(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); emit_mov_reg_reg(c, RAX, RDX); } static void emit_or_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x09,0xD8},3); } static void emit_xor_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x31,0xD8},3); } static void emit_and_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x21,0xD8},3); } static void emit_shl_rax_cl(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0xD3,0xE0},3); } static void emit_shr_rax_cl(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0xD3,0xE8},3); } static void emit_cmp_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x39,0xD8},3); } static void emit_setcc_al(CodeBuf *c, uint8_t cc) { emitN(c,(uint8_t[]){0x0F,0x90|cc,0xC0},3); } static void emit_movzx_rax_al(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xB6,0xC0},4); } static void emit_jmp_rel32(CodeBuf *c, int32_t rel) { emit8(c,0xE9); emit32(c,(uint32_t)rel); } static void emit_jcc_rel32(CodeBuf *c, uint8_t cc, int32_t rel) { emitN(c,(uint8_t[]){0x0F,(uint8_t)(0x80|cc)},2); emit32(c,(uint32_t)rel); } static void emit_test_rax_rax(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x85,0xC0},3); } static void emit_prologue(CodeBuf *c, int total_stack_size) { emit8(c,0x55); emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); emit_push_reg(c, RBX); int stack_bytes = ((total_stack_size+15)/16)*16; if (stack_bytes > 0) { emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); emit32(c,(uint32_t)stack_bytes); } } static void emit_epilogue(CodeBuf *c) { emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); } static void emit_lea_rax_rbp_disp(CodeBuf *c, int disp32) { emit8(c,0x48); emit8(c,0x8D); emit_modrm(c,2,RAX,RBP); emit32(c,(uint32_t)disp32); } static void emit_movzx_rax_mem16(CodeBuf *c, int disp32) { emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB7); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); } static void emit_mov_mem16_rax(CodeBuf *c, int disp32) { emit8(c,0x66); emit_rex(c,RAX,RBP,0); emit8(c,0x89); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); } static void emit_load_rax_from_mem(CodeBuf *c, int disp32, int size) { if (size == 1) emit_movzx_rax_mem8(c, disp32); else if (size == 2) emit_movzx_rax_mem16(c, disp32); else emit_mov_reg_mem64(c, RAX, disp32); } static void emit_store_rax_to_mem(CodeBuf *c, int disp32, int size) { if (size == 1) emit_mov_mem8_rax(c, disp32); else if (size == 2) emit_mov_mem16_rax(c, disp32); else emit_mov_mem64_reg(c, disp32, RAX); } static int calculate_type_size(_TY type) { if (type.ptr_level > 0) return 8; int slot; switch (type.base) { case TY_CHAR: slot = 1; break; case TY_BOOL: slot = 1; break; case TY_SHORT: slot = 2; break; case TY_INT: slot = 8; break; case TY_FLOAT: slot = 8; break; case TY_LONG: slot = 8; break; case TY_VOID: slot = 0; break; default: slot = 8; break; } if (type.array_size > 0) { size_t total = (size_t)slot * (size_t)type.array_size; if (total > (size_t)INT_MAX) { fprintf(stderr, "[JIT] array size too large\n"); exit(1); } return (int)total; } return slot; } static int align_offset(int offset, _TY type) { int align; if (type.ptr_level > 0) { align = 8; } else switch (type.base) { case TY_CHAR: align = 1; break; case TY_BOOL: align = 1; break; case TY_SHORT: align = 2; break; case TY_INT: align = 8; break; case TY_FLOAT: align = 8; break; case TY_LONG: align = 8; break; default: align = 8; break; } if (offset % align == 0) return offset; if (offset < 0) return ((offset - (align - 1)) / align) * align; return (offset / align) * align; } static int ty_slot_size(_TY ty) { if (ty.ptr_level > 0) return 8; switch (ty.base) { case TY_CHAR: return 1; case TY_BOOL: return 1; case TY_SHORT: return 2; case TY_INT: return 8; /* promoted to 64-bit slot */ case TY_FLOAT: return 8; /* stored in 64-bit slot (integer bits) */ case TY_LONG: return 8; case TY_VOID: return 0; default: return 8; } } static void reset_varmap(JIT *jit) { for (VarMap *v = jit->var_list; v;) { VarMap *n = v->next; free(v->name); free(v); v = n; } jit->var_list = NULL; jit->next_local_offset = -16; } static void add_var(JIT *jit, const char *name, _TY type) { VarMap *v = (VarMap *)malloc(sizeof(VarMap)); if (!v) { fprintf(stderr, "[JIT] malloc failed in add_var\n"); exit(1); } v->name = strdup(name); if (!v->name) { fprintf(stderr, "[JIT] strdup failed in add_var\n"); free(v); exit(1); } v->type = type; int type_size = calculate_type_size(type); jit->next_local_offset -= type_size; jit->next_local_offset = align_offset(jit->next_local_offset, type); v->offset = jit->next_local_offset; v->next = jit->var_list; jit->var_list = v; } static int get_var_offset(JIT *jit, const char *name) { for (VarMap *v = jit->var_list; v; v = v->next) if (strcmp(v->name, name) == 0) return v->offset; fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1); } static GlobalVar *find_global(JIT *jit, const char *name) { for (GlobalVar *g = jit->global_list; g; g = g->next) if (strcmp(g->name, name) == 0) return g; return NULL; } static GlobalVar *register_global(JIT *jit, const char *name, _TY type) { if (find_global(jit, name)) { fprintf(stderr, "[JIT] Duplicate global '%s'\n", name); exit(1); } int elem_sz = ty_slot_size((_TY){type.base, type.ptr_level > 0 ? type.ptr_level : 0, -1}); int n_elems = (type.array_size > 0) ? type.array_size : 1; int total_sz = elem_sz * n_elems; size_t align = (elem_sz < 8) ? elem_sz : 8; size_t off = (jit->globals_used + align - 1) & ~(align - 1); if (off + (size_t)total_sz > jit->globals_cap) { fprintf(stderr, "[JIT] Globals segment full\n"); exit(1); } GlobalVar *g = (GlobalVar *)malloc(sizeof(GlobalVar)); if (!g) { fprintf(stderr, "[JIT] OOM\n"); exit(1); } g->name = strdup(name); g->addr = jit->globals_buf + off; g->type = type; g->size = total_sz; g->next = jit->global_list; jit->global_list = g; jit->globals_used = off + total_sz; return g; } static _TY get_var_type(JIT *jit, const char *name) { for (VarMap *v = jit->var_list; v; v = v->next) if (strcmp(v->name, name) == 0) return v->type; /* Fall back to globals */ GlobalVar *g = find_global(jit, name); if (g) return g->type; fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1); } static _TY get_func_ret_type(JIT *jit, const char *name); static void reset_varmap(JIT *jit); static _TY get_expr_type(JIT *jit, _EX *expr) { switch (expr->kind) { case EX_NUMBER: return (_TY){TY_INT, 0, -1}; case EX_STRING: return (_TY){TY_CHAR, 1, -1}; case EX_VAR: { _TY t = get_var_type(jit, expr->name); if (t.array_size > 0) return (_TY){t.base, t.ptr_level + 1, -1}; return t; } case EX_BINOP: { _TY left_type = get_expr_type(jit, expr->binop.l); (void)get_expr_type(jit, expr->binop.r); switch (expr->binop.op) { case TK_PLUS: case TK_MINUS: case TK_STAR: case TK_SLASH: case TK_PERCENT: case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE: case TK_AND: case TK_OR: case TK_AMP: case TK_BAR: case TK_CARET: case TK_SHL: case TK_SHR: return (_TY){TY_INT, 0, -1}; default: return left_type; } } case EX_CALL: if (strcmp(expr->call.func_name, "syscall") == 0) return (_TY){TY_LONG, 0, -1}; if (strcmp(expr->call.func_name, "__initlist__") == 0 || strcmp(expr->call.func_name, "__sizeof__") == 0) return (_TY){TY_INT, 0, -1}; return get_func_ret_type(jit, expr->call.func_name); case EX_INDEX: { _TY t = get_expr_type(jit, expr->index.array); if (t.array_size > 0) return (_TY){t.base, t.ptr_level, -1}; if (t.ptr_level > 0) return (_TY){t.base, t.ptr_level-1, -1}; return (_TY){TY_INT, 0, -1}; } case EX_DEREF: { _TY t = get_expr_type(jit, expr->deref.expr); if (t.ptr_level > 0) return (_TY){t.base, t.ptr_level-1, -1}; return (_TY){TY_INT, 0, -1}; } case EX_ADDR: { _TY t = get_expr_type(jit, expr->addr.expr); return (_TY){t.base, t.ptr_level+1, -1}; } case EX_TERNARY: return get_expr_type(jit, expr->ternary.then_expr); case EX_CAST: return expr->cast.to; default: return (_TY){TY_INT, 0, -1}; } } static int type_is_integer(_TY ty) { if (ty.ptr_level > 0) return 0; switch (ty.base) { case TY_INT: case TY_CHAR: case TY_SHORT: case TY_LONG: case TY_BOOL: return 1; default: return 0; } } static int types_compatible(_TY expected, _TY actual) { if (expected.base == actual.base && expected.ptr_level == actual.ptr_level && expected.array_size == actual.array_size) return 1; if (actual.base == TY_INT && actual.ptr_level == 0 && actual.array_size == -1) return 1; if (type_is_integer(expected) && type_is_integer(actual)) return 1; if (expected.ptr_level > 0 && actual.ptr_level > 0) return 1; if (expected.ptr_level > 0 && actual.array_size > 0 && expected.base == actual.base && expected.ptr_level == actual.ptr_level + 1) return 1; return 0; } static void register_func(JIT *jit, const char *name, _TY ret_type) { FuncMap *f = (FuncMap *)malloc(sizeof(FuncMap)); if (!f) { fprintf(stderr, "[JIT] malloc failed in register_func\n"); exit(1); } f->name = strdup(name); if (!f->name) { fprintf(stderr, "[JIT] strdup failed in register_func\n"); free(f); exit(1); } f->addr = NULL; f->size = 0; f->alloc_size = 0; f->ret_type = ret_type; f->next = jit->func_list; jit->func_list = f; } static _TY get_func_ret_type(JIT *jit, const char *name) { for (FuncMap *f = jit->func_list; f; f = f->next) if (strcmp(f->name, name) == 0) return f->ret_type; return (_TY){TY_INT, 0, -1}; } static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, size_t alloc_size) { for (FuncMap *f = jit->func_list; f; f = f->next) { if (strcmp(f->name, name) != 0) continue; f->addr = addr; f->size = size; f->alloc_size = alloc_size; fprintf(stderr, "\n[JIT] %s @ %p (%zu bytes)\n", name, addr, size); uint8_t *p = (uint8_t *)addr; for (size_t i = 0; i < size; i += 16) { fprintf(stderr, "\t%04zx ", i); for (size_t j = 0; j < 16; j++) (i+j < size) ? fprintf(stderr, "%02x ", p[i+j]) : fprintf(stderr, " "); fprintf(stderr, " |"); for (size_t j = 0; j < 16 && i+j < size; j++) fprintf(stderr, "%c", isprint(p[i+j]) ? p[i+j] : '.'); fprintf(stderr, "\n"); } return; } fprintf(stderr, "[JIT] set_func_addr: unknown function '%s'\n", name); exit(1); } static void *get_func_addr(JIT *jit, const char *name) { for (FuncMap *f = jit->func_list; f; f = f->next) { if (strcmp(f->name, name) == 0) return f->addr ? f->addr : (void*)0xDEADBEEF; } fprintf(stderr, "[JIT] get_func_addr: unknown function '%s'\n", name); exit(1); } static int calculate_stack_size(_STN *s) { int total = 0; while (s) { if (s->kind == STK_VAR_DECL) { int sz = calculate_type_size(s->var_decl.type); total += (sz < 8 && s->var_decl.type.array_size <= 0) ? 8 : sz; } if (s->kind == STK_BLOCK) total += calculate_stack_size(s->body); if (s->kind == STK_FOR) { if (s->fr.init) total += calculate_stack_size(s->fr.init); total += calculate_stack_size(s->fr.body); } s = s->n; } return total + 8; } static int calculate_total_stack_size(_FN *f) { int param_space = 0; for (int i = 0; i < f->pac && i < 6; i++) param_space += (calculate_type_size(f->param_types[i]) < 8) ? 8 : calculate_type_size(f->param_types[i]); return calculate_stack_size(f->body) + param_space; } static void gen_expr_jit(JIT *jit, _EX *e); static int gen_stmt_jit(JIT *jit, _STN *s); static int gen_stmt_jit(JIT *jit, _STN *s) { while (s) { switch (s->kind) { case STK_VAR_DECL: add_var(jit, s->var_decl.name, s->var_decl.type); if (s->var_decl.init) { if (s->var_decl.init->kind == EX_CALL && strcmp(s->var_decl.init->call.func_name, "__initlist__") == 0) { _TY vty = s->var_decl.type; int elem_sz = ty_slot_size((_TY){vty.base, 0, -1}); int arr_off = get_var_offset(jit, s->var_decl.name); for (int ii = 0; ii < s->var_decl.init->call.argc; ii++) { gen_expr_jit(jit, s->var_decl.init->call.args[ii]); /* address of arr[ii] = RBP + arr_off + ii*elem_sz */ int elem_off = arr_off + ii * elem_sz; emit_store_rax_to_mem(&jit->cb, elem_off, elem_sz); } break; } _TY init_type = get_expr_type(jit, s->var_decl.init); if (!types_compatible(s->var_decl.type, init_type)) { fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n", tybase_name(init_type.base), tybase_name(s->var_decl.type.base)); exit(1); } gen_expr_jit(jit, s->var_decl.init); int sz = ty_slot_size(s->var_decl.type); emit_store_rax_to_mem(&jit->cb, get_var_offset(jit, s->var_decl.name), sz); } break; case STK_ASSIGN: { _EX *lhs = s->assign.lhs; if (lhs->kind == EX_VAR) { GlobalVar *gv = find_global(jit, lhs->name); if (gv) { gen_expr_jit(jit, s->assign.expr); emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr); int sz = ty_slot_size(gv->type); if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); break; } int offset = get_var_offset(jit, lhs->name); _TY type = get_var_type(jit, lhs->name); _TY expr_type = get_expr_type(jit, s->assign.expr); if (!types_compatible(type, expr_type)) { fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n", tybase_name(expr_type.base), tybase_name(type.base)); exit(1); } gen_expr_jit(jit, s->assign.expr); int sz = ty_slot_size(type); emit_store_rax_to_mem(&jit->cb, offset, sz); } else if (lhs->kind == EX_DEREF) { gen_expr_jit(jit, lhs->deref.expr); emit_push_reg(&jit->cb, RAX); // save address on stack gen_expr_jit(jit, s->assign.expr); // RAX = value (may clobber RBX) emit_pop_reg(&jit->cb, RBX); // RBX = address _TY ptr_ty = get_expr_type(jit, lhs->deref.expr); int dsz = (ptr_ty.ptr_level > 1) ? 8 : ty_slot_size((_TY){ptr_ty.base, 0, -1}); if (dsz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } else if (lhs->kind == EX_INDEX) { if (lhs->index.array->kind != EX_VAR) { gen_expr_jit(jit, lhs->index.array); break; } _TY var_type = get_var_type(jit, lhs->index.array->name); int element_size = ty_slot_size((_TY){var_type.base, 0, -1}); if (var_type.ptr_level > 0) { GlobalVar *gv_ptr2 = find_global(jit, lhs->index.array->name); if (gv_ptr2) { emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr2->addr); emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref global ptr */ } else { int ptr_offset = get_var_offset(jit, lhs->index.array->name); emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer } gen_expr_jit(jit, lhs->index.index); // RAX = index if (element_size > 1) { emit_mov_reg_reg(&jit->cb, RCX, RBX); emit_mov_reg_imm64(&jit->cb, RBX, element_size); emit_imul_rax_rbx(&jit->cb); // RAX = byte offset emit_mov_reg_reg(&jit->cb, RBX, RCX); } emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index] emit_push_reg(&jit->cb, RBX); // save address gen_expr_jit(jit, s->assign.expr); // RAX = value emit_pop_reg(&jit->cb, RBX); // restore address if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } else if (var_type.array_size > 0) { GlobalVar *gv_arr2 = find_global(jit, lhs->index.array->name); gen_expr_jit(jit, lhs->index.index); // RAX = index emit_mov_reg_imm64(&jit->cb, RBX, element_size); emit_imul_rax_rbx(&jit->cb); // RAX = byte offset if (gv_arr2) { emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr2->addr); emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index] } else { int array_offset = get_var_offset(jit, lhs->index.array->name); emit_mov_reg_imm64(&jit->cb, RBX, array_offset); emit_add_reg_reg(&jit->cb, RBX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_mov_reg_reg(&jit->cb, RBX, RBP); emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset } emit_push_reg(&jit->cb, RBX); // save address gen_expr_jit(jit, s->assign.expr); // RAX = value emit_pop_reg(&jit->cb, RBX); // RBX = address if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } else { fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", lhs->index.array->name); exit(1); } } else { fprintf(stderr, "[JIT] Unsupported assignment LHS kind %d\n", lhs->kind); exit(1); } break; } case STK_RETURN: if (s->return_expr) gen_expr_jit(jit, s->return_expr); emit_epilogue(&jit->cb); return 1; case STK_EXPR: gen_expr_jit(jit, s->expr); break; case STK_BLOCK: { int r = gen_stmt_jit(jit, s->body); if (r) return 1; break; } case STK_IF: { gen_expr_jit(jit, s->ifs.cond); emit_test_rax_rax(&jit->cb); size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); // JZ else gen_stmt_jit(jit, s->ifs.thenb); size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); // JMP end int32_t rel_else = (int32_t)(jit->cb.len - (jz_pos + 6)); memcpy(jit->cb.buf + jz_pos + 2, &rel_else, 4); if (s->ifs.elseb) gen_stmt_jit(jit, s->ifs.elseb); int32_t rel_end = (int32_t)(jit->cb.len - (jmp_pos + 5)); memcpy(jit->cb.buf + jmp_pos + 1, &rel_end, 4); break; } case STK_WHILE: { int depth = jit->loop_depth++; jit->break_patch_count[depth] = 0; jit->cont_patch_count[depth] = 0; size_t loop_start = jit->cb.len; gen_expr_jit(jit, s->whl.cond); emit_test_rax_rax(&jit->cb); size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); gen_stmt_jit(jit, s->whl.body); size_t cont_target = jit->cb.len; for (int i = 0; i < jit->cont_patch_count[depth]; i++) { size_t p = jit->cont_patches[depth][i]; int32_t r = (int32_t)(cont_target - (p + 5)); memcpy(jit->cb.buf + p + 1, &r, 4); } emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); size_t break_target = jit->cb.len; int32_t rel_end = (int32_t)(break_target - (jz_pos + 6)); memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4); for (int i = 0; i < jit->break_patch_count[depth]; i++) { size_t p = jit->break_patches[depth][i]; int32_t r = (int32_t)(break_target - (p + 5)); memcpy(jit->cb.buf + p + 1, &r, 4); } jit->loop_depth--; break; } case STK_FOR: { int depth = jit->loop_depth++; jit->break_patch_count[depth] = 0; jit->cont_patch_count[depth] = 0; if (s->fr.init) gen_stmt_jit(jit, s->fr.init); size_t loop_start = jit->cb.len; size_t jz_pos_for = 0; int has_cond = (s->fr.cond != NULL); if (has_cond) { gen_expr_jit(jit, s->fr.cond); emit_test_rax_rax(&jit->cb); jz_pos_for = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); } gen_stmt_jit(jit, s->fr.body); size_t cont_target = jit->cb.len; for (int i = 0; i < jit->cont_patch_count[depth]; i++) { size_t p = jit->cont_patches[depth][i]; int32_t r = (int32_t)(cont_target - (p + 5)); memcpy(jit->cb.buf + p + 1, &r, 4); } if (s->fr.step) gen_stmt_jit(jit, s->fr.step); emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); size_t break_target = jit->cb.len; if (has_cond) { int32_t rel_end = (int32_t)(break_target - (jz_pos_for + 6)); memcpy(jit->cb.buf + jz_pos_for + 2, &rel_end, 4); } for (int i = 0; i < jit->break_patch_count[depth]; i++) { size_t p = jit->break_patches[depth][i]; int32_t r = (int32_t)(break_target - (p + 5)); memcpy(jit->cb.buf + p + 1, &r, 4); } jit->loop_depth--; break; } case STK_DOWHILE: { int depth = jit->loop_depth++; jit->break_patch_count[depth] = 0; jit->cont_patch_count[depth] = 0; size_t loop_start = jit->cb.len; gen_stmt_jit(jit, s->dowhl.body); /* continue → jump to condition check */ size_t cont_target = jit->cb.len; for (int i = 0; i < jit->cont_patch_count[depth]; i++) { size_t p = jit->cont_patches[depth][i]; int32_t r = (int32_t)(cont_target - (p + 5)); memcpy(jit->cb.buf + p + 1, &r, 4); } gen_expr_jit(jit, s->dowhl.cond); emit_test_rax_rax(&jit->cb); emit_jcc_rel32(&jit->cb, 0x05, (int32_t)(loop_start - (jit->cb.len + 6))); size_t break_target = jit->cb.len; for (int i = 0; i < jit->break_patch_count[depth]; i++) { size_t p = jit->break_patches[depth][i]; int32_t r = (int32_t)(break_target - (p + 5)); memcpy(jit->cb.buf + p + 1, &r, 4); } jit->loop_depth--; break; } case STK_BREAK: { if (jit->loop_depth == 0) { fprintf(stderr, "[JIT] 'break' outside loop\n"); exit(1); } int depth = jit->loop_depth - 1; size_t pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); int cnt = jit->break_patch_count[depth]; if (cnt >= 256) { fprintf(stderr, "[JIT] Too many breaks\n"); exit(1); } jit->break_patches[depth][cnt] = pos; jit->break_patch_count[depth]++; break; } case STK_CONTINUE: { if (jit->loop_depth == 0) { fprintf(stderr, "[JIT] 'continue' outside loop\n"); exit(1); } int depth = jit->loop_depth - 1; size_t pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); int cnt = jit->cont_patch_count[depth]; if (cnt >= 256) { fprintf(stderr, "[JIT] Too many continues\n"); exit(1); } jit->cont_patches[depth][cnt] = pos; jit->cont_patch_count[depth]++; break; } case STK_GLOBAL: { GlobalVar *gv = find_global(jit, s->global.name); if (!gv) { fprintf(stderr, "[JIT] Global '%s' not registered\n", s->global.name); exit(1); } if (s->global.init) { if (s->global.init->kind == EX_CALL && strcmp(s->global.init->call.func_name, "__initlist__") == 0) { int esz = ty_slot_size((_TY){gv->type.base, 0, -1}); for (int ii = 0; ii < s->global.init->call.argc; ii++) { gen_expr_jit(jit, s->global.init->call.args[ii]); emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)(gv->addr + ii * esz)); if (esz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } } else { gen_expr_jit(jit, s->global.init); emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr); int sz = ty_slot_size(gv->type); if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } } break; } fprintf(stderr, "[JIT] Unsupported statement kind %d\n", s->kind); exit(1); } s = s->n; } return 0; } static void gen_expr_jit(JIT *jit, _EX *e) { if (!e) return; switch (e->kind) { case EX_NUMBER: emit_movabs_rax_imm64(&jit->cb, (uint64_t)e->value); break; case EX_VAR: { GlobalVar *gv = NULL; for (VarMap *v = jit->var_list; v; v = v->next) { if (strcmp(v->name, e->name) == 0) goto local_var; } gv = find_global(jit, e->name); if (gv) { _TY ty = gv->type; if (ty.array_size > 0) { /* Global array: load its absolute address */ emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv->addr); break; } /* Global scalar: absolute address → load through it */ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr); int sz = ty_slot_size(ty); if (sz == 1) emit_movzx_rax_mem8_base(&jit->cb, RBX); else emit_mov_reg_mem_reg(&jit->cb, RAX, RBX); break; } local_var: { int off = get_var_offset(jit, e->name); _TY ty = get_var_type(jit, e->name); if (ty.array_size > 0) { emit_lea_rax_rbp_disp(&jit->cb, off); break; } int sz = ty_slot_size(ty); if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off); else emit_mov_reg_mem64(&jit->cb, RAX, off); break; } } case EX_BINOP: { if (e->binop.op == TK_INC || e->binop.op == TK_DEC) { int is_pre = (e->binop.r->value == -1); int is_inc = (e->binop.op == TK_INC); _EX *lval = e->binop.l; _TY lty = get_expr_type(jit, lval); int step = (lty.ptr_level > 0) ? ty_slot_size((_TY){lty.base,0,-1}) : 1; gen_expr_jit(jit, lval); /* RAX = old value */ emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = old value */ emit_mov_reg_imm64(&jit->cb, RBX, step); if (is_inc) emit_add_reg_reg(&jit->cb, RCX, RBX); else emit_sub_reg_reg(&jit->cb, RCX, RBX); if (lval->kind == EX_VAR) { GlobalVar *gv_inc = find_global(jit, lval->name); emit_mov_reg_reg(&jit->cb, RAX, RCX); if (gv_inc) { emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_inc->addr); int sz = ty_slot_size(lty); if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } else { int off = get_var_offset(jit, lval->name); int sz = ty_slot_size(lty); emit_store_rax_to_mem(&jit->cb, off, sz); } } else { emit_push_reg(&jit->cb, RCX); /* save new value */ if (lval->kind == EX_DEREF) { gen_expr_jit(jit, lval->deref.expr); /* RAX = address */ } else { /* EX_INDEX */ _TY vty = get_var_type(jit, lval->index.array->name); int esz = ty_slot_size((_TY){vty.base,0,-1}); if (vty.ptr_level > 0) { int poff = get_var_offset(jit, lval->index.array->name); emit_mov_reg_mem64(&jit->cb, RAX, poff); emit_push_reg(&jit->cb, RAX); gen_expr_jit(jit, lval->index.index); if (esz > 1) { emit_mov_reg_imm64(&jit->cb, RBX, esz); emit_imul_rax_rbx(&jit->cb); } emit_pop_reg(&jit->cb, RBX); emit_add_reg_reg(&jit->cb, RAX, RBX); } else { int aoff = get_var_offset(jit, lval->index.array->name); gen_expr_jit(jit, lval->index.index); emit_mov_reg_imm64(&jit->cb, RBX, esz); emit_imul_rax_rbx(&jit->cb); emit_mov_reg_imm64(&jit->cb, RBX, aoff); emit_add_reg_reg(&jit->cb, RBX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_mov_reg_reg(&jit->cb, RBX, RBP); emit_add_reg_reg(&jit->cb, RAX, RBX); } } emit_mov_reg_reg(&jit->cb, RBX, RAX); /* RBX = address */ emit_pop_reg(&jit->cb, RCX); /* RCX = new val */ emit_mov_reg_reg(&jit->cb, RAX, RCX); emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } if (is_pre) emit_mov_reg_reg(&jit->cb, RAX, RCX); if (!is_pre) { emit_mov_reg_reg(&jit->cb, RAX, RCX); emit_mov_reg_imm64(&jit->cb, RBX, step); if (is_inc) emit_sub_reg_reg(&jit->cb, RAX, RBX); else emit_add_reg_reg(&jit->cb, RAX, RBX); } break; } if (e->binop.op == TK_AND) { gen_expr_jit(jit, e->binop.l); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); gen_expr_jit(jit, e->binop.r); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); emit_setcc_al(&jit->cb, 0x05); emit_movzx_rax_al(&jit->cb); size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); int32_t rel = (int32_t)(jit->cb.len - (jz_pos + 6)); memcpy(jit->cb.buf + jz_pos + 2, &rel, 4); emit_movabs_rax_imm64(&jit->cb, 0); rel = (int32_t)(jit->cb.len - (jmp_pos + 5)); memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4); break; } if (e->binop.op == TK_OR) { gen_expr_jit(jit, e->binop.l); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); size_t jnz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x05, 0); gen_expr_jit(jit, e->binop.r); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); emit_setcc_al(&jit->cb, 0x05); emit_movzx_rax_al(&jit->cb); size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); int32_t rel = (int32_t)(jit->cb.len - (jnz_pos + 6)); memcpy(jit->cb.buf + jnz_pos + 2, &rel, 4); emit_movabs_rax_imm64(&jit->cb, 1); rel = (int32_t)(jit->cb.len - (jmp_pos + 5)); memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4); break; } gen_expr_jit(jit, e->binop.l); emit_push_reg(&jit->cb, RAX); gen_expr_jit(jit, e->binop.r); emit_pop_reg(&jit->cb, RBX); // RBX = left, RAX = right switch (e->binop.op) { case TK_PLUS: { _TY lt = get_expr_type(jit, e->binop.l); _TY rt = get_expr_type(jit, e->binop.r); int lptr = (lt.ptr_level > 0 || lt.array_size > 0); int rptr = (rt.ptr_level > 0 || rt.array_size > 0); if (lptr && !rptr) { int esz = ty_slot_size((_TY){lt.base, 0, -1}); if (esz > 1) { emit_mov_reg_imm64(&jit->cb, RCX, esz); emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4); } } else if (rptr && !lptr) { int esz = ty_slot_size((_TY){rt.base, 0, -1}); emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = ptr */ emit_mov_reg_reg(&jit->cb, RAX, RBX); /* RAX = int */ emit_mov_reg_reg(&jit->cb, RBX, RCX); /* RBX = ptr */ if (esz > 1) { emit_mov_reg_imm64(&jit->cb, RCX, esz); emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4); } } emit_add_rax_rbx(&jit->cb); break; } case TK_MINUS: { _TY lt = get_expr_type(jit, e->binop.l); _TY rt = get_expr_type(jit, e->binop.r); int lptr = (lt.ptr_level > 0 || lt.array_size > 0); int rptr = (rt.ptr_level > 0 || rt.array_size > 0); if (lptr && !rptr) { /* ptr - int: scale int by element size */ int esz = ty_slot_size((_TY){lt.base, 0, -1}); if (esz > 1) { emit_mov_reg_imm64(&jit->cb, RCX, esz); emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4); } } /* RAX=right(scaled), RBX=left: result = left - right */ emit_mov_reg_reg(&jit->cb, RCX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_rex(&jit->cb, RCX, RAX, 1); emit8(&jit->cb, 0x29); emit_modrm(&jit->cb, 3, RCX, RAX); break; } case TK_STAR: emit_imul_rax_rbx(&jit->cb); break; case TK_SLASH: emit_mov_reg_reg(&jit->cb, RCX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_mov_reg_reg(&jit->cb, RBX, RCX); emit_idiv_rbx(&jit->cb); break; case TK_PERCENT: emit_mov_reg_reg(&jit->cb, RCX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_mov_reg_reg(&jit->cb, RBX, RCX); emit_imod(&jit->cb); break; case TK_BAR: emit_or_rax_rbx(&jit->cb); break; case TK_CARET: emit_xor_rax_rbx(&jit->cb); break; case TK_AMP: emit_and_rax_rbx(&jit->cb); break; case TK_SHL: emit_mov_reg_reg(&jit->cb, RCX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_shl_rax_cl(&jit->cb); break; case TK_SHR: emit_mov_reg_reg(&jit->cb, RCX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_shr_rax_cl(&jit->cb); break; case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE: { emit_mov_reg_reg(&jit->cb, RCX, RAX); // RCX = right emit_mov_reg_reg(&jit->cb, RAX, RBX); // RAX = left emit_mov_reg_reg(&jit->cb, RBX, RCX); // RBX = right emit_cmp_rax_rbx(&jit->cb); uint8_t cc; switch (e->binop.op) { case TK_EQ: cc=0x04; break; case TK_NE: cc=0x05; break; case TK_LT: cc=0x0C; break; case TK_LE: cc=0x0E; break; case TK_GT: cc=0x0F; break; case TK_GE: cc=0x0D; break; default: cc=0x04; break; } emit_setcc_al(&jit->cb, cc); emit_movzx_rax_al(&jit->cb); break; } default: fprintf(stderr, "[JIT] Unsupported binop %d\n", e->binop.op); exit(1); } break; } case EX_CALL: { if (strcmp(e->call.func_name, "__sizeof__") == 0) { _EX *arg = e->call.args[0]; _TY ty; if (arg->kind == EX_VAR) { ty = get_var_type(jit, arg->name); } else { ty = get_expr_type(jit, arg); } int sz; if (ty.array_size > 0) { int esz = (ty.base == TY_CHAR) ? 1 : (ty.base == TY_SHORT) ? 2 : (ty.base == TY_LONG) ? 8 : 4; sz = esz * ty.array_size; } else if (ty.ptr_level > 0) { sz = 8; } else { switch (ty.base) { case TY_CHAR: sz = 1; break; case TY_SHORT: sz = 2; break; case TY_LONG: sz = 8; break; default: sz = 4; break; /* int, bool, float */ } } emit_mov_reg_imm64(&jit->cb, RAX, sz); break; } if (strcmp(e->call.func_name, "syscall") == 0) { int argc = e->call.argc; if (argc < 1) { fprintf(stderr, "[JIT] syscall() requires at least 1 argument (number)\n"); exit(1); } if (argc > 7) { fprintf(stderr, "[JIT] syscall() supports at most 7 arguments\n"); exit(1); } const int sc_regs[6] = {RDI, RSI, RDX, R10, R8, R9}; int nargs = argc - 1; /* number of args after the syscall number */ for (int i = argc - 1; i >= 0; i--) { gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX); } emit_pop_reg(&jit->cb, RAX); for (int i = 0; i < nargs; i++) { emit_pop_reg(&jit->cb, sc_regs[i]); } emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0x05); break; } int total_args = e->call.argc; int stack_args = total_args > 6 ? total_args - 6 : 0; int reg_args = total_args < 6 ? total_args : 6; int padding = (stack_args % 2) ? 8 : 0; const int arg_regs[6] = {RDI, RSI, RDX, RCX, R8, R9}; /* Push stack-passed args right-to-left (args[total-1] first). */ for (int i = total_args-1; i >= 6; i--) { gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX); } for (int i = reg_args-1; i >= 0; i--) { gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX); } for (int i = 0; i < reg_args; i++) { emit_pop_reg(&jit->cb, arg_regs[i]); } if (padding) { emit8(&jit->cb,0x48); emit8(&jit->cb,0x83); emit8(&jit->cb,0xEC); emit8(&jit->cb,0x08); } void *addr = get_func_addr(jit, e->call.func_name); if (addr == (void*)0xDEADBEEF) { emit_movabs_rax_imm64(&jit->cb, (uint64_t)0xDEADBEEF); PatchEntry *patch = (PatchEntry*)malloc(sizeof(PatchEntry)); if (!patch) { fprintf(stderr, "[JIT] malloc failed (patch)\n"); exit(1); } patch->func_name = strdup(e->call.func_name); patch->owning_func = strdup(jit->current_func_name); if (!patch->func_name || !patch->owning_func) { fprintf(stderr, "[JIT] strdup failed (patch)\n"); free(patch->func_name); free(patch->owning_func); free(patch); exit(1); } patch->offset = jit->cb.len - 8; // points at the imm64 in the buffer patch->next = jit->patch_list; jit->patch_list = patch; } else { emit_movabs_rax_imm64(&jit->cb, (uint64_t)addr); } emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); if (stack_args * 8 + padding > 0) { emit8(&jit->cb,0x48); emit8(&jit->cb,0x81); emit8(&jit->cb,0xC4); emit32(&jit->cb, (uint32_t)(stack_args*8 + padding)); } break; } case EX_ADDR: if (e->addr.expr->kind == EX_VAR) { GlobalVar *gv_addr = find_global(jit, e->addr.expr->name); if (gv_addr) { emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv_addr->addr); } else { emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name)); } } else { fprintf(stderr, "[JIT] &expr: only &var supported\n"); exit(1); } break; case EX_TERNARY: { gen_expr_jit(jit, e->ternary.cond); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); gen_expr_jit(jit, e->ternary.then_expr); size_t jmp_pos = jit->cb.len; emit_jmp_rel32(&jit->cb, 0); int32_t rel_jz = (int32_t)(jit->cb.len - (jz_pos + 6)); memcpy(jit->cb.buf + jz_pos + 2, &rel_jz, 4); gen_expr_jit(jit, e->ternary.else_expr); int32_t rel_jmp = (int32_t)(jit->cb.len - (jmp_pos + 5)); memcpy(jit->cb.buf + jmp_pos + 1, &rel_jmp, 4); break; } case EX_CAST: { gen_expr_jit(jit, e->cast.expr); _TY to = e->cast.to; if (to.ptr_level > 0) break; /* pointer cast: value unchanged */ switch (to.base) { case TY_CHAR: emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, al */ break; case TY_SHORT: emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB7); emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, ax */ break; case TY_INT: emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x63); emit_modrm(&jit->cb, 3, RAX, RAX); break; default: break; /* long/void/etc: no-op */ } break; } case EX_DEREF: gen_expr_jit(jit, e->deref.expr); emit_rex(&jit->cb, RAX, RAX, 1); emit8(&jit->cb, 0x8B); emit_modrm(&jit->cb, 0, RAX, RAX); break; case EX_STRING: { char *s = strdup(e->string); if (!s) { fprintf(stderr, "[JIT] strdup failed (string literal)\n"); exit(1); } emit_movabs_rax_imm64(&jit->cb, (uint64_t)s); break; } case EX_INDEX: { if (e->index.array->kind != EX_VAR) { gen_expr_jit(jit, e->index.array); break; } _TY var_type = get_var_type(jit, e->index.array->name); int element_size = ty_slot_size((_TY){var_type.base, 0, -1}); if (var_type.ptr_level > 0) { GlobalVar *gv_ptr = find_global(jit, e->index.array->name); if (gv_ptr) { emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr->addr); emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref: RBX = *addr = pointer value */ } else { int ptr_offset = get_var_offset(jit, e->index.array->name); emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer } gen_expr_jit(jit, e->index.index); // RAX = index if (element_size > 1) { emit_mov_reg_reg(&jit->cb, RCX, RBX); // save pointer emit_mov_reg_imm64(&jit->cb, RBX, element_size); emit_imul_rax_rbx(&jit->cb); // RAX = byte offset emit_mov_reg_reg(&jit->cb, RBX, RCX); // restore pointer } emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index] if (element_size == 1) { emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); emit_modrm(&jit->cb, 0, RAX, RBX); } else { emit_mov_reg_mem_reg(&jit->cb, RAX, RBX); } } else if (var_type.array_size > 0) { GlobalVar *gv_arr = find_global(jit, e->index.array->name); gen_expr_jit(jit, e->index.index); // RAX = index emit_mov_reg_imm64(&jit->cb, RBX, element_size); emit_imul_rax_rbx(&jit->cb); // RAX = byte offset if (gv_arr) { /* Global array: base is absolute address */ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr->addr); emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index] } else { int array_offset = get_var_offset(jit, e->index.array->name); emit_mov_reg_imm64(&jit->cb, RBX, array_offset); emit_add_reg_reg(&jit->cb, RBX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_mov_reg_reg(&jit->cb, RBX, RBP); emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset } if (element_size == 1) { emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); emit_modrm(&jit->cb, 0, RAX, RBX); } else { emit_mov_reg_mem_reg(&jit->cb, RAX, RBX); } } else { fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", e->index.array->name); exit(1); } break; } default: fprintf(stderr, "[JIT] Unsupported expression kind %d\n", e->kind); exit(1); } } static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) { reset_varmap(jit); jit->current_func_name = f->name; int total_stack_size = calculate_total_stack_size(f); cb_init(&jit->cb); emit_prologue(&jit->cb, total_stack_size); const int param_regs[6] = {RDI, RSI, RDX, RCX, R8, R9}; for (int i = 0; i < f->pac && i < 6; i++) { add_var(jit, f->params[i], f->param_types[i]); emit_mov_mem64_reg(&jit->cb, get_var_offset(jit, f->params[i]), param_regs[i]); } int num_stack_params = f->pac > 6 ? f->pac - 6 : 0; int stack_arg_padding = (num_stack_params % 2) ? 8 : 0; int stack_arg_base = 16 + stack_arg_padding; /* offset of last-pushed (highest-index) stack arg */ for (int i = 6; i < f->pac; i++) { VarMap *v = (VarMap *)malloc(sizeof(VarMap)); if (!v) { fprintf(stderr, "[JIT] malloc failed (stack param)\n"); exit(1); } v->name = strdup(f->params[i]); if (!v->name) { fprintf(stderr, "[JIT] strdup failed (stack param)\n"); free(v); exit(1); } v->type = f->param_types[i]; v->offset = stack_arg_base + (i - 6) * 8; v->next = jit->var_list; jit->var_list = v; } int did_return = gen_stmt_jit(jit, f->body); if (!did_return) { emit_movabs_rax_imm64(&jit->cb, 0); emit_epilogue(&jit->cb); } size_t pagesz = (size_t)sysconf(_SC_PAGESIZE); size_t alloc_size = ((jit->cb.len + pagesz - 1) / pagesz) * pagesz; void *mem = mmap(NULL, alloc_size, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0); if (mem == MAP_FAILED) { perror("mmap"); cb_free(&jit->cb); return NULL; } memcpy(mem, jit->cb.buf, jit->cb.len); if (out_size) *out_size = jit->cb.len; set_func_addr(jit, f->name, mem, jit->cb.len, alloc_size); cb_free(&jit->cb); return mem; } static void patch_function_calls(JIT *jit) { for (PatchEntry *patch = jit->patch_list; patch; patch = patch->next) { void *target = NULL; for (FuncMap *f = jit->func_list; f; f = f->next) if (strcmp(f->name, patch->func_name) == 0 && f->addr) { target = f->addr; break; } if (!target) { fprintf(stderr, "[JIT] patch: function '%s' not found\n", patch->func_name); exit(1); } void *owner = NULL; for (FuncMap *f = jit->func_list; f; f = f->next) if (strcmp(f->name, patch->owning_func) == 0 && f->addr) { owner = f->addr; break; } if (!owner) { fprintf(stderr, "[JIT] patch: owning function '%s' not found\n", patch->owning_func); exit(1); } *(void **)((uint8_t *)owner + patch->offset) = target; } } static void jit_compile_all(JIT *jit, _FN *fn_list) { /* First pass: register all globals so type lookups work during codegen */ for (_FN *cur = fn_list; cur; cur = cur->n) { if (strncmp(cur->name, "__global_", 9) == 0) { /* Extract variable name from __global_NAME__ */ const char *start = cur->name + 9; size_t len = strlen(start); if (len >= 2 && start[len-2] == '_' && start[len-1] == '_') len -= 2; char *vname = strndup(start, len); if (!vname) { fprintf(stderr, "[JIT] OOM\n"); exit(1); } /* Get type from the STK_GLOBAL node */ _STN *gdecl = cur->body; if (gdecl && gdecl->kind == STK_GLOBAL) { register_global(jit, vname, gdecl->global.type); } free(vname); } } /* Second pass: register all real functions (so forward calls work) */ for (_FN *cur = fn_list; cur; cur = cur->n) { if (strncmp(cur->name, "__global_", 9) != 0) register_func(jit, cur->name, cur->ret_type); } /* Compile real functions in reverse order */ _FN *functions[64]; int count = 0; for (_FN *cur = fn_list; cur; cur = cur->n) { if (strncmp(cur->name, "__global_", 9) == 0) continue; if (count >= 64) { fprintf(stderr, "[JIT] Too many functions (max 64)\n"); exit(1); } functions[count++] = cur; } for (int i = count-1; i >= 0; i--) gen_function_jit(jit, functions[i], NULL); patch_function_calls(jit); /* Run global initialisers in declaration order */ for (_FN *cur = fn_list; cur; cur = cur->n) { if (strncmp(cur->name, "__global_", 9) != 0) continue; reset_varmap(jit); jit->current_func_name = cur->name; cb_init(&jit->cb); emit8(&jit->cb, 0x55); /* push rbp */ emitN(&jit->cb, (uint8_t[]){0x48,0x89,0xE5}, 3); /* mov rbp, rsp */ gen_stmt_jit(jit, cur->body); emit8(&jit->cb, 0xC9); /* leave */ emit8(&jit->cb, 0xC3); /* ret */ size_t isz = jit->cb.len; void *ibuf = mmap(NULL, isz, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0); if (ibuf == MAP_FAILED) { perror("mmap init"); exit(1); } memcpy(ibuf, jit->cb.buf, isz); ((void(*)(void))ibuf)(); munmap(ibuf, isz); } } static int jit_run(JIT *jit, int argc, char **argv) { int (*main_func)(int, char **) = get_func_addr(jit, "main"); fprintf(stderr, "[JIT] Running main (argc=%d argv=%p)\n", argc, (void*)argv); return main_func(argc, argv); } #endif