diff options
Diffstat (limited to 'src/codegen_jit.h')
| -rw-r--r-- | src/codegen_jit.h | 853 |
1 files changed, 740 insertions, 113 deletions
diff --git a/src/codegen_jit.h b/src/codegen_jit.h index ff1d687..2cd4aa2 100644 --- a/src/codegen_jit.h +++ b/src/codegen_jit.h @@ -32,6 +32,7 @@ typedef struct FuncMap { void *addr; size_t size; size_t alloc_size; + _TY ret_type; /* return type of this function */ struct FuncMap *next; } FuncMap; @@ -49,6 +50,14 @@ typedef struct VarMap { struct VarMap *next; } VarMap; +typedef struct GlobalVar { + char *name; + uint8_t *addr; /* pointer into globals_buf */ + _TY type; + int size; /* total bytes allocated */ + struct GlobalVar *next; +} GlobalVar; + typedef struct { FuncMap *func_list; VarMap *var_list; @@ -56,6 +65,17 @@ typedef struct { CodeBuf cb; char *current_func_name; PatchEntry *patch_list; + /* Globals data segment */ + uint8_t *globals_buf; /* mmap'd RW page(s) for global variables */ + size_t globals_cap; + size_t globals_used; + GlobalVar *global_list; + /* Break/continue patch stacks (up to 64 nesting levels) */ + size_t break_patches[64][256]; /* offsets to patch */ + int break_patch_count[64]; + size_t cont_patches[64][256]; + int cont_patch_count[64]; + int loop_depth; } JIT; static void jit_init(JIT *jit) { @@ -66,6 +86,19 @@ static void jit_init(JIT *jit) { jit->cb.len = jit->cb.cap = 0; jit->current_func_name = NULL; jit->patch_list = NULL; + jit->globals_cap = 1024 * 1024; + jit->globals_used = 0; + jit->globals_buf = (uint8_t *)mmap(NULL, jit->globals_cap, + PROT_READ|PROT_WRITE, + MAP_PRIVATE|MAP_ANONYMOUS, -1, 0); + if (jit->globals_buf == MAP_FAILED) { + perror("mmap globals"); exit(1); + } + memset(jit->globals_buf, 0, jit->globals_cap); + jit->global_list = NULL; + jit->loop_depth = 0; + memset(jit->break_patch_count, 0, sizeof(jit->break_patch_count)); + memset(jit->cont_patch_count, 0, sizeof(jit->cont_patch_count)); } static void jit_free(JIT *jit) { @@ -79,6 +112,14 @@ static void jit_free(JIT *jit) { } jit->patch_list = NULL; + for (GlobalVar *g = jit->global_list; g;) { + GlobalVar *n = g->next; free(g->name); free(g); g = n; + } + jit->global_list = NULL; + + if (jit->globals_buf && jit->globals_buf != MAP_FAILED) + munmap(jit->globals_buf, jit->globals_cap); + for (FuncMap *f = jit->func_list; f;) { FuncMap *n = f->next; if (f->addr && f->alloc_size > 0) munmap(f->addr, f->alloc_size); @@ -91,8 +132,6 @@ static void jit_free(JIT *jit) { jit->cb.len = jit->cb.cap = 0; } -/* --- Code buffer --- */ - static void cb_init(CodeBuf *c) { c->cap = 1024; c->len = 0; c->buf = (uint8_t *)malloc(c->cap); @@ -114,8 +153,6 @@ static void emit32(CodeBuf *c, uint32_t v) { cb_grow(c,4); memcpy(c->buf+c-> static void emit64(CodeBuf *c, uint64_t v) { cb_grow(c,8); memcpy(c->buf+c->len,&v,8); c->len+=8; } static void emitN(CodeBuf *c, const void *p, size_t n) { cb_grow(c,n); memcpy(c->buf+c->len,p,n); c->len+=n; } -/* --- x86-64 encoding helpers --- */ - static void emit_rex(CodeBuf *c, int reg, int rm, int w) { uint8_t rex = 0x40; if (w) rex |= 0x08; @@ -145,9 +182,18 @@ static void emit_mov_reg_mem_reg(CodeBuf *c, int dst, int base) { static void emit_mov_mem_reg_reg(CodeBuf *c, int base, int src) { emit_rex(c,src,base,1); emit8(c,0x89); emit_modrm(c,0,src,base); } +static void emit_mov_mem8_reg_reg(CodeBuf *c, int base) { + if (base & 8) emit8(c, 0x41); /* REX.B for extended base registers */ + emit8(c, 0x88); emit_modrm(c, 0, RAX, base); +} static void emit_movzx_rax_mem8(CodeBuf *c, int disp32) { emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB6); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); } +static void emit_movzx_rax_mem8_base(CodeBuf *c, int base) { + if (base & 8) emit8(c, 0x49); else emit8(c, 0x48); + emit8(c, 0x0F); emit8(c, 0xB6); + emit_modrm(c, 0, RAX, base); +} static void emit_mov_mem8_rax(CodeBuf *c, int disp32) { emit_rex(c,RAX,RBP,0); emit8(c,0x88); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); } @@ -174,8 +220,8 @@ static void emit_add_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x01,0xD8}, static void emit_imul_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xAF,0xC3},4); } static void emit_idiv_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); } static void emit_imod(CodeBuf *c) { - emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); // cqo; idiv rbx - emit_mov_reg_reg(c, RAX, RDX); // remainder -> RAX + emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); + emit_mov_reg_reg(c, RAX, RDX); } static void emit_or_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x09,0xD8},3); } static void emit_xor_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x31,0xD8},3); } @@ -192,53 +238,97 @@ static void emit_jcc_rel32(CodeBuf *c, uint8_t cc, int32_t rel) { static void emit_test_rax_rax(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x85,0xC0},3); } static void emit_prologue(CodeBuf *c, int total_stack_size) { - emit8(c,0x55); // push rbp - emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); // mov rbp, rsp - emit_push_reg(c, RBX); // save callee-saved RBX + emit8(c,0x55); + emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); + emit_push_reg(c, RBX); int stack_bytes = ((total_stack_size+15)/16)*16; if (stack_bytes > 0) { - emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); // sub rsp, imm32 + emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); emit32(c,(uint32_t)stack_bytes); } } static void emit_epilogue(CodeBuf *c) { - emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); // restore RBX; leave; ret + emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); } static void emit_lea_rax_rbp_disp(CodeBuf *c, int disp32) { emit8(c,0x48); emit8(c,0x8D); emit_modrm(c,2,RAX,RBP); emit32(c,(uint32_t)disp32); } +static void emit_movzx_rax_mem16(CodeBuf *c, int disp32) { + emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB7); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); +} +static void emit_mov_mem16_rax(CodeBuf *c, int disp32) { + emit8(c,0x66); emit_rex(c,RAX,RBP,0); emit8(c,0x89); emit_modrm(c,2,RAX,RBP); emit32(c,disp32); +} static void emit_load_rax_from_mem(CodeBuf *c, int disp32, int size) { - if (size == 1) emit_movzx_rax_mem8(c, disp32); else emit_mov_reg_mem64(c, RAX, disp32); + if (size == 1) emit_movzx_rax_mem8(c, disp32); + else if (size == 2) emit_movzx_rax_mem16(c, disp32); + else emit_mov_reg_mem64(c, RAX, disp32); } static void emit_store_rax_to_mem(CodeBuf *c, int disp32, int size) { - if (size == 1) emit_mov_mem8_rax(c, disp32); else emit_mov_mem64_reg(c, disp32, RAX); + if (size == 1) emit_mov_mem8_rax(c, disp32); + else if (size == 2) emit_mov_mem16_rax(c, disp32); + else emit_mov_mem64_reg(c, disp32, RAX); } -/* --- Variable and type helpers --- */ - static int calculate_type_size(_TY type) { if (type.ptr_level > 0) return 8; - int base_size = (type.base == TY_CHAR) ? 1 : 8; - size_t total = (size_t)base_size; + int slot; + switch (type.base) { + case TY_CHAR: slot = 1; break; + case TY_BOOL: slot = 1; break; + case TY_SHORT: slot = 2; break; + case TY_INT: slot = 8; break; + case TY_FLOAT: slot = 8; break; + case TY_LONG: slot = 8; break; + case TY_VOID: slot = 0; break; + default: slot = 8; break; + } if (type.array_size > 0) { - total *= (size_t)type.array_size; + size_t total = (size_t)slot * (size_t)type.array_size; if (total > (size_t)INT_MAX) { fprintf(stderr, "[JIT] array size too large\n"); exit(1); } + return (int)total; } - return (int)total; + return slot; } static int align_offset(int offset, _TY type) { - int align = (type.base == TY_CHAR) ? 1 : 8; + int align; + if (type.ptr_level > 0) { align = 8; } + else switch (type.base) { + case TY_CHAR: align = 1; break; + case TY_BOOL: align = 1; break; + case TY_SHORT: align = 2; break; + case TY_INT: align = 8; break; + case TY_FLOAT: align = 8; break; + case TY_LONG: align = 8; break; + default: align = 8; break; + } if (offset % align == 0) return offset; + if (offset < 0) + return ((offset - (align - 1)) / align) * align; return (offset / align) * align; } +static int ty_slot_size(_TY ty) { + if (ty.ptr_level > 0) return 8; + switch (ty.base) { + case TY_CHAR: return 1; + case TY_BOOL: return 1; + case TY_SHORT: return 2; + case TY_INT: return 8; /* promoted to 64-bit slot */ + case TY_FLOAT: return 8; /* stored in 64-bit slot (integer bits) */ + case TY_LONG: return 8; + case TY_VOID: return 0; + default: return 8; + } +} + static void reset_varmap(JIT *jit) { for (VarMap *v = jit->var_list; v;) { VarMap *n = v->next; free(v->name); free(v); v = n; } jit->var_list = NULL; - jit->next_local_offset = -16; // after saved RBX at RBP-8, 16-byte aligned + jit->next_local_offset = -16; } static void add_var(JIT *jit, const char *name, _TY type) { @@ -248,9 +338,9 @@ static void add_var(JIT *jit, const char *name, _TY type) { if (!v->name) { fprintf(stderr, "[JIT] strdup failed in add_var\n"); free(v); exit(1); } v->type = type; int type_size = calculate_type_size(type); + jit->next_local_offset -= type_size; jit->next_local_offset = align_offset(jit->next_local_offset, type); v->offset = jit->next_local_offset; - jit->next_local_offset -= type_size; v->next = jit->var_list; jit->var_list = v; } @@ -261,23 +351,60 @@ static int get_var_offset(JIT *jit, const char *name) { fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1); } +static GlobalVar *find_global(JIT *jit, const char *name) { + for (GlobalVar *g = jit->global_list; g; g = g->next) + if (strcmp(g->name, name) == 0) return g; + return NULL; +} + +static GlobalVar *register_global(JIT *jit, const char *name, _TY type) { + if (find_global(jit, name)) { + fprintf(stderr, "[JIT] Duplicate global '%s'\n", name); exit(1); + } + int elem_sz = ty_slot_size((_TY){type.base, type.ptr_level > 0 ? type.ptr_level : 0, -1}); + int n_elems = (type.array_size > 0) ? type.array_size : 1; + int total_sz = elem_sz * n_elems; + size_t align = (elem_sz < 8) ? elem_sz : 8; + size_t off = (jit->globals_used + align - 1) & ~(align - 1); + if (off + (size_t)total_sz > jit->globals_cap) { + fprintf(stderr, "[JIT] Globals segment full\n"); exit(1); + } + GlobalVar *g = (GlobalVar *)malloc(sizeof(GlobalVar)); + if (!g) { fprintf(stderr, "[JIT] OOM\n"); exit(1); } + g->name = strdup(name); + g->addr = jit->globals_buf + off; + g->type = type; + g->size = total_sz; + g->next = jit->global_list; + jit->global_list = g; + jit->globals_used = off + total_sz; + return g; +} + static _TY get_var_type(JIT *jit, const char *name) { for (VarMap *v = jit->var_list; v; v = v->next) if (strcmp(v->name, name) == 0) return v->type; + /* Fall back to globals */ + GlobalVar *g = find_global(jit, name); + if (g) return g->type; fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1); } -/* --- Type checking --- */ +static _TY get_func_ret_type(JIT *jit, const char *name); +static void reset_varmap(JIT *jit); static _TY get_expr_type(JIT *jit, _EX *expr) { switch (expr->kind) { case EX_NUMBER: return (_TY){TY_INT, 0, -1}; case EX_STRING: return (_TY){TY_CHAR, 1, -1}; - case EX_VAR: return get_var_type(jit, expr->name); + case EX_VAR: { + _TY t = get_var_type(jit, expr->name); + if (t.array_size > 0) return (_TY){t.base, t.ptr_level + 1, -1}; + return t; + } case EX_BINOP: { _TY left_type = get_expr_type(jit, expr->binop.l); - // All arithmetic, comparison, bitwise ops produce int - (void)get_expr_type(jit, expr->binop.r); + (void)get_expr_type(jit, expr->binop.r); switch (expr->binop.op) { case TK_PLUS: case TK_MINUS: case TK_STAR: case TK_SLASH: case TK_PERCENT: case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE: @@ -287,7 +414,13 @@ static _TY get_expr_type(JIT *jit, _EX *expr) { default: return left_type; } } - case EX_CALL: return (_TY){TY_INT, 0, -1}; + case EX_CALL: + if (strcmp(expr->call.func_name, "syscall") == 0) + return (_TY){TY_LONG, 0, -1}; + if (strcmp(expr->call.func_name, "__initlist__") == 0 || + strcmp(expr->call.func_name, "__sizeof__") == 0) + return (_TY){TY_INT, 0, -1}; + return get_func_ret_type(jit, expr->call.func_name); case EX_INDEX: { _TY t = get_expr_type(jit, expr->index.array); if (t.array_size > 0) return (_TY){t.base, t.ptr_level, -1}; @@ -303,31 +436,53 @@ static _TY get_expr_type(JIT *jit, _EX *expr) { _TY t = get_expr_type(jit, expr->addr.expr); return (_TY){t.base, t.ptr_level+1, -1}; } + case EX_TERNARY: + return get_expr_type(jit, expr->ternary.then_expr); + case EX_CAST: + return expr->cast.to; default: return (_TY){TY_INT, 0, -1}; } } +static int type_is_integer(_TY ty) { + if (ty.ptr_level > 0) return 0; + switch (ty.base) { + case TY_INT: case TY_CHAR: case TY_SHORT: + case TY_LONG: case TY_BOOL: return 1; + default: return 0; + } +} + static int types_compatible(_TY expected, _TY actual) { if (expected.base == actual.base && expected.ptr_level == actual.ptr_level && expected.array_size == actual.array_size) return 1; - // Allow untyped int literals to be assigned anywhere if (actual.base == TY_INT && actual.ptr_level == 0 && actual.array_size == -1) return 1; + if (type_is_integer(expected) && type_is_integer(actual)) return 1; + if (expected.ptr_level > 0 && actual.ptr_level > 0) return 1; + if (expected.ptr_level > 0 && actual.array_size > 0 && + expected.base == actual.base && + expected.ptr_level == actual.ptr_level + 1) return 1; return 0; } -/* --- Function registry --- */ - -static void register_func(JIT *jit, const char *name) { +static void register_func(JIT *jit, const char *name, _TY ret_type) { FuncMap *f = (FuncMap *)malloc(sizeof(FuncMap)); if (!f) { fprintf(stderr, "[JIT] malloc failed in register_func\n"); exit(1); } f->name = strdup(name); if (!f->name) { fprintf(stderr, "[JIT] strdup failed in register_func\n"); free(f); exit(1); } f->addr = NULL; f->size = 0; f->alloc_size = 0; + f->ret_type = ret_type; f->next = jit->func_list; jit->func_list = f; } +static _TY get_func_ret_type(JIT *jit, const char *name) { + for (FuncMap *f = jit->func_list; f; f = f->next) + if (strcmp(f->name, name) == 0) return f->ret_type; + return (_TY){TY_INT, 0, -1}; +} + static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, size_t alloc_size) { for (FuncMap *f = jit->func_list; f; f = f->next) { if (strcmp(f->name, name) != 0) continue; @@ -352,17 +507,19 @@ static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, s static void *get_func_addr(JIT *jit, const char *name) { for (FuncMap *f = jit->func_list; f; f = f->next) { if (strcmp(f->name, name) == 0) - return f->addr ? f->addr : (void*)0xDEADBEEF; // placeholder; patched later + return f->addr ? f->addr : (void*)0xDEADBEEF; } fprintf(stderr, "[JIT] get_func_addr: unknown function '%s'\n", name); exit(1); } -/* --- Stack size calculation --- */ static int calculate_stack_size(_STN *s) { int total = 0; while (s) { - if (s->kind == STK_VAR_DECL) total += calculate_type_size(s->var_decl.type); + if (s->kind == STK_VAR_DECL) { + int sz = calculate_type_size(s->var_decl.type); + total += (sz < 8 && s->var_decl.type.array_size <= 0) ? 8 : sz; + } if (s->kind == STK_BLOCK) total += calculate_stack_size(s->body); if (s->kind == STK_FOR) { if (s->fr.init) total += calculate_stack_size(s->fr.init); @@ -370,15 +527,21 @@ static int calculate_stack_size(_STN *s) { } s = s->n; } - return total; + return total + 8; +} + +static int calculate_total_stack_size(_FN *f) { + int param_space = 0; + for (int i = 0; i < f->pac && i < 6; i++) + param_space += (calculate_type_size(f->param_types[i]) < 8) ? 8 + : calculate_type_size(f->param_types[i]); + return calculate_stack_size(f->body) + param_space; } -/* --- Code generation (forward declarations) --- */ static void gen_expr_jit(JIT *jit, _EX *e); static int gen_stmt_jit(JIT *jit, _STN *s); -/* --- Statement generation --- */ static int gen_stmt_jit(JIT *jit, _STN *s) { while (s) { @@ -387,6 +550,19 @@ static int gen_stmt_jit(JIT *jit, _STN *s) { case STK_VAR_DECL: add_var(jit, s->var_decl.name, s->var_decl.type); if (s->var_decl.init) { + if (s->var_decl.init->kind == EX_CALL && + strcmp(s->var_decl.init->call.func_name, "__initlist__") == 0) { + _TY vty = s->var_decl.type; + int elem_sz = ty_slot_size((_TY){vty.base, 0, -1}); + int arr_off = get_var_offset(jit, s->var_decl.name); + for (int ii = 0; ii < s->var_decl.init->call.argc; ii++) { + gen_expr_jit(jit, s->var_decl.init->call.args[ii]); + /* address of arr[ii] = RBP + arr_off + ii*elem_sz */ + int elem_off = arr_off + ii * elem_sz; + emit_store_rax_to_mem(&jit->cb, elem_off, elem_sz); + } + break; + } _TY init_type = get_expr_type(jit, s->var_decl.init); if (!types_compatible(s->var_decl.type, init_type)) { fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n", @@ -394,8 +570,7 @@ static int gen_stmt_jit(JIT *jit, _STN *s) { exit(1); } gen_expr_jit(jit, s->var_decl.init); - int sz = (s->var_decl.type.ptr_level > 0) ? 8 - : (s->var_decl.type.base == TY_CHAR) ? 1 : 8; + int sz = ty_slot_size(s->var_decl.type); emit_store_rax_to_mem(&jit->cb, get_var_offset(jit, s->var_decl.name), sz); } break; @@ -403,6 +578,15 @@ static int gen_stmt_jit(JIT *jit, _STN *s) { case STK_ASSIGN: { _EX *lhs = s->assign.lhs; if (lhs->kind == EX_VAR) { + GlobalVar *gv = find_global(jit, lhs->name); + if (gv) { + gen_expr_jit(jit, s->assign.expr); + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr); + int sz = ty_slot_size(gv->type); + if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + break; + } int offset = get_var_offset(jit, lhs->name); _TY type = get_var_type(jit, lhs->name); _TY expr_type = get_expr_type(jit, s->assign.expr); @@ -412,36 +596,74 @@ static int gen_stmt_jit(JIT *jit, _STN *s) { exit(1); } gen_expr_jit(jit, s->assign.expr); - int sz = (type.ptr_level > 0) ? 8 : (type.base == TY_CHAR) ? 1 : 8; + int sz = ty_slot_size(type); emit_store_rax_to_mem(&jit->cb, offset, sz); } else if (lhs->kind == EX_DEREF) { gen_expr_jit(jit, lhs->deref.expr); - emit_mov_reg_reg(&jit->cb, RBX, RAX); // RBX = address - gen_expr_jit(jit, s->assign.expr); // RAX = value - emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + emit_push_reg(&jit->cb, RAX); // save address on stack + gen_expr_jit(jit, s->assign.expr); // RAX = value (may clobber RBX) + emit_pop_reg(&jit->cb, RBX); // RBX = address + _TY ptr_ty = get_expr_type(jit, lhs->deref.expr); + int dsz = (ptr_ty.ptr_level > 1) ? 8 : ty_slot_size((_TY){ptr_ty.base, 0, -1}); + if (dsz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } else if (lhs->kind == EX_INDEX) { if (lhs->index.array->kind != EX_VAR) { gen_expr_jit(jit, lhs->index.array); break; } _TY var_type = get_var_type(jit, lhs->index.array->name); - if (var_type.array_size <= 0) { - fprintf(stderr, "[JIT] Cannot index non-array '%s'\n", lhs->index.array->name); exit(1); + int element_size = ty_slot_size((_TY){var_type.base, 0, -1}); + + if (var_type.ptr_level > 0) { + GlobalVar *gv_ptr2 = find_global(jit, lhs->index.array->name); + if (gv_ptr2) { + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr2->addr); + emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref global ptr */ + } else { + int ptr_offset = get_var_offset(jit, lhs->index.array->name); + emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer + } + gen_expr_jit(jit, lhs->index.index); // RAX = index + if (element_size > 1) { + emit_mov_reg_reg(&jit->cb, RCX, RBX); + emit_mov_reg_imm64(&jit->cb, RBX, element_size); + emit_imul_rax_rbx(&jit->cb); // RAX = byte offset + emit_mov_reg_reg(&jit->cb, RBX, RCX); + } + emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index] + emit_push_reg(&jit->cb, RBX); // save address + gen_expr_jit(jit, s->assign.expr); // RAX = value + emit_pop_reg(&jit->cb, RBX); // restore address + if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + + } else if (var_type.array_size > 0) { + GlobalVar *gv_arr2 = find_global(jit, lhs->index.array->name); + gen_expr_jit(jit, lhs->index.index); // RAX = index + emit_mov_reg_imm64(&jit->cb, RBX, element_size); + emit_imul_rax_rbx(&jit->cb); // RAX = byte offset + if (gv_arr2) { + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr2->addr); + emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index] + } else { + int array_offset = get_var_offset(jit, lhs->index.array->name); + emit_mov_reg_imm64(&jit->cb, RBX, array_offset); + emit_add_reg_reg(&jit->cb, RBX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RBP); + emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset + } + emit_push_reg(&jit->cb, RBX); // save address + gen_expr_jit(jit, s->assign.expr); // RAX = value + emit_pop_reg(&jit->cb, RBX); // RBX = address + if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + + } else { + fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", lhs->index.array->name); exit(1); } - int array_offset = get_var_offset(jit, lhs->index.array->name); - int element_size = (var_type.base == TY_CHAR) ? 1 : 8; - gen_expr_jit(jit, lhs->index.index); // RAX = index - emit_mov_reg_imm64(&jit->cb, RBX, element_size); - emit_imul_rax_rbx(&jit->cb); // RAX = byte offset - emit_mov_reg_imm64(&jit->cb, RBX, array_offset); - emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset - emit_mov_reg_reg(&jit->cb, RAX, RBX); - emit_mov_reg_reg(&jit->cb, RBX, RBP); - emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index] - emit_mov_reg_reg(&jit->cb, RBX, RAX); - gen_expr_jit(jit, s->assign.expr); // RAX = value - emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); } else { fprintf(stderr, "[JIT] Unsupported assignment LHS kind %d\n", lhs->kind); exit(1); @@ -479,38 +701,151 @@ static int gen_stmt_jit(JIT *jit, _STN *s) { } case STK_WHILE: { + int depth = jit->loop_depth++; + jit->break_patch_count[depth] = 0; + jit->cont_patch_count[depth] = 0; + size_t loop_start = jit->cb.len; gen_expr_jit(jit, s->whl.cond); emit_test_rax_rax(&jit->cb); size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); gen_stmt_jit(jit, s->whl.body); + size_t cont_target = jit->cb.len; + for (int i = 0; i < jit->cont_patch_count[depth]; i++) { + size_t p = jit->cont_patches[depth][i]; + int32_t r = (int32_t)(cont_target - (p + 5)); + memcpy(jit->cb.buf + p + 1, &r, 4); + } emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); - int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6)); + size_t break_target = jit->cb.len; + int32_t rel_end = (int32_t)(break_target - (jz_pos + 6)); memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4); + for (int i = 0; i < jit->break_patch_count[depth]; i++) { + size_t p = jit->break_patches[depth][i]; + int32_t r = (int32_t)(break_target - (p + 5)); + memcpy(jit->cb.buf + p + 1, &r, 4); + } + jit->loop_depth--; break; } case STK_FOR: { + int depth = jit->loop_depth++; + jit->break_patch_count[depth] = 0; + jit->cont_patch_count[depth] = 0; + if (s->fr.init) gen_stmt_jit(jit, s->fr.init); size_t loop_start = jit->cb.len; - if (s->fr.cond) { + size_t jz_pos_for = 0; + int has_cond = (s->fr.cond != NULL); + if (has_cond) { gen_expr_jit(jit, s->fr.cond); emit_test_rax_rax(&jit->cb); - size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); - gen_stmt_jit(jit, s->fr.body); - if (s->fr.step) gen_stmt_jit(jit, s->fr.step); - emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); - int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6)); - memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4); - } else { - gen_stmt_jit(jit, s->fr.body); - if (s->fr.step) gen_stmt_jit(jit, s->fr.step); - emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); + jz_pos_for = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0); + } + gen_stmt_jit(jit, s->fr.body); + size_t cont_target = jit->cb.len; + for (int i = 0; i < jit->cont_patch_count[depth]; i++) { + size_t p = jit->cont_patches[depth][i]; + int32_t r = (int32_t)(cont_target - (p + 5)); + memcpy(jit->cb.buf + p + 1, &r, 4); } + if (s->fr.step) gen_stmt_jit(jit, s->fr.step); + emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5))); + size_t break_target = jit->cb.len; + if (has_cond) { + int32_t rel_end = (int32_t)(break_target - (jz_pos_for + 6)); + memcpy(jit->cb.buf + jz_pos_for + 2, &rel_end, 4); + } + for (int i = 0; i < jit->break_patch_count[depth]; i++) { + size_t p = jit->break_patches[depth][i]; + int32_t r = (int32_t)(break_target - (p + 5)); + memcpy(jit->cb.buf + p + 1, &r, 4); + } + jit->loop_depth--; break; } - default: + case STK_DOWHILE: { + int depth = jit->loop_depth++; + jit->break_patch_count[depth] = 0; + jit->cont_patch_count[depth] = 0; + + size_t loop_start = jit->cb.len; + gen_stmt_jit(jit, s->dowhl.body); + /* continue → jump to condition check */ + size_t cont_target = jit->cb.len; + for (int i = 0; i < jit->cont_patch_count[depth]; i++) { + size_t p = jit->cont_patches[depth][i]; + int32_t r = (int32_t)(cont_target - (p + 5)); + memcpy(jit->cb.buf + p + 1, &r, 4); + } + gen_expr_jit(jit, s->dowhl.cond); + emit_test_rax_rax(&jit->cb); + emit_jcc_rel32(&jit->cb, 0x05, (int32_t)(loop_start - (jit->cb.len + 6))); + size_t break_target = jit->cb.len; + for (int i = 0; i < jit->break_patch_count[depth]; i++) { + size_t p = jit->break_patches[depth][i]; + int32_t r = (int32_t)(break_target - (p + 5)); + memcpy(jit->cb.buf + p + 1, &r, 4); + } + jit->loop_depth--; + break; + } + + case STK_BREAK: { + if (jit->loop_depth == 0) { + fprintf(stderr, "[JIT] 'break' outside loop\n"); exit(1); + } + int depth = jit->loop_depth - 1; + size_t pos = jit->cb.len; + emit_jmp_rel32(&jit->cb, 0); + int cnt = jit->break_patch_count[depth]; + if (cnt >= 256) { fprintf(stderr, "[JIT] Too many breaks\n"); exit(1); } + jit->break_patches[depth][cnt] = pos; + jit->break_patch_count[depth]++; + break; + } + + case STK_CONTINUE: { + if (jit->loop_depth == 0) { + fprintf(stderr, "[JIT] 'continue' outside loop\n"); exit(1); + } + int depth = jit->loop_depth - 1; + size_t pos = jit->cb.len; + emit_jmp_rel32(&jit->cb, 0); + int cnt = jit->cont_patch_count[depth]; + if (cnt >= 256) { fprintf(stderr, "[JIT] Too many continues\n"); exit(1); } + jit->cont_patches[depth][cnt] = pos; + jit->cont_patch_count[depth]++; + break; + } + + case STK_GLOBAL: { + GlobalVar *gv = find_global(jit, s->global.name); + if (!gv) { + fprintf(stderr, "[JIT] Global '%s' not registered\n", s->global.name); exit(1); + } + if (s->global.init) { + if (s->global.init->kind == EX_CALL && + strcmp(s->global.init->call.func_name, "__initlist__") == 0) { + int esz = ty_slot_size((_TY){gv->type.base, 0, -1}); + for (int ii = 0; ii < s->global.init->call.argc; ii++) { + gen_expr_jit(jit, s->global.init->call.args[ii]); + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)(gv->addr + ii * esz)); + if (esz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + } + } else { + gen_expr_jit(jit, s->global.init); + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr); + int sz = ty_slot_size(gv->type); + if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + } + } + break; + } fprintf(stderr, "[JIT] Unsupported statement kind %d\n", s->kind); exit(1); } s = s->n; @@ -518,7 +853,6 @@ static int gen_stmt_jit(JIT *jit, _STN *s) { return 0; } -/* --- Expression generation --- */ static void gen_expr_jit(JIT *jit, _EX *e) { if (!e) return; @@ -529,16 +863,110 @@ static void gen_expr_jit(JIT *jit, _EX *e) { break; case EX_VAR: { - int off = get_var_offset(jit, e->name); - _TY ty = get_var_type(jit, e->name); - int sz = (ty.ptr_level > 0) ? 8 : (ty.base == TY_CHAR) ? 1 : 8; - if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off); - else emit_mov_reg_mem64(&jit->cb, RAX, off); - break; + GlobalVar *gv = NULL; + for (VarMap *v = jit->var_list; v; v = v->next) { + if (strcmp(v->name, e->name) == 0) goto local_var; + } + gv = find_global(jit, e->name); + if (gv) { + _TY ty = gv->type; + if (ty.array_size > 0) { + /* Global array: load its absolute address */ + emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv->addr); + break; + } + /* Global scalar: absolute address → load through it */ + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr); + int sz = ty_slot_size(ty); + if (sz == 1) emit_movzx_rax_mem8_base(&jit->cb, RBX); + else emit_mov_reg_mem_reg(&jit->cb, RAX, RBX); + break; + } + local_var: { + int off = get_var_offset(jit, e->name); + _TY ty = get_var_type(jit, e->name); + if (ty.array_size > 0) { + emit_lea_rax_rbp_disp(&jit->cb, off); + break; + } + int sz = ty_slot_size(ty); + if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off); + else emit_mov_reg_mem64(&jit->cb, RAX, off); + break; + } } case EX_BINOP: { - // Short-circuit AND + if (e->binop.op == TK_INC || e->binop.op == TK_DEC) { + int is_pre = (e->binop.r->value == -1); + int is_inc = (e->binop.op == TK_INC); + _EX *lval = e->binop.l; + + _TY lty = get_expr_type(jit, lval); + int step = (lty.ptr_level > 0) ? ty_slot_size((_TY){lty.base,0,-1}) : 1; + + gen_expr_jit(jit, lval); /* RAX = old value */ + emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = old value */ + emit_mov_reg_imm64(&jit->cb, RBX, step); + if (is_inc) emit_add_reg_reg(&jit->cb, RCX, RBX); + else emit_sub_reg_reg(&jit->cb, RCX, RBX); + if (lval->kind == EX_VAR) { + GlobalVar *gv_inc = find_global(jit, lval->name); + emit_mov_reg_reg(&jit->cb, RAX, RCX); + if (gv_inc) { + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_inc->addr); + int sz = ty_slot_size(lty); + if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX); + else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + } else { + int off = get_var_offset(jit, lval->name); + int sz = ty_slot_size(lty); + emit_store_rax_to_mem(&jit->cb, off, sz); + } + } else { + emit_push_reg(&jit->cb, RCX); /* save new value */ + if (lval->kind == EX_DEREF) { + gen_expr_jit(jit, lval->deref.expr); /* RAX = address */ + } else { /* EX_INDEX */ + _TY vty = get_var_type(jit, lval->index.array->name); + int esz = ty_slot_size((_TY){vty.base,0,-1}); + if (vty.ptr_level > 0) { + int poff = get_var_offset(jit, lval->index.array->name); + emit_mov_reg_mem64(&jit->cb, RAX, poff); + emit_push_reg(&jit->cb, RAX); + gen_expr_jit(jit, lval->index.index); + if (esz > 1) { + emit_mov_reg_imm64(&jit->cb, RBX, esz); + emit_imul_rax_rbx(&jit->cb); + } + emit_pop_reg(&jit->cb, RBX); + emit_add_reg_reg(&jit->cb, RAX, RBX); + } else { + int aoff = get_var_offset(jit, lval->index.array->name); + gen_expr_jit(jit, lval->index.index); + emit_mov_reg_imm64(&jit->cb, RBX, esz); + emit_imul_rax_rbx(&jit->cb); + emit_mov_reg_imm64(&jit->cb, RBX, aoff); + emit_add_reg_reg(&jit->cb, RBX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RBP); + emit_add_reg_reg(&jit->cb, RAX, RBX); + } + } + emit_mov_reg_reg(&jit->cb, RBX, RAX); /* RBX = address */ + emit_pop_reg(&jit->cb, RCX); /* RCX = new val */ + emit_mov_reg_reg(&jit->cb, RAX, RCX); + emit_mov_mem_reg_reg(&jit->cb, RBX, RAX); + } + if (is_pre) emit_mov_reg_reg(&jit->cb, RAX, RCX); + if (!is_pre) { + emit_mov_reg_reg(&jit->cb, RAX, RCX); + emit_mov_reg_imm64(&jit->cb, RBX, step); + if (is_inc) emit_sub_reg_reg(&jit->cb, RAX, RBX); + else emit_add_reg_reg(&jit->cb, RAX, RBX); + } + break; + } if (e->binop.op == TK_AND) { gen_expr_jit(jit, e->binop.l); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); @@ -554,7 +982,6 @@ static void gen_expr_jit(JIT *jit, _EX *e) { memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4); break; } - // Short-circuit OR if (e->binop.op == TK_OR) { gen_expr_jit(jit, e->binop.l); emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb); @@ -570,18 +997,54 @@ static void gen_expr_jit(JIT *jit, _EX *e) { memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4); break; } - // All other binary ops: eval left -> push; eval right -> RBX = left gen_expr_jit(jit, e->binop.l); emit_push_reg(&jit->cb, RAX); gen_expr_jit(jit, e->binop.r); emit_pop_reg(&jit->cb, RBX); // RBX = left, RAX = right switch (e->binop.op) { - case TK_PLUS: emit_add_rax_rbx(&jit->cb); break; - case TK_MINUS: + case TK_PLUS: { + _TY lt = get_expr_type(jit, e->binop.l); + _TY rt = get_expr_type(jit, e->binop.r); + int lptr = (lt.ptr_level > 0 || lt.array_size > 0); + int rptr = (rt.ptr_level > 0 || rt.array_size > 0); + if (lptr && !rptr) { + int esz = ty_slot_size((_TY){lt.base, 0, -1}); + if (esz > 1) { + emit_mov_reg_imm64(&jit->cb, RCX, esz); + emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4); + } + } else if (rptr && !lptr) { + int esz = ty_slot_size((_TY){rt.base, 0, -1}); + emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = ptr */ + emit_mov_reg_reg(&jit->cb, RAX, RBX); /* RAX = int */ + emit_mov_reg_reg(&jit->cb, RBX, RCX); /* RBX = ptr */ + if (esz > 1) { + emit_mov_reg_imm64(&jit->cb, RCX, esz); + emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4); + } + } + emit_add_rax_rbx(&jit->cb); + break; + } + case TK_MINUS: { + _TY lt = get_expr_type(jit, e->binop.l); + _TY rt = get_expr_type(jit, e->binop.r); + int lptr = (lt.ptr_level > 0 || lt.array_size > 0); + int rptr = (rt.ptr_level > 0 || rt.array_size > 0); + if (lptr && !rptr) { + /* ptr - int: scale int by element size */ + int esz = ty_slot_size((_TY){lt.base, 0, -1}); + if (esz > 1) { + emit_mov_reg_imm64(&jit->cb, RCX, esz); + emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4); + } + } + /* RAX=right(scaled), RBX=left: result = left - right */ emit_mov_reg_reg(&jit->cb, RCX, RAX); emit_mov_reg_reg(&jit->cb, RAX, RBX); emit_rex(&jit->cb, RCX, RAX, 1); emit8(&jit->cb, 0x29); emit_modrm(&jit->cb, 3, RCX, RAX); break; + } case TK_STAR: emit_imul_rax_rbx(&jit->cb); break; case TK_SLASH: emit_mov_reg_reg(&jit->cb, RCX, RAX); @@ -626,24 +1089,84 @@ static void gen_expr_jit(JIT *jit, _EX *e) { } case EX_CALL: { + if (strcmp(e->call.func_name, "__sizeof__") == 0) { + _EX *arg = e->call.args[0]; + _TY ty; + if (arg->kind == EX_VAR) { + ty = get_var_type(jit, arg->name); + } else { + ty = get_expr_type(jit, arg); + } + int sz; + if (ty.array_size > 0) { + int esz = (ty.base == TY_CHAR) ? 1 : (ty.base == TY_SHORT) ? 2 : + (ty.base == TY_LONG) ? 8 : 4; + sz = esz * ty.array_size; + } else if (ty.ptr_level > 0) { + sz = 8; + } else { + switch (ty.base) { + case TY_CHAR: sz = 1; break; + case TY_SHORT: sz = 2; break; + case TY_LONG: sz = 8; break; + default: sz = 4; break; /* int, bool, float */ + } + } + emit_mov_reg_imm64(&jit->cb, RAX, sz); + break; + } + + if (strcmp(e->call.func_name, "syscall") == 0) { + int argc = e->call.argc; + if (argc < 1) { + fprintf(stderr, "[JIT] syscall() requires at least 1 argument (number)\n"); exit(1); + } + if (argc > 7) { + fprintf(stderr, "[JIT] syscall() supports at most 7 arguments\n"); exit(1); + } + const int sc_regs[6] = {RDI, RSI, RDX, R10, R8, R9}; + int nargs = argc - 1; /* number of args after the syscall number */ + + for (int i = argc - 1; i >= 0; i--) { + gen_expr_jit(jit, e->call.args[i]); + emit_push_reg(&jit->cb, RAX); + } + + emit_pop_reg(&jit->cb, RAX); + + for (int i = 0; i < nargs; i++) { + emit_pop_reg(&jit->cb, sc_regs[i]); + } + + emit8(&jit->cb, 0x0F); + emit8(&jit->cb, 0x05); + break; + } + int total_args = e->call.argc; int stack_args = total_args > 6 ? total_args - 6 : 0; + int reg_args = total_args < 6 ? total_args : 6; int padding = (stack_args % 2) ? 8 : 0; const int arg_regs[6] = {RDI, RSI, RDX, RCX, R8, R9}; + /* Push stack-passed args right-to-left (args[total-1] first). */ for (int i = total_args-1; i >= 6; i--) { gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX); } - for (int i = 0; i < total_args && i < 6; i++) { - gen_expr_jit(jit, e->call.args[i]); emit_mov_reg_reg(&jit->cb, arg_regs[i], RAX); + for (int i = reg_args-1; i >= 0; i--) { + gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX); + } + + for (int i = 0; i < reg_args; i++) { + emit_pop_reg(&jit->cb, arg_regs[i]); } + if (padding) { emit8(&jit->cb,0x48); emit8(&jit->cb,0x83); emit8(&jit->cb,0xEC); emit8(&jit->cb,0x08); } void *addr = get_func_addr(jit, e->call.func_name); if (addr == (void*)0xDEADBEEF) { - // Forward call: emit placeholder imm64, record offset for later patching emit_movabs_rax_imm64(&jit->cb, (uint64_t)0xDEADBEEF); PatchEntry *patch = (PatchEntry*)malloc(sizeof(PatchEntry)); if (!patch) { fprintf(stderr, "[JIT] malloc failed (patch)\n"); exit(1); } @@ -659,7 +1182,7 @@ static void gen_expr_jit(JIT *jit, _EX *e) { } else { emit_movabs_rax_imm64(&jit->cb, (uint64_t)addr); } - emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); // CALL RAX + emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); if (stack_args * 8 + padding > 0) { emit8(&jit->cb,0x48); emit8(&jit->cb,0x81); emit8(&jit->cb,0xC4); @@ -669,12 +1192,56 @@ static void gen_expr_jit(JIT *jit, _EX *e) { } case EX_ADDR: - if (e->addr.expr->kind != EX_VAR) { + if (e->addr.expr->kind == EX_VAR) { + GlobalVar *gv_addr = find_global(jit, e->addr.expr->name); + if (gv_addr) { + emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv_addr->addr); + } else { + emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name)); + } + } else { fprintf(stderr, "[JIT] &expr: only &var supported\n"); exit(1); } - emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name)); break; + case EX_TERNARY: { + gen_expr_jit(jit, e->ternary.cond); + emit_mov_reg_reg(&jit->cb, RBX, RAX); + emit_or_rax_rbx(&jit->cb); + size_t jz_pos = jit->cb.len; + emit_jcc_rel32(&jit->cb, 0x04, 0); + gen_expr_jit(jit, e->ternary.then_expr); + size_t jmp_pos = jit->cb.len; + emit_jmp_rel32(&jit->cb, 0); + int32_t rel_jz = (int32_t)(jit->cb.len - (jz_pos + 6)); + memcpy(jit->cb.buf + jz_pos + 2, &rel_jz, 4); + gen_expr_jit(jit, e->ternary.else_expr); + int32_t rel_jmp = (int32_t)(jit->cb.len - (jmp_pos + 5)); + memcpy(jit->cb.buf + jmp_pos + 1, &rel_jmp, 4); + break; + } + + case EX_CAST: { + gen_expr_jit(jit, e->cast.expr); + _TY to = e->cast.to; + if (to.ptr_level > 0) break; /* pointer cast: value unchanged */ + switch (to.base) { + case TY_CHAR: + emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); + emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, al */ + break; + case TY_SHORT: + emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB7); + emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, ax */ + break; + case TY_INT: + emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x63); emit_modrm(&jit->cb, 3, RAX, RAX); + break; + default: break; /* long/void/etc: no-op */ + } + break; + } + case EX_DEREF: gen_expr_jit(jit, e->deref.expr); emit_rex(&jit->cb, RAX, RAX, 1); emit8(&jit->cb, 0x8B); emit_modrm(&jit->cb, 0, RAX, RAX); @@ -692,12 +1259,17 @@ static void gen_expr_jit(JIT *jit, _EX *e) { gen_expr_jit(jit, e->index.array); break; } _TY var_type = get_var_type(jit, e->index.array->name); - int element_size = (var_type.base == TY_CHAR) ? 1 : 8; + int element_size = ty_slot_size((_TY){var_type.base, 0, -1}); if (var_type.ptr_level > 0) { - // Pointer indexing: load pointer, add index*element_size - int ptr_offset = get_var_offset(jit, e->index.array->name); - emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer + GlobalVar *gv_ptr = find_global(jit, e->index.array->name); + if (gv_ptr) { + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr->addr); + emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref: RBX = *addr = pointer value */ + } else { + int ptr_offset = get_var_offset(jit, e->index.array->name); + emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer + } gen_expr_jit(jit, e->index.index); // RAX = index if (element_size > 1) { emit_mov_reg_reg(&jit->cb, RCX, RBX); // save pointer @@ -714,17 +1286,22 @@ static void gen_expr_jit(JIT *jit, _EX *e) { } } else if (var_type.array_size > 0) { - // Array indexing: compute RBP-relative address = array_offset - index*element_size - int array_offset = get_var_offset(jit, e->index.array->name); + GlobalVar *gv_arr = find_global(jit, e->index.array->name); gen_expr_jit(jit, e->index.index); // RAX = index emit_mov_reg_imm64(&jit->cb, RBX, element_size); emit_imul_rax_rbx(&jit->cb); // RAX = byte offset - emit_mov_reg_imm64(&jit->cb, RBX, array_offset); - emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset - emit_mov_reg_reg(&jit->cb, RAX, RBX); - emit_mov_reg_reg(&jit->cb, RBX, RBP); - emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index] - emit_mov_reg_reg(&jit->cb, RBX, RAX); + if (gv_arr) { + /* Global array: base is absolute address */ + emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr->addr); + emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index] + } else { + int array_offset = get_var_offset(jit, e->index.array->name); + emit_mov_reg_imm64(&jit->cb, RBX, array_offset); + emit_add_reg_reg(&jit->cb, RBX, RAX); + emit_mov_reg_reg(&jit->cb, RAX, RBX); + emit_mov_reg_reg(&jit->cb, RBX, RBP); + emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset + } if (element_size == 1) { emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6); emit_modrm(&jit->cb, 0, RAX, RBX); @@ -743,22 +1320,35 @@ static void gen_expr_jit(JIT *jit, _EX *e) { } } -/* --- Compile one function --- */ - static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) { reset_varmap(jit); jit->current_func_name = f->name; - int total_stack_size = calculate_stack_size(f->body); + int total_stack_size = calculate_total_stack_size(f); cb_init(&jit->cb); emit_prologue(&jit->cb, total_stack_size); const int param_regs[6] = {RDI, RSI, RDX, RCX, R8, R9}; + for (int i = 0; i < f->pac && i < 6; i++) { add_var(jit, f->params[i], f->param_types[i]); emit_mov_mem64_reg(&jit->cb, get_var_offset(jit, f->params[i]), param_regs[i]); } + int num_stack_params = f->pac > 6 ? f->pac - 6 : 0; + int stack_arg_padding = (num_stack_params % 2) ? 8 : 0; + int stack_arg_base = 16 + stack_arg_padding; /* offset of last-pushed (highest-index) stack arg */ + for (int i = 6; i < f->pac; i++) { + VarMap *v = (VarMap *)malloc(sizeof(VarMap)); + if (!v) { fprintf(stderr, "[JIT] malloc failed (stack param)\n"); exit(1); } + v->name = strdup(f->params[i]); + if (!v->name) { fprintf(stderr, "[JIT] strdup failed (stack param)\n"); free(v); exit(1); } + v->type = f->param_types[i]; + v->offset = stack_arg_base + (i - 6) * 8; + v->next = jit->var_list; + jit->var_list = v; + } + int did_return = gen_stmt_jit(jit, f->body); if (!did_return) { emit_movabs_rax_imm64(&jit->cb, 0); @@ -777,7 +1367,6 @@ static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) { return mem; } -/* --- Patch forward calls after all functions are compiled --- */ static void patch_function_calls(JIT *jit) { for (PatchEntry *patch = jit->patch_list; patch; patch = patch->next) { @@ -797,16 +1386,36 @@ static void patch_function_calls(JIT *jit) { } } -/* --- Compile all functions and patch forward calls --- */ - static void jit_compile_all(JIT *jit, _FN *fn_list) { - for (_FN *cur = fn_list; cur; cur = cur->n) - register_func(jit, cur->name); + /* First pass: register all globals so type lookups work during codegen */ + for (_FN *cur = fn_list; cur; cur = cur->n) { + if (strncmp(cur->name, "__global_", 9) == 0) { + /* Extract variable name from __global_NAME__ */ + const char *start = cur->name + 9; + size_t len = strlen(start); + if (len >= 2 && start[len-2] == '_' && start[len-1] == '_') len -= 2; + char *vname = strndup(start, len); + if (!vname) { fprintf(stderr, "[JIT] OOM\n"); exit(1); } + /* Get type from the STK_GLOBAL node */ + _STN *gdecl = cur->body; + if (gdecl && gdecl->kind == STK_GLOBAL) { + register_global(jit, vname, gdecl->global.type); + } + free(vname); + } + } + + /* Second pass: register all real functions (so forward calls work) */ + for (_FN *cur = fn_list; cur; cur = cur->n) { + if (strncmp(cur->name, "__global_", 9) != 0) + register_func(jit, cur->name, cur->ret_type); + } - // Compile in reverse order so callees are typically compiled before callers + /* Compile real functions in reverse order */ _FN *functions[64]; int count = 0; for (_FN *cur = fn_list; cur; cur = cur->n) { + if (strncmp(cur->name, "__global_", 9) == 0) continue; if (count >= 64) { fprintf(stderr, "[JIT] Too many functions (max 64)\n"); exit(1); } functions[count++] = cur; } @@ -814,9 +1423,27 @@ static void jit_compile_all(JIT *jit, _FN *fn_list) { gen_function_jit(jit, functions[i], NULL); patch_function_calls(jit); -} -/* --- Entry point --- */ + /* Run global initialisers in declaration order */ + for (_FN *cur = fn_list; cur; cur = cur->n) { + if (strncmp(cur->name, "__global_", 9) != 0) continue; + reset_varmap(jit); + jit->current_func_name = cur->name; + cb_init(&jit->cb); + emit8(&jit->cb, 0x55); /* push rbp */ + emitN(&jit->cb, (uint8_t[]){0x48,0x89,0xE5}, 3); /* mov rbp, rsp */ + gen_stmt_jit(jit, cur->body); + emit8(&jit->cb, 0xC9); /* leave */ + emit8(&jit->cb, 0xC3); /* ret */ + size_t isz = jit->cb.len; + void *ibuf = mmap(NULL, isz, PROT_READ|PROT_WRITE|PROT_EXEC, + MAP_PRIVATE|MAP_ANONYMOUS, -1, 0); + if (ibuf == MAP_FAILED) { perror("mmap init"); exit(1); } + memcpy(ibuf, jit->cb.buf, isz); + ((void(*)(void))ibuf)(); + munmap(ibuf, isz); + } +} static int jit_run(JIT *jit, int argc, char **argv) { int (*main_func)(int, char **) = get_func_addr(jit, "main"); @@ -824,4 +1451,4 @@ static int jit_run(JIT *jit, int argc, char **argv) { return main_func(argc, argv); } -#endif
\ No newline at end of file +#endif |
