summaryrefslogtreecommitdiff
path: root/src/codegen_jit.h
diff options
context:
space:
mode:
authorDavid Moc <personal@cdatgoose.org>2026-03-08 15:02:30 +0100
committerDavid Moc <personal@cdatgoose.org>2026-03-08 15:02:30 +0100
commitad035ac5d942c6448a6a0464b995c2868a8378db (patch)
tree9c2c779d544a45b4234a6a3c311aa7b612ea975e /src/codegen_jit.h
parent0385817bb1301a778bb33f8405a435293b9f8905 (diff)
Gosh.\nFixed bugs in offsetting. Added syscalls. Cleaned up the previous commenting (used to pass the project thu claude to add comments but LLMs are dumb). Removed the LLM made test runner cuz fuck AI.HEADmaster
Diffstat (limited to 'src/codegen_jit.h')
-rw-r--r--src/codegen_jit.h853
1 files changed, 740 insertions, 113 deletions
diff --git a/src/codegen_jit.h b/src/codegen_jit.h
index ff1d687..2cd4aa2 100644
--- a/src/codegen_jit.h
+++ b/src/codegen_jit.h
@@ -32,6 +32,7 @@ typedef struct FuncMap {
void *addr;
size_t size;
size_t alloc_size;
+ _TY ret_type; /* return type of this function */
struct FuncMap *next;
} FuncMap;
@@ -49,6 +50,14 @@ typedef struct VarMap {
struct VarMap *next;
} VarMap;
+typedef struct GlobalVar {
+ char *name;
+ uint8_t *addr; /* pointer into globals_buf */
+ _TY type;
+ int size; /* total bytes allocated */
+ struct GlobalVar *next;
+} GlobalVar;
+
typedef struct {
FuncMap *func_list;
VarMap *var_list;
@@ -56,6 +65,17 @@ typedef struct {
CodeBuf cb;
char *current_func_name;
PatchEntry *patch_list;
+ /* Globals data segment */
+ uint8_t *globals_buf; /* mmap'd RW page(s) for global variables */
+ size_t globals_cap;
+ size_t globals_used;
+ GlobalVar *global_list;
+ /* Break/continue patch stacks (up to 64 nesting levels) */
+ size_t break_patches[64][256]; /* offsets to patch */
+ int break_patch_count[64];
+ size_t cont_patches[64][256];
+ int cont_patch_count[64];
+ int loop_depth;
} JIT;
static void jit_init(JIT *jit) {
@@ -66,6 +86,19 @@ static void jit_init(JIT *jit) {
jit->cb.len = jit->cb.cap = 0;
jit->current_func_name = NULL;
jit->patch_list = NULL;
+ jit->globals_cap = 1024 * 1024;
+ jit->globals_used = 0;
+ jit->globals_buf = (uint8_t *)mmap(NULL, jit->globals_cap,
+ PROT_READ|PROT_WRITE,
+ MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
+ if (jit->globals_buf == MAP_FAILED) {
+ perror("mmap globals"); exit(1);
+ }
+ memset(jit->globals_buf, 0, jit->globals_cap);
+ jit->global_list = NULL;
+ jit->loop_depth = 0;
+ memset(jit->break_patch_count, 0, sizeof(jit->break_patch_count));
+ memset(jit->cont_patch_count, 0, sizeof(jit->cont_patch_count));
}
static void jit_free(JIT *jit) {
@@ -79,6 +112,14 @@ static void jit_free(JIT *jit) {
}
jit->patch_list = NULL;
+ for (GlobalVar *g = jit->global_list; g;) {
+ GlobalVar *n = g->next; free(g->name); free(g); g = n;
+ }
+ jit->global_list = NULL;
+
+ if (jit->globals_buf && jit->globals_buf != MAP_FAILED)
+ munmap(jit->globals_buf, jit->globals_cap);
+
for (FuncMap *f = jit->func_list; f;) {
FuncMap *n = f->next;
if (f->addr && f->alloc_size > 0) munmap(f->addr, f->alloc_size);
@@ -91,8 +132,6 @@ static void jit_free(JIT *jit) {
jit->cb.len = jit->cb.cap = 0;
}
-/* --- Code buffer --- */
-
static void cb_init(CodeBuf *c) {
c->cap = 1024; c->len = 0;
c->buf = (uint8_t *)malloc(c->cap);
@@ -114,8 +153,6 @@ static void emit32(CodeBuf *c, uint32_t v) { cb_grow(c,4); memcpy(c->buf+c->
static void emit64(CodeBuf *c, uint64_t v) { cb_grow(c,8); memcpy(c->buf+c->len,&v,8); c->len+=8; }
static void emitN(CodeBuf *c, const void *p, size_t n) { cb_grow(c,n); memcpy(c->buf+c->len,p,n); c->len+=n; }
-/* --- x86-64 encoding helpers --- */
-
static void emit_rex(CodeBuf *c, int reg, int rm, int w) {
uint8_t rex = 0x40;
if (w) rex |= 0x08;
@@ -145,9 +182,18 @@ static void emit_mov_reg_mem_reg(CodeBuf *c, int dst, int base) {
static void emit_mov_mem_reg_reg(CodeBuf *c, int base, int src) {
emit_rex(c,src,base,1); emit8(c,0x89); emit_modrm(c,0,src,base);
}
+static void emit_mov_mem8_reg_reg(CodeBuf *c, int base) {
+ if (base & 8) emit8(c, 0x41); /* REX.B for extended base registers */
+ emit8(c, 0x88); emit_modrm(c, 0, RAX, base);
+}
static void emit_movzx_rax_mem8(CodeBuf *c, int disp32) {
emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB6); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
}
+static void emit_movzx_rax_mem8_base(CodeBuf *c, int base) {
+ if (base & 8) emit8(c, 0x49); else emit8(c, 0x48);
+ emit8(c, 0x0F); emit8(c, 0xB6);
+ emit_modrm(c, 0, RAX, base);
+}
static void emit_mov_mem8_rax(CodeBuf *c, int disp32) {
emit_rex(c,RAX,RBP,0); emit8(c,0x88); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
}
@@ -174,8 +220,8 @@ static void emit_add_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x01,0xD8},
static void emit_imul_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xAF,0xC3},4); }
static void emit_idiv_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); }
static void emit_imod(CodeBuf *c) {
- emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); // cqo; idiv rbx
- emit_mov_reg_reg(c, RAX, RDX); // remainder -> RAX
+ emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5);
+ emit_mov_reg_reg(c, RAX, RDX);
}
static void emit_or_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x09,0xD8},3); }
static void emit_xor_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x31,0xD8},3); }
@@ -192,53 +238,97 @@ static void emit_jcc_rel32(CodeBuf *c, uint8_t cc, int32_t rel) {
static void emit_test_rax_rax(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x85,0xC0},3); }
static void emit_prologue(CodeBuf *c, int total_stack_size) {
- emit8(c,0x55); // push rbp
- emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); // mov rbp, rsp
- emit_push_reg(c, RBX); // save callee-saved RBX
+ emit8(c,0x55);
+ emitN(c,(uint8_t[]){0x48,0x89,0xE5},3);
+ emit_push_reg(c, RBX);
int stack_bytes = ((total_stack_size+15)/16)*16;
if (stack_bytes > 0) {
- emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); // sub rsp, imm32
+ emitN(c,(uint8_t[]){0x48,0x81,0xEC},3);
emit32(c,(uint32_t)stack_bytes);
}
}
static void emit_epilogue(CodeBuf *c) {
- emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); // restore RBX; leave; ret
+ emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3);
}
static void emit_lea_rax_rbp_disp(CodeBuf *c, int disp32) {
emit8(c,0x48); emit8(c,0x8D); emit_modrm(c,2,RAX,RBP); emit32(c,(uint32_t)disp32);
}
+static void emit_movzx_rax_mem16(CodeBuf *c, int disp32) {
+ emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB7); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
+}
+static void emit_mov_mem16_rax(CodeBuf *c, int disp32) {
+ emit8(c,0x66); emit_rex(c,RAX,RBP,0); emit8(c,0x89); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
+}
static void emit_load_rax_from_mem(CodeBuf *c, int disp32, int size) {
- if (size == 1) emit_movzx_rax_mem8(c, disp32); else emit_mov_reg_mem64(c, RAX, disp32);
+ if (size == 1) emit_movzx_rax_mem8(c, disp32);
+ else if (size == 2) emit_movzx_rax_mem16(c, disp32);
+ else emit_mov_reg_mem64(c, RAX, disp32);
}
static void emit_store_rax_to_mem(CodeBuf *c, int disp32, int size) {
- if (size == 1) emit_mov_mem8_rax(c, disp32); else emit_mov_mem64_reg(c, disp32, RAX);
+ if (size == 1) emit_mov_mem8_rax(c, disp32);
+ else if (size == 2) emit_mov_mem16_rax(c, disp32);
+ else emit_mov_mem64_reg(c, disp32, RAX);
}
-/* --- Variable and type helpers --- */
-
static int calculate_type_size(_TY type) {
if (type.ptr_level > 0) return 8;
- int base_size = (type.base == TY_CHAR) ? 1 : 8;
- size_t total = (size_t)base_size;
+ int slot;
+ switch (type.base) {
+ case TY_CHAR: slot = 1; break;
+ case TY_BOOL: slot = 1; break;
+ case TY_SHORT: slot = 2; break;
+ case TY_INT: slot = 8; break;
+ case TY_FLOAT: slot = 8; break;
+ case TY_LONG: slot = 8; break;
+ case TY_VOID: slot = 0; break;
+ default: slot = 8; break;
+ }
if (type.array_size > 0) {
- total *= (size_t)type.array_size;
+ size_t total = (size_t)slot * (size_t)type.array_size;
if (total > (size_t)INT_MAX) { fprintf(stderr, "[JIT] array size too large\n"); exit(1); }
+ return (int)total;
}
- return (int)total;
+ return slot;
}
static int align_offset(int offset, _TY type) {
- int align = (type.base == TY_CHAR) ? 1 : 8;
+ int align;
+ if (type.ptr_level > 0) { align = 8; }
+ else switch (type.base) {
+ case TY_CHAR: align = 1; break;
+ case TY_BOOL: align = 1; break;
+ case TY_SHORT: align = 2; break;
+ case TY_INT: align = 8; break;
+ case TY_FLOAT: align = 8; break;
+ case TY_LONG: align = 8; break;
+ default: align = 8; break;
+ }
if (offset % align == 0) return offset;
+ if (offset < 0)
+ return ((offset - (align - 1)) / align) * align;
return (offset / align) * align;
}
+static int ty_slot_size(_TY ty) {
+ if (ty.ptr_level > 0) return 8;
+ switch (ty.base) {
+ case TY_CHAR: return 1;
+ case TY_BOOL: return 1;
+ case TY_SHORT: return 2;
+ case TY_INT: return 8; /* promoted to 64-bit slot */
+ case TY_FLOAT: return 8; /* stored in 64-bit slot (integer bits) */
+ case TY_LONG: return 8;
+ case TY_VOID: return 0;
+ default: return 8;
+ }
+}
+
static void reset_varmap(JIT *jit) {
for (VarMap *v = jit->var_list; v;) {
VarMap *n = v->next; free(v->name); free(v); v = n;
}
jit->var_list = NULL;
- jit->next_local_offset = -16; // after saved RBX at RBP-8, 16-byte aligned
+ jit->next_local_offset = -16;
}
static void add_var(JIT *jit, const char *name, _TY type) {
@@ -248,9 +338,9 @@ static void add_var(JIT *jit, const char *name, _TY type) {
if (!v->name) { fprintf(stderr, "[JIT] strdup failed in add_var\n"); free(v); exit(1); }
v->type = type;
int type_size = calculate_type_size(type);
+ jit->next_local_offset -= type_size;
jit->next_local_offset = align_offset(jit->next_local_offset, type);
v->offset = jit->next_local_offset;
- jit->next_local_offset -= type_size;
v->next = jit->var_list;
jit->var_list = v;
}
@@ -261,23 +351,60 @@ static int get_var_offset(JIT *jit, const char *name) {
fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1);
}
+static GlobalVar *find_global(JIT *jit, const char *name) {
+ for (GlobalVar *g = jit->global_list; g; g = g->next)
+ if (strcmp(g->name, name) == 0) return g;
+ return NULL;
+}
+
+static GlobalVar *register_global(JIT *jit, const char *name, _TY type) {
+ if (find_global(jit, name)) {
+ fprintf(stderr, "[JIT] Duplicate global '%s'\n", name); exit(1);
+ }
+ int elem_sz = ty_slot_size((_TY){type.base, type.ptr_level > 0 ? type.ptr_level : 0, -1});
+ int n_elems = (type.array_size > 0) ? type.array_size : 1;
+ int total_sz = elem_sz * n_elems;
+ size_t align = (elem_sz < 8) ? elem_sz : 8;
+ size_t off = (jit->globals_used + align - 1) & ~(align - 1);
+ if (off + (size_t)total_sz > jit->globals_cap) {
+ fprintf(stderr, "[JIT] Globals segment full\n"); exit(1);
+ }
+ GlobalVar *g = (GlobalVar *)malloc(sizeof(GlobalVar));
+ if (!g) { fprintf(stderr, "[JIT] OOM\n"); exit(1); }
+ g->name = strdup(name);
+ g->addr = jit->globals_buf + off;
+ g->type = type;
+ g->size = total_sz;
+ g->next = jit->global_list;
+ jit->global_list = g;
+ jit->globals_used = off + total_sz;
+ return g;
+}
+
static _TY get_var_type(JIT *jit, const char *name) {
for (VarMap *v = jit->var_list; v; v = v->next)
if (strcmp(v->name, name) == 0) return v->type;
+ /* Fall back to globals */
+ GlobalVar *g = find_global(jit, name);
+ if (g) return g->type;
fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1);
}
-/* --- Type checking --- */
+static _TY get_func_ret_type(JIT *jit, const char *name);
+static void reset_varmap(JIT *jit);
static _TY get_expr_type(JIT *jit, _EX *expr) {
switch (expr->kind) {
case EX_NUMBER: return (_TY){TY_INT, 0, -1};
case EX_STRING: return (_TY){TY_CHAR, 1, -1};
- case EX_VAR: return get_var_type(jit, expr->name);
+ case EX_VAR: {
+ _TY t = get_var_type(jit, expr->name);
+ if (t.array_size > 0) return (_TY){t.base, t.ptr_level + 1, -1};
+ return t;
+ }
case EX_BINOP: {
_TY left_type = get_expr_type(jit, expr->binop.l);
- // All arithmetic, comparison, bitwise ops produce int
- (void)get_expr_type(jit, expr->binop.r);
+ (void)get_expr_type(jit, expr->binop.r);
switch (expr->binop.op) {
case TK_PLUS: case TK_MINUS: case TK_STAR: case TK_SLASH: case TK_PERCENT:
case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE:
@@ -287,7 +414,13 @@ static _TY get_expr_type(JIT *jit, _EX *expr) {
default: return left_type;
}
}
- case EX_CALL: return (_TY){TY_INT, 0, -1};
+ case EX_CALL:
+ if (strcmp(expr->call.func_name, "syscall") == 0)
+ return (_TY){TY_LONG, 0, -1};
+ if (strcmp(expr->call.func_name, "__initlist__") == 0 ||
+ strcmp(expr->call.func_name, "__sizeof__") == 0)
+ return (_TY){TY_INT, 0, -1};
+ return get_func_ret_type(jit, expr->call.func_name);
case EX_INDEX: {
_TY t = get_expr_type(jit, expr->index.array);
if (t.array_size > 0) return (_TY){t.base, t.ptr_level, -1};
@@ -303,31 +436,53 @@ static _TY get_expr_type(JIT *jit, _EX *expr) {
_TY t = get_expr_type(jit, expr->addr.expr);
return (_TY){t.base, t.ptr_level+1, -1};
}
+ case EX_TERNARY:
+ return get_expr_type(jit, expr->ternary.then_expr);
+ case EX_CAST:
+ return expr->cast.to;
default: return (_TY){TY_INT, 0, -1};
}
}
+static int type_is_integer(_TY ty) {
+ if (ty.ptr_level > 0) return 0;
+ switch (ty.base) {
+ case TY_INT: case TY_CHAR: case TY_SHORT:
+ case TY_LONG: case TY_BOOL: return 1;
+ default: return 0;
+ }
+}
+
static int types_compatible(_TY expected, _TY actual) {
if (expected.base == actual.base &&
expected.ptr_level == actual.ptr_level &&
expected.array_size == actual.array_size) return 1;
- // Allow untyped int literals to be assigned anywhere
if (actual.base == TY_INT && actual.ptr_level == 0 && actual.array_size == -1) return 1;
+ if (type_is_integer(expected) && type_is_integer(actual)) return 1;
+ if (expected.ptr_level > 0 && actual.ptr_level > 0) return 1;
+ if (expected.ptr_level > 0 && actual.array_size > 0 &&
+ expected.base == actual.base &&
+ expected.ptr_level == actual.ptr_level + 1) return 1;
return 0;
}
-/* --- Function registry --- */
-
-static void register_func(JIT *jit, const char *name) {
+static void register_func(JIT *jit, const char *name, _TY ret_type) {
FuncMap *f = (FuncMap *)malloc(sizeof(FuncMap));
if (!f) { fprintf(stderr, "[JIT] malloc failed in register_func\n"); exit(1); }
f->name = strdup(name);
if (!f->name) { fprintf(stderr, "[JIT] strdup failed in register_func\n"); free(f); exit(1); }
f->addr = NULL; f->size = 0; f->alloc_size = 0;
+ f->ret_type = ret_type;
f->next = jit->func_list;
jit->func_list = f;
}
+static _TY get_func_ret_type(JIT *jit, const char *name) {
+ for (FuncMap *f = jit->func_list; f; f = f->next)
+ if (strcmp(f->name, name) == 0) return f->ret_type;
+ return (_TY){TY_INT, 0, -1};
+}
+
static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, size_t alloc_size) {
for (FuncMap *f = jit->func_list; f; f = f->next) {
if (strcmp(f->name, name) != 0) continue;
@@ -352,17 +507,19 @@ static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, s
static void *get_func_addr(JIT *jit, const char *name) {
for (FuncMap *f = jit->func_list; f; f = f->next) {
if (strcmp(f->name, name) == 0)
- return f->addr ? f->addr : (void*)0xDEADBEEF; // placeholder; patched later
+ return f->addr ? f->addr : (void*)0xDEADBEEF;
}
fprintf(stderr, "[JIT] get_func_addr: unknown function '%s'\n", name); exit(1);
}
-/* --- Stack size calculation --- */
static int calculate_stack_size(_STN *s) {
int total = 0;
while (s) {
- if (s->kind == STK_VAR_DECL) total += calculate_type_size(s->var_decl.type);
+ if (s->kind == STK_VAR_DECL) {
+ int sz = calculate_type_size(s->var_decl.type);
+ total += (sz < 8 && s->var_decl.type.array_size <= 0) ? 8 : sz;
+ }
if (s->kind == STK_BLOCK) total += calculate_stack_size(s->body);
if (s->kind == STK_FOR) {
if (s->fr.init) total += calculate_stack_size(s->fr.init);
@@ -370,15 +527,21 @@ static int calculate_stack_size(_STN *s) {
}
s = s->n;
}
- return total;
+ return total + 8;
+}
+
+static int calculate_total_stack_size(_FN *f) {
+ int param_space = 0;
+ for (int i = 0; i < f->pac && i < 6; i++)
+ param_space += (calculate_type_size(f->param_types[i]) < 8) ? 8
+ : calculate_type_size(f->param_types[i]);
+ return calculate_stack_size(f->body) + param_space;
}
-/* --- Code generation (forward declarations) --- */
static void gen_expr_jit(JIT *jit, _EX *e);
static int gen_stmt_jit(JIT *jit, _STN *s);
-/* --- Statement generation --- */
static int gen_stmt_jit(JIT *jit, _STN *s) {
while (s) {
@@ -387,6 +550,19 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
case STK_VAR_DECL:
add_var(jit, s->var_decl.name, s->var_decl.type);
if (s->var_decl.init) {
+ if (s->var_decl.init->kind == EX_CALL &&
+ strcmp(s->var_decl.init->call.func_name, "__initlist__") == 0) {
+ _TY vty = s->var_decl.type;
+ int elem_sz = ty_slot_size((_TY){vty.base, 0, -1});
+ int arr_off = get_var_offset(jit, s->var_decl.name);
+ for (int ii = 0; ii < s->var_decl.init->call.argc; ii++) {
+ gen_expr_jit(jit, s->var_decl.init->call.args[ii]);
+ /* address of arr[ii] = RBP + arr_off + ii*elem_sz */
+ int elem_off = arr_off + ii * elem_sz;
+ emit_store_rax_to_mem(&jit->cb, elem_off, elem_sz);
+ }
+ break;
+ }
_TY init_type = get_expr_type(jit, s->var_decl.init);
if (!types_compatible(s->var_decl.type, init_type)) {
fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n",
@@ -394,8 +570,7 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
exit(1);
}
gen_expr_jit(jit, s->var_decl.init);
- int sz = (s->var_decl.type.ptr_level > 0) ? 8
- : (s->var_decl.type.base == TY_CHAR) ? 1 : 8;
+ int sz = ty_slot_size(s->var_decl.type);
emit_store_rax_to_mem(&jit->cb, get_var_offset(jit, s->var_decl.name), sz);
}
break;
@@ -403,6 +578,15 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
case STK_ASSIGN: {
_EX *lhs = s->assign.lhs;
if (lhs->kind == EX_VAR) {
+ GlobalVar *gv = find_global(jit, lhs->name);
+ if (gv) {
+ gen_expr_jit(jit, s->assign.expr);
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr);
+ int sz = ty_slot_size(gv->type);
+ if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ break;
+ }
int offset = get_var_offset(jit, lhs->name);
_TY type = get_var_type(jit, lhs->name);
_TY expr_type = get_expr_type(jit, s->assign.expr);
@@ -412,36 +596,74 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
exit(1);
}
gen_expr_jit(jit, s->assign.expr);
- int sz = (type.ptr_level > 0) ? 8 : (type.base == TY_CHAR) ? 1 : 8;
+ int sz = ty_slot_size(type);
emit_store_rax_to_mem(&jit->cb, offset, sz);
} else if (lhs->kind == EX_DEREF) {
gen_expr_jit(jit, lhs->deref.expr);
- emit_mov_reg_reg(&jit->cb, RBX, RAX); // RBX = address
- gen_expr_jit(jit, s->assign.expr); // RAX = value
- emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ emit_push_reg(&jit->cb, RAX); // save address on stack
+ gen_expr_jit(jit, s->assign.expr); // RAX = value (may clobber RBX)
+ emit_pop_reg(&jit->cb, RBX); // RBX = address
+ _TY ptr_ty = get_expr_type(jit, lhs->deref.expr);
+ int dsz = (ptr_ty.ptr_level > 1) ? 8 : ty_slot_size((_TY){ptr_ty.base, 0, -1});
+ if (dsz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
} else if (lhs->kind == EX_INDEX) {
if (lhs->index.array->kind != EX_VAR) {
gen_expr_jit(jit, lhs->index.array); break;
}
_TY var_type = get_var_type(jit, lhs->index.array->name);
- if (var_type.array_size <= 0) {
- fprintf(stderr, "[JIT] Cannot index non-array '%s'\n", lhs->index.array->name); exit(1);
+ int element_size = ty_slot_size((_TY){var_type.base, 0, -1});
+
+ if (var_type.ptr_level > 0) {
+ GlobalVar *gv_ptr2 = find_global(jit, lhs->index.array->name);
+ if (gv_ptr2) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr2->addr);
+ emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref global ptr */
+ } else {
+ int ptr_offset = get_var_offset(jit, lhs->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ }
+ gen_expr_jit(jit, lhs->index.index); // RAX = index
+ if (element_size > 1) {
+ emit_mov_reg_reg(&jit->cb, RCX, RBX);
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ emit_mov_reg_reg(&jit->cb, RBX, RCX);
+ }
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index]
+ emit_push_reg(&jit->cb, RBX); // save address
+ gen_expr_jit(jit, s->assign.expr); // RAX = value
+ emit_pop_reg(&jit->cb, RBX); // restore address
+ if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+
+ } else if (var_type.array_size > 0) {
+ GlobalVar *gv_arr2 = find_global(jit, lhs->index.array->name);
+ gen_expr_jit(jit, lhs->index.index); // RAX = index
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ if (gv_arr2) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr2->addr);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index]
+ } else {
+ int array_offset = get_var_offset(jit, lhs->index.array->name);
+ emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
+ emit_add_reg_reg(&jit->cb, RBX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset
+ }
+ emit_push_reg(&jit->cb, RBX); // save address
+ gen_expr_jit(jit, s->assign.expr); // RAX = value
+ emit_pop_reg(&jit->cb, RBX); // RBX = address
+ if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+
+ } else {
+ fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", lhs->index.array->name); exit(1);
}
- int array_offset = get_var_offset(jit, lhs->index.array->name);
- int element_size = (var_type.base == TY_CHAR) ? 1 : 8;
- gen_expr_jit(jit, lhs->index.index); // RAX = index
- emit_mov_reg_imm64(&jit->cb, RBX, element_size);
- emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
- emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
- emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset
- emit_mov_reg_reg(&jit->cb, RAX, RBX);
- emit_mov_reg_reg(&jit->cb, RBX, RBP);
- emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index]
- emit_mov_reg_reg(&jit->cb, RBX, RAX);
- gen_expr_jit(jit, s->assign.expr); // RAX = value
- emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
} else {
fprintf(stderr, "[JIT] Unsupported assignment LHS kind %d\n", lhs->kind); exit(1);
@@ -479,38 +701,151 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
}
case STK_WHILE: {
+ int depth = jit->loop_depth++;
+ jit->break_patch_count[depth] = 0;
+ jit->cont_patch_count[depth] = 0;
+
size_t loop_start = jit->cb.len;
gen_expr_jit(jit, s->whl.cond);
emit_test_rax_rax(&jit->cb);
size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
gen_stmt_jit(jit, s->whl.body);
+ size_t cont_target = jit->cb.len;
+ for (int i = 0; i < jit->cont_patch_count[depth]; i++) {
+ size_t p = jit->cont_patches[depth][i];
+ int32_t r = (int32_t)(cont_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
- int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6));
+ size_t break_target = jit->cb.len;
+ int32_t rel_end = (int32_t)(break_target - (jz_pos + 6));
memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4);
+ for (int i = 0; i < jit->break_patch_count[depth]; i++) {
+ size_t p = jit->break_patches[depth][i];
+ int32_t r = (int32_t)(break_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ jit->loop_depth--;
break;
}
case STK_FOR: {
+ int depth = jit->loop_depth++;
+ jit->break_patch_count[depth] = 0;
+ jit->cont_patch_count[depth] = 0;
+
if (s->fr.init) gen_stmt_jit(jit, s->fr.init);
size_t loop_start = jit->cb.len;
- if (s->fr.cond) {
+ size_t jz_pos_for = 0;
+ int has_cond = (s->fr.cond != NULL);
+ if (has_cond) {
gen_expr_jit(jit, s->fr.cond);
emit_test_rax_rax(&jit->cb);
- size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
- gen_stmt_jit(jit, s->fr.body);
- if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
- emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
- int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6));
- memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4);
- } else {
- gen_stmt_jit(jit, s->fr.body);
- if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
- emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ jz_pos_for = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
+ }
+ gen_stmt_jit(jit, s->fr.body);
+ size_t cont_target = jit->cb.len;
+ for (int i = 0; i < jit->cont_patch_count[depth]; i++) {
+ size_t p = jit->cont_patches[depth][i];
+ int32_t r = (int32_t)(cont_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
}
+ if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
+ emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ size_t break_target = jit->cb.len;
+ if (has_cond) {
+ int32_t rel_end = (int32_t)(break_target - (jz_pos_for + 6));
+ memcpy(jit->cb.buf + jz_pos_for + 2, &rel_end, 4);
+ }
+ for (int i = 0; i < jit->break_patch_count[depth]; i++) {
+ size_t p = jit->break_patches[depth][i];
+ int32_t r = (int32_t)(break_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ jit->loop_depth--;
break;
}
- default:
+ case STK_DOWHILE: {
+ int depth = jit->loop_depth++;
+ jit->break_patch_count[depth] = 0;
+ jit->cont_patch_count[depth] = 0;
+
+ size_t loop_start = jit->cb.len;
+ gen_stmt_jit(jit, s->dowhl.body);
+ /* continue → jump to condition check */
+ size_t cont_target = jit->cb.len;
+ for (int i = 0; i < jit->cont_patch_count[depth]; i++) {
+ size_t p = jit->cont_patches[depth][i];
+ int32_t r = (int32_t)(cont_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ gen_expr_jit(jit, s->dowhl.cond);
+ emit_test_rax_rax(&jit->cb);
+ emit_jcc_rel32(&jit->cb, 0x05, (int32_t)(loop_start - (jit->cb.len + 6)));
+ size_t break_target = jit->cb.len;
+ for (int i = 0; i < jit->break_patch_count[depth]; i++) {
+ size_t p = jit->break_patches[depth][i];
+ int32_t r = (int32_t)(break_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ jit->loop_depth--;
+ break;
+ }
+
+ case STK_BREAK: {
+ if (jit->loop_depth == 0) {
+ fprintf(stderr, "[JIT] 'break' outside loop\n"); exit(1);
+ }
+ int depth = jit->loop_depth - 1;
+ size_t pos = jit->cb.len;
+ emit_jmp_rel32(&jit->cb, 0);
+ int cnt = jit->break_patch_count[depth];
+ if (cnt >= 256) { fprintf(stderr, "[JIT] Too many breaks\n"); exit(1); }
+ jit->break_patches[depth][cnt] = pos;
+ jit->break_patch_count[depth]++;
+ break;
+ }
+
+ case STK_CONTINUE: {
+ if (jit->loop_depth == 0) {
+ fprintf(stderr, "[JIT] 'continue' outside loop\n"); exit(1);
+ }
+ int depth = jit->loop_depth - 1;
+ size_t pos = jit->cb.len;
+ emit_jmp_rel32(&jit->cb, 0);
+ int cnt = jit->cont_patch_count[depth];
+ if (cnt >= 256) { fprintf(stderr, "[JIT] Too many continues\n"); exit(1); }
+ jit->cont_patches[depth][cnt] = pos;
+ jit->cont_patch_count[depth]++;
+ break;
+ }
+
+ case STK_GLOBAL: {
+ GlobalVar *gv = find_global(jit, s->global.name);
+ if (!gv) {
+ fprintf(stderr, "[JIT] Global '%s' not registered\n", s->global.name); exit(1);
+ }
+ if (s->global.init) {
+ if (s->global.init->kind == EX_CALL &&
+ strcmp(s->global.init->call.func_name, "__initlist__") == 0) {
+ int esz = ty_slot_size((_TY){gv->type.base, 0, -1});
+ for (int ii = 0; ii < s->global.init->call.argc; ii++) {
+ gen_expr_jit(jit, s->global.init->call.args[ii]);
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)(gv->addr + ii * esz));
+ if (esz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ }
+ } else {
+ gen_expr_jit(jit, s->global.init);
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr);
+ int sz = ty_slot_size(gv->type);
+ if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ }
+ }
+ break;
+ }
fprintf(stderr, "[JIT] Unsupported statement kind %d\n", s->kind); exit(1);
}
s = s->n;
@@ -518,7 +853,6 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
return 0;
}
-/* --- Expression generation --- */
static void gen_expr_jit(JIT *jit, _EX *e) {
if (!e) return;
@@ -529,16 +863,110 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
break;
case EX_VAR: {
- int off = get_var_offset(jit, e->name);
- _TY ty = get_var_type(jit, e->name);
- int sz = (ty.ptr_level > 0) ? 8 : (ty.base == TY_CHAR) ? 1 : 8;
- if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off);
- else emit_mov_reg_mem64(&jit->cb, RAX, off);
- break;
+ GlobalVar *gv = NULL;
+ for (VarMap *v = jit->var_list; v; v = v->next) {
+ if (strcmp(v->name, e->name) == 0) goto local_var;
+ }
+ gv = find_global(jit, e->name);
+ if (gv) {
+ _TY ty = gv->type;
+ if (ty.array_size > 0) {
+ /* Global array: load its absolute address */
+ emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv->addr);
+ break;
+ }
+ /* Global scalar: absolute address → load through it */
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr);
+ int sz = ty_slot_size(ty);
+ if (sz == 1) emit_movzx_rax_mem8_base(&jit->cb, RBX);
+ else emit_mov_reg_mem_reg(&jit->cb, RAX, RBX);
+ break;
+ }
+ local_var: {
+ int off = get_var_offset(jit, e->name);
+ _TY ty = get_var_type(jit, e->name);
+ if (ty.array_size > 0) {
+ emit_lea_rax_rbp_disp(&jit->cb, off);
+ break;
+ }
+ int sz = ty_slot_size(ty);
+ if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off);
+ else emit_mov_reg_mem64(&jit->cb, RAX, off);
+ break;
+ }
}
case EX_BINOP: {
- // Short-circuit AND
+ if (e->binop.op == TK_INC || e->binop.op == TK_DEC) {
+ int is_pre = (e->binop.r->value == -1);
+ int is_inc = (e->binop.op == TK_INC);
+ _EX *lval = e->binop.l;
+
+ _TY lty = get_expr_type(jit, lval);
+ int step = (lty.ptr_level > 0) ? ty_slot_size((_TY){lty.base,0,-1}) : 1;
+
+ gen_expr_jit(jit, lval); /* RAX = old value */
+ emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = old value */
+ emit_mov_reg_imm64(&jit->cb, RBX, step);
+ if (is_inc) emit_add_reg_reg(&jit->cb, RCX, RBX);
+ else emit_sub_reg_reg(&jit->cb, RCX, RBX);
+ if (lval->kind == EX_VAR) {
+ GlobalVar *gv_inc = find_global(jit, lval->name);
+ emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ if (gv_inc) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_inc->addr);
+ int sz = ty_slot_size(lty);
+ if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ } else {
+ int off = get_var_offset(jit, lval->name);
+ int sz = ty_slot_size(lty);
+ emit_store_rax_to_mem(&jit->cb, off, sz);
+ }
+ } else {
+ emit_push_reg(&jit->cb, RCX); /* save new value */
+ if (lval->kind == EX_DEREF) {
+ gen_expr_jit(jit, lval->deref.expr); /* RAX = address */
+ } else { /* EX_INDEX */
+ _TY vty = get_var_type(jit, lval->index.array->name);
+ int esz = ty_slot_size((_TY){vty.base,0,-1});
+ if (vty.ptr_level > 0) {
+ int poff = get_var_offset(jit, lval->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RAX, poff);
+ emit_push_reg(&jit->cb, RAX);
+ gen_expr_jit(jit, lval->index.index);
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RBX, esz);
+ emit_imul_rax_rbx(&jit->cb);
+ }
+ emit_pop_reg(&jit->cb, RBX);
+ emit_add_reg_reg(&jit->cb, RAX, RBX);
+ } else {
+ int aoff = get_var_offset(jit, lval->index.array->name);
+ gen_expr_jit(jit, lval->index.index);
+ emit_mov_reg_imm64(&jit->cb, RBX, esz);
+ emit_imul_rax_rbx(&jit->cb);
+ emit_mov_reg_imm64(&jit->cb, RBX, aoff);
+ emit_add_reg_reg(&jit->cb, RBX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RAX, RBX);
+ }
+ }
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); /* RBX = address */
+ emit_pop_reg(&jit->cb, RCX); /* RCX = new val */
+ emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ }
+ if (is_pre) emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ if (!is_pre) {
+ emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ emit_mov_reg_imm64(&jit->cb, RBX, step);
+ if (is_inc) emit_sub_reg_reg(&jit->cb, RAX, RBX);
+ else emit_add_reg_reg(&jit->cb, RAX, RBX);
+ }
+ break;
+ }
if (e->binop.op == TK_AND) {
gen_expr_jit(jit, e->binop.l);
emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
@@ -554,7 +982,6 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4);
break;
}
- // Short-circuit OR
if (e->binop.op == TK_OR) {
gen_expr_jit(jit, e->binop.l);
emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
@@ -570,18 +997,54 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4);
break;
}
- // All other binary ops: eval left -> push; eval right -> RBX = left
gen_expr_jit(jit, e->binop.l);
emit_push_reg(&jit->cb, RAX);
gen_expr_jit(jit, e->binop.r);
emit_pop_reg(&jit->cb, RBX); // RBX = left, RAX = right
switch (e->binop.op) {
- case TK_PLUS: emit_add_rax_rbx(&jit->cb); break;
- case TK_MINUS:
+ case TK_PLUS: {
+ _TY lt = get_expr_type(jit, e->binop.l);
+ _TY rt = get_expr_type(jit, e->binop.r);
+ int lptr = (lt.ptr_level > 0 || lt.array_size > 0);
+ int rptr = (rt.ptr_level > 0 || rt.array_size > 0);
+ if (lptr && !rptr) {
+ int esz = ty_slot_size((_TY){lt.base, 0, -1});
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RCX, esz);
+ emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4);
+ }
+ } else if (rptr && !lptr) {
+ int esz = ty_slot_size((_TY){rt.base, 0, -1});
+ emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = ptr */
+ emit_mov_reg_reg(&jit->cb, RAX, RBX); /* RAX = int */
+ emit_mov_reg_reg(&jit->cb, RBX, RCX); /* RBX = ptr */
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RCX, esz);
+ emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4);
+ }
+ }
+ emit_add_rax_rbx(&jit->cb);
+ break;
+ }
+ case TK_MINUS: {
+ _TY lt = get_expr_type(jit, e->binop.l);
+ _TY rt = get_expr_type(jit, e->binop.r);
+ int lptr = (lt.ptr_level > 0 || lt.array_size > 0);
+ int rptr = (rt.ptr_level > 0 || rt.array_size > 0);
+ if (lptr && !rptr) {
+ /* ptr - int: scale int by element size */
+ int esz = ty_slot_size((_TY){lt.base, 0, -1});
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RCX, esz);
+ emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4);
+ }
+ }
+ /* RAX=right(scaled), RBX=left: result = left - right */
emit_mov_reg_reg(&jit->cb, RCX, RAX);
emit_mov_reg_reg(&jit->cb, RAX, RBX);
emit_rex(&jit->cb, RCX, RAX, 1); emit8(&jit->cb, 0x29); emit_modrm(&jit->cb, 3, RCX, RAX);
break;
+ }
case TK_STAR: emit_imul_rax_rbx(&jit->cb); break;
case TK_SLASH:
emit_mov_reg_reg(&jit->cb, RCX, RAX);
@@ -626,24 +1089,84 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
case EX_CALL: {
+ if (strcmp(e->call.func_name, "__sizeof__") == 0) {
+ _EX *arg = e->call.args[0];
+ _TY ty;
+ if (arg->kind == EX_VAR) {
+ ty = get_var_type(jit, arg->name);
+ } else {
+ ty = get_expr_type(jit, arg);
+ }
+ int sz;
+ if (ty.array_size > 0) {
+ int esz = (ty.base == TY_CHAR) ? 1 : (ty.base == TY_SHORT) ? 2 :
+ (ty.base == TY_LONG) ? 8 : 4;
+ sz = esz * ty.array_size;
+ } else if (ty.ptr_level > 0) {
+ sz = 8;
+ } else {
+ switch (ty.base) {
+ case TY_CHAR: sz = 1; break;
+ case TY_SHORT: sz = 2; break;
+ case TY_LONG: sz = 8; break;
+ default: sz = 4; break; /* int, bool, float */
+ }
+ }
+ emit_mov_reg_imm64(&jit->cb, RAX, sz);
+ break;
+ }
+
+ if (strcmp(e->call.func_name, "syscall") == 0) {
+ int argc = e->call.argc;
+ if (argc < 1) {
+ fprintf(stderr, "[JIT] syscall() requires at least 1 argument (number)\n"); exit(1);
+ }
+ if (argc > 7) {
+ fprintf(stderr, "[JIT] syscall() supports at most 7 arguments\n"); exit(1);
+ }
+ const int sc_regs[6] = {RDI, RSI, RDX, R10, R8, R9};
+ int nargs = argc - 1; /* number of args after the syscall number */
+
+ for (int i = argc - 1; i >= 0; i--) {
+ gen_expr_jit(jit, e->call.args[i]);
+ emit_push_reg(&jit->cb, RAX);
+ }
+
+ emit_pop_reg(&jit->cb, RAX);
+
+ for (int i = 0; i < nargs; i++) {
+ emit_pop_reg(&jit->cb, sc_regs[i]);
+ }
+
+ emit8(&jit->cb, 0x0F);
+ emit8(&jit->cb, 0x05);
+ break;
+ }
+
int total_args = e->call.argc;
int stack_args = total_args > 6 ? total_args - 6 : 0;
+ int reg_args = total_args < 6 ? total_args : 6;
int padding = (stack_args % 2) ? 8 : 0;
const int arg_regs[6] = {RDI, RSI, RDX, RCX, R8, R9};
+ /* Push stack-passed args right-to-left (args[total-1] first). */
for (int i = total_args-1; i >= 6; i--) {
gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX);
}
- for (int i = 0; i < total_args && i < 6; i++) {
- gen_expr_jit(jit, e->call.args[i]); emit_mov_reg_reg(&jit->cb, arg_regs[i], RAX);
+ for (int i = reg_args-1; i >= 0; i--) {
+ gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX);
+ }
+
+ for (int i = 0; i < reg_args; i++) {
+ emit_pop_reg(&jit->cb, arg_regs[i]);
}
+
if (padding) {
emit8(&jit->cb,0x48); emit8(&jit->cb,0x83); emit8(&jit->cb,0xEC); emit8(&jit->cb,0x08);
}
void *addr = get_func_addr(jit, e->call.func_name);
if (addr == (void*)0xDEADBEEF) {
- // Forward call: emit placeholder imm64, record offset for later patching
emit_movabs_rax_imm64(&jit->cb, (uint64_t)0xDEADBEEF);
PatchEntry *patch = (PatchEntry*)malloc(sizeof(PatchEntry));
if (!patch) { fprintf(stderr, "[JIT] malloc failed (patch)\n"); exit(1); }
@@ -659,7 +1182,7 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
} else {
emit_movabs_rax_imm64(&jit->cb, (uint64_t)addr);
}
- emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); // CALL RAX
+ emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0);
if (stack_args * 8 + padding > 0) {
emit8(&jit->cb,0x48); emit8(&jit->cb,0x81); emit8(&jit->cb,0xC4);
@@ -669,12 +1192,56 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
case EX_ADDR:
- if (e->addr.expr->kind != EX_VAR) {
+ if (e->addr.expr->kind == EX_VAR) {
+ GlobalVar *gv_addr = find_global(jit, e->addr.expr->name);
+ if (gv_addr) {
+ emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv_addr->addr);
+ } else {
+ emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name));
+ }
+ } else {
fprintf(stderr, "[JIT] &expr: only &var supported\n"); exit(1);
}
- emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name));
break;
+ case EX_TERNARY: {
+ gen_expr_jit(jit, e->ternary.cond);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX);
+ emit_or_rax_rbx(&jit->cb);
+ size_t jz_pos = jit->cb.len;
+ emit_jcc_rel32(&jit->cb, 0x04, 0);
+ gen_expr_jit(jit, e->ternary.then_expr);
+ size_t jmp_pos = jit->cb.len;
+ emit_jmp_rel32(&jit->cb, 0);
+ int32_t rel_jz = (int32_t)(jit->cb.len - (jz_pos + 6));
+ memcpy(jit->cb.buf + jz_pos + 2, &rel_jz, 4);
+ gen_expr_jit(jit, e->ternary.else_expr);
+ int32_t rel_jmp = (int32_t)(jit->cb.len - (jmp_pos + 5));
+ memcpy(jit->cb.buf + jmp_pos + 1, &rel_jmp, 4);
+ break;
+ }
+
+ case EX_CAST: {
+ gen_expr_jit(jit, e->cast.expr);
+ _TY to = e->cast.to;
+ if (to.ptr_level > 0) break; /* pointer cast: value unchanged */
+ switch (to.base) {
+ case TY_CHAR:
+ emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6);
+ emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, al */
+ break;
+ case TY_SHORT:
+ emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB7);
+ emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, ax */
+ break;
+ case TY_INT:
+ emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x63); emit_modrm(&jit->cb, 3, RAX, RAX);
+ break;
+ default: break; /* long/void/etc: no-op */
+ }
+ break;
+ }
+
case EX_DEREF:
gen_expr_jit(jit, e->deref.expr);
emit_rex(&jit->cb, RAX, RAX, 1); emit8(&jit->cb, 0x8B); emit_modrm(&jit->cb, 0, RAX, RAX);
@@ -692,12 +1259,17 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
gen_expr_jit(jit, e->index.array); break;
}
_TY var_type = get_var_type(jit, e->index.array->name);
- int element_size = (var_type.base == TY_CHAR) ? 1 : 8;
+ int element_size = ty_slot_size((_TY){var_type.base, 0, -1});
if (var_type.ptr_level > 0) {
- // Pointer indexing: load pointer, add index*element_size
- int ptr_offset = get_var_offset(jit, e->index.array->name);
- emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ GlobalVar *gv_ptr = find_global(jit, e->index.array->name);
+ if (gv_ptr) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr->addr);
+ emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref: RBX = *addr = pointer value */
+ } else {
+ int ptr_offset = get_var_offset(jit, e->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ }
gen_expr_jit(jit, e->index.index); // RAX = index
if (element_size > 1) {
emit_mov_reg_reg(&jit->cb, RCX, RBX); // save pointer
@@ -714,17 +1286,22 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
} else if (var_type.array_size > 0) {
- // Array indexing: compute RBP-relative address = array_offset - index*element_size
- int array_offset = get_var_offset(jit, e->index.array->name);
+ GlobalVar *gv_arr = find_global(jit, e->index.array->name);
gen_expr_jit(jit, e->index.index); // RAX = index
emit_mov_reg_imm64(&jit->cb, RBX, element_size);
emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
- emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
- emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset
- emit_mov_reg_reg(&jit->cb, RAX, RBX);
- emit_mov_reg_reg(&jit->cb, RBX, RBP);
- emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index]
- emit_mov_reg_reg(&jit->cb, RBX, RAX);
+ if (gv_arr) {
+ /* Global array: base is absolute address */
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr->addr);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index]
+ } else {
+ int array_offset = get_var_offset(jit, e->index.array->name);
+ emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
+ emit_add_reg_reg(&jit->cb, RBX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset
+ }
if (element_size == 1) {
emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6);
emit_modrm(&jit->cb, 0, RAX, RBX);
@@ -743,22 +1320,35 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
}
-/* --- Compile one function --- */
-
static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) {
reset_varmap(jit);
jit->current_func_name = f->name;
- int total_stack_size = calculate_stack_size(f->body);
+ int total_stack_size = calculate_total_stack_size(f);
cb_init(&jit->cb);
emit_prologue(&jit->cb, total_stack_size);
const int param_regs[6] = {RDI, RSI, RDX, RCX, R8, R9};
+
for (int i = 0; i < f->pac && i < 6; i++) {
add_var(jit, f->params[i], f->param_types[i]);
emit_mov_mem64_reg(&jit->cb, get_var_offset(jit, f->params[i]), param_regs[i]);
}
+ int num_stack_params = f->pac > 6 ? f->pac - 6 : 0;
+ int stack_arg_padding = (num_stack_params % 2) ? 8 : 0;
+ int stack_arg_base = 16 + stack_arg_padding; /* offset of last-pushed (highest-index) stack arg */
+ for (int i = 6; i < f->pac; i++) {
+ VarMap *v = (VarMap *)malloc(sizeof(VarMap));
+ if (!v) { fprintf(stderr, "[JIT] malloc failed (stack param)\n"); exit(1); }
+ v->name = strdup(f->params[i]);
+ if (!v->name) { fprintf(stderr, "[JIT] strdup failed (stack param)\n"); free(v); exit(1); }
+ v->type = f->param_types[i];
+ v->offset = stack_arg_base + (i - 6) * 8;
+ v->next = jit->var_list;
+ jit->var_list = v;
+ }
+
int did_return = gen_stmt_jit(jit, f->body);
if (!did_return) {
emit_movabs_rax_imm64(&jit->cb, 0);
@@ -777,7 +1367,6 @@ static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) {
return mem;
}
-/* --- Patch forward calls after all functions are compiled --- */
static void patch_function_calls(JIT *jit) {
for (PatchEntry *patch = jit->patch_list; patch; patch = patch->next) {
@@ -797,16 +1386,36 @@ static void patch_function_calls(JIT *jit) {
}
}
-/* --- Compile all functions and patch forward calls --- */
-
static void jit_compile_all(JIT *jit, _FN *fn_list) {
- for (_FN *cur = fn_list; cur; cur = cur->n)
- register_func(jit, cur->name);
+ /* First pass: register all globals so type lookups work during codegen */
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) == 0) {
+ /* Extract variable name from __global_NAME__ */
+ const char *start = cur->name + 9;
+ size_t len = strlen(start);
+ if (len >= 2 && start[len-2] == '_' && start[len-1] == '_') len -= 2;
+ char *vname = strndup(start, len);
+ if (!vname) { fprintf(stderr, "[JIT] OOM\n"); exit(1); }
+ /* Get type from the STK_GLOBAL node */
+ _STN *gdecl = cur->body;
+ if (gdecl && gdecl->kind == STK_GLOBAL) {
+ register_global(jit, vname, gdecl->global.type);
+ }
+ free(vname);
+ }
+ }
+
+ /* Second pass: register all real functions (so forward calls work) */
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) != 0)
+ register_func(jit, cur->name, cur->ret_type);
+ }
- // Compile in reverse order so callees are typically compiled before callers
+ /* Compile real functions in reverse order */
_FN *functions[64];
int count = 0;
for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) == 0) continue;
if (count >= 64) { fprintf(stderr, "[JIT] Too many functions (max 64)\n"); exit(1); }
functions[count++] = cur;
}
@@ -814,9 +1423,27 @@ static void jit_compile_all(JIT *jit, _FN *fn_list) {
gen_function_jit(jit, functions[i], NULL);
patch_function_calls(jit);
-}
-/* --- Entry point --- */
+ /* Run global initialisers in declaration order */
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) != 0) continue;
+ reset_varmap(jit);
+ jit->current_func_name = cur->name;
+ cb_init(&jit->cb);
+ emit8(&jit->cb, 0x55); /* push rbp */
+ emitN(&jit->cb, (uint8_t[]){0x48,0x89,0xE5}, 3); /* mov rbp, rsp */
+ gen_stmt_jit(jit, cur->body);
+ emit8(&jit->cb, 0xC9); /* leave */
+ emit8(&jit->cb, 0xC3); /* ret */
+ size_t isz = jit->cb.len;
+ void *ibuf = mmap(NULL, isz, PROT_READ|PROT_WRITE|PROT_EXEC,
+ MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
+ if (ibuf == MAP_FAILED) { perror("mmap init"); exit(1); }
+ memcpy(ibuf, jit->cb.buf, isz);
+ ((void(*)(void))ibuf)();
+ munmap(ibuf, isz);
+ }
+}
static int jit_run(JIT *jit, int argc, char **argv) {
int (*main_func)(int, char **) = get_func_addr(jit, "main");
@@ -824,4 +1451,4 @@ static int jit_run(JIT *jit, int argc, char **argv) {
return main_func(argc, argv);
}
-#endif \ No newline at end of file
+#endif