summaryrefslogtreecommitdiff
path: root/src/codegen_jit.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/codegen_jit.h')
-rw-r--r--src/codegen_jit.h853
1 files changed, 740 insertions, 113 deletions
diff --git a/src/codegen_jit.h b/src/codegen_jit.h
index ff1d687..2cd4aa2 100644
--- a/src/codegen_jit.h
+++ b/src/codegen_jit.h
@@ -32,6 +32,7 @@ typedef struct FuncMap {
void *addr;
size_t size;
size_t alloc_size;
+ _TY ret_type; /* return type of this function */
struct FuncMap *next;
} FuncMap;
@@ -49,6 +50,14 @@ typedef struct VarMap {
struct VarMap *next;
} VarMap;
+typedef struct GlobalVar {
+ char *name;
+ uint8_t *addr; /* pointer into globals_buf */
+ _TY type;
+ int size; /* total bytes allocated */
+ struct GlobalVar *next;
+} GlobalVar;
+
typedef struct {
FuncMap *func_list;
VarMap *var_list;
@@ -56,6 +65,17 @@ typedef struct {
CodeBuf cb;
char *current_func_name;
PatchEntry *patch_list;
+ /* Globals data segment */
+ uint8_t *globals_buf; /* mmap'd RW page(s) for global variables */
+ size_t globals_cap;
+ size_t globals_used;
+ GlobalVar *global_list;
+ /* Break/continue patch stacks (up to 64 nesting levels) */
+ size_t break_patches[64][256]; /* offsets to patch */
+ int break_patch_count[64];
+ size_t cont_patches[64][256];
+ int cont_patch_count[64];
+ int loop_depth;
} JIT;
static void jit_init(JIT *jit) {
@@ -66,6 +86,19 @@ static void jit_init(JIT *jit) {
jit->cb.len = jit->cb.cap = 0;
jit->current_func_name = NULL;
jit->patch_list = NULL;
+ jit->globals_cap = 1024 * 1024;
+ jit->globals_used = 0;
+ jit->globals_buf = (uint8_t *)mmap(NULL, jit->globals_cap,
+ PROT_READ|PROT_WRITE,
+ MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
+ if (jit->globals_buf == MAP_FAILED) {
+ perror("mmap globals"); exit(1);
+ }
+ memset(jit->globals_buf, 0, jit->globals_cap);
+ jit->global_list = NULL;
+ jit->loop_depth = 0;
+ memset(jit->break_patch_count, 0, sizeof(jit->break_patch_count));
+ memset(jit->cont_patch_count, 0, sizeof(jit->cont_patch_count));
}
static void jit_free(JIT *jit) {
@@ -79,6 +112,14 @@ static void jit_free(JIT *jit) {
}
jit->patch_list = NULL;
+ for (GlobalVar *g = jit->global_list; g;) {
+ GlobalVar *n = g->next; free(g->name); free(g); g = n;
+ }
+ jit->global_list = NULL;
+
+ if (jit->globals_buf && jit->globals_buf != MAP_FAILED)
+ munmap(jit->globals_buf, jit->globals_cap);
+
for (FuncMap *f = jit->func_list; f;) {
FuncMap *n = f->next;
if (f->addr && f->alloc_size > 0) munmap(f->addr, f->alloc_size);
@@ -91,8 +132,6 @@ static void jit_free(JIT *jit) {
jit->cb.len = jit->cb.cap = 0;
}
-/* --- Code buffer --- */
-
static void cb_init(CodeBuf *c) {
c->cap = 1024; c->len = 0;
c->buf = (uint8_t *)malloc(c->cap);
@@ -114,8 +153,6 @@ static void emit32(CodeBuf *c, uint32_t v) { cb_grow(c,4); memcpy(c->buf+c->
static void emit64(CodeBuf *c, uint64_t v) { cb_grow(c,8); memcpy(c->buf+c->len,&v,8); c->len+=8; }
static void emitN(CodeBuf *c, const void *p, size_t n) { cb_grow(c,n); memcpy(c->buf+c->len,p,n); c->len+=n; }
-/* --- x86-64 encoding helpers --- */
-
static void emit_rex(CodeBuf *c, int reg, int rm, int w) {
uint8_t rex = 0x40;
if (w) rex |= 0x08;
@@ -145,9 +182,18 @@ static void emit_mov_reg_mem_reg(CodeBuf *c, int dst, int base) {
static void emit_mov_mem_reg_reg(CodeBuf *c, int base, int src) {
emit_rex(c,src,base,1); emit8(c,0x89); emit_modrm(c,0,src,base);
}
+static void emit_mov_mem8_reg_reg(CodeBuf *c, int base) {
+ if (base & 8) emit8(c, 0x41); /* REX.B for extended base registers */
+ emit8(c, 0x88); emit_modrm(c, 0, RAX, base);
+}
static void emit_movzx_rax_mem8(CodeBuf *c, int disp32) {
emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB6); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
}
+static void emit_movzx_rax_mem8_base(CodeBuf *c, int base) {
+ if (base & 8) emit8(c, 0x49); else emit8(c, 0x48);
+ emit8(c, 0x0F); emit8(c, 0xB6);
+ emit_modrm(c, 0, RAX, base);
+}
static void emit_mov_mem8_rax(CodeBuf *c, int disp32) {
emit_rex(c,RAX,RBP,0); emit8(c,0x88); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
}
@@ -174,8 +220,8 @@ static void emit_add_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x01,0xD8},
static void emit_imul_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x0F,0xAF,0xC3},4); }
static void emit_idiv_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); }
static void emit_imod(CodeBuf *c) {
- emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5); // cqo; idiv rbx
- emit_mov_reg_reg(c, RAX, RDX); // remainder -> RAX
+ emitN(c,(uint8_t[]){0x48,0x99,0x48,0xF7,0xFB},5);
+ emit_mov_reg_reg(c, RAX, RDX);
}
static void emit_or_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x09,0xD8},3); }
static void emit_xor_rax_rbx(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x31,0xD8},3); }
@@ -192,53 +238,97 @@ static void emit_jcc_rel32(CodeBuf *c, uint8_t cc, int32_t rel) {
static void emit_test_rax_rax(CodeBuf *c) { emitN(c,(uint8_t[]){0x48,0x85,0xC0},3); }
static void emit_prologue(CodeBuf *c, int total_stack_size) {
- emit8(c,0x55); // push rbp
- emitN(c,(uint8_t[]){0x48,0x89,0xE5},3); // mov rbp, rsp
- emit_push_reg(c, RBX); // save callee-saved RBX
+ emit8(c,0x55);
+ emitN(c,(uint8_t[]){0x48,0x89,0xE5},3);
+ emit_push_reg(c, RBX);
int stack_bytes = ((total_stack_size+15)/16)*16;
if (stack_bytes > 0) {
- emitN(c,(uint8_t[]){0x48,0x81,0xEC},3); // sub rsp, imm32
+ emitN(c,(uint8_t[]){0x48,0x81,0xEC},3);
emit32(c,(uint32_t)stack_bytes);
}
}
static void emit_epilogue(CodeBuf *c) {
- emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3); // restore RBX; leave; ret
+ emit_pop_reg(c, RBX); emit8(c,0xC9); emit8(c,0xC3);
}
static void emit_lea_rax_rbp_disp(CodeBuf *c, int disp32) {
emit8(c,0x48); emit8(c,0x8D); emit_modrm(c,2,RAX,RBP); emit32(c,(uint32_t)disp32);
}
+static void emit_movzx_rax_mem16(CodeBuf *c, int disp32) {
+ emit8(c,0x48); emit8(c,0x0F); emit8(c,0xB7); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
+}
+static void emit_mov_mem16_rax(CodeBuf *c, int disp32) {
+ emit8(c,0x66); emit_rex(c,RAX,RBP,0); emit8(c,0x89); emit_modrm(c,2,RAX,RBP); emit32(c,disp32);
+}
static void emit_load_rax_from_mem(CodeBuf *c, int disp32, int size) {
- if (size == 1) emit_movzx_rax_mem8(c, disp32); else emit_mov_reg_mem64(c, RAX, disp32);
+ if (size == 1) emit_movzx_rax_mem8(c, disp32);
+ else if (size == 2) emit_movzx_rax_mem16(c, disp32);
+ else emit_mov_reg_mem64(c, RAX, disp32);
}
static void emit_store_rax_to_mem(CodeBuf *c, int disp32, int size) {
- if (size == 1) emit_mov_mem8_rax(c, disp32); else emit_mov_mem64_reg(c, disp32, RAX);
+ if (size == 1) emit_mov_mem8_rax(c, disp32);
+ else if (size == 2) emit_mov_mem16_rax(c, disp32);
+ else emit_mov_mem64_reg(c, disp32, RAX);
}
-/* --- Variable and type helpers --- */
-
static int calculate_type_size(_TY type) {
if (type.ptr_level > 0) return 8;
- int base_size = (type.base == TY_CHAR) ? 1 : 8;
- size_t total = (size_t)base_size;
+ int slot;
+ switch (type.base) {
+ case TY_CHAR: slot = 1; break;
+ case TY_BOOL: slot = 1; break;
+ case TY_SHORT: slot = 2; break;
+ case TY_INT: slot = 8; break;
+ case TY_FLOAT: slot = 8; break;
+ case TY_LONG: slot = 8; break;
+ case TY_VOID: slot = 0; break;
+ default: slot = 8; break;
+ }
if (type.array_size > 0) {
- total *= (size_t)type.array_size;
+ size_t total = (size_t)slot * (size_t)type.array_size;
if (total > (size_t)INT_MAX) { fprintf(stderr, "[JIT] array size too large\n"); exit(1); }
+ return (int)total;
}
- return (int)total;
+ return slot;
}
static int align_offset(int offset, _TY type) {
- int align = (type.base == TY_CHAR) ? 1 : 8;
+ int align;
+ if (type.ptr_level > 0) { align = 8; }
+ else switch (type.base) {
+ case TY_CHAR: align = 1; break;
+ case TY_BOOL: align = 1; break;
+ case TY_SHORT: align = 2; break;
+ case TY_INT: align = 8; break;
+ case TY_FLOAT: align = 8; break;
+ case TY_LONG: align = 8; break;
+ default: align = 8; break;
+ }
if (offset % align == 0) return offset;
+ if (offset < 0)
+ return ((offset - (align - 1)) / align) * align;
return (offset / align) * align;
}
+static int ty_slot_size(_TY ty) {
+ if (ty.ptr_level > 0) return 8;
+ switch (ty.base) {
+ case TY_CHAR: return 1;
+ case TY_BOOL: return 1;
+ case TY_SHORT: return 2;
+ case TY_INT: return 8; /* promoted to 64-bit slot */
+ case TY_FLOAT: return 8; /* stored in 64-bit slot (integer bits) */
+ case TY_LONG: return 8;
+ case TY_VOID: return 0;
+ default: return 8;
+ }
+}
+
static void reset_varmap(JIT *jit) {
for (VarMap *v = jit->var_list; v;) {
VarMap *n = v->next; free(v->name); free(v); v = n;
}
jit->var_list = NULL;
- jit->next_local_offset = -16; // after saved RBX at RBP-8, 16-byte aligned
+ jit->next_local_offset = -16;
}
static void add_var(JIT *jit, const char *name, _TY type) {
@@ -248,9 +338,9 @@ static void add_var(JIT *jit, const char *name, _TY type) {
if (!v->name) { fprintf(stderr, "[JIT] strdup failed in add_var\n"); free(v); exit(1); }
v->type = type;
int type_size = calculate_type_size(type);
+ jit->next_local_offset -= type_size;
jit->next_local_offset = align_offset(jit->next_local_offset, type);
v->offset = jit->next_local_offset;
- jit->next_local_offset -= type_size;
v->next = jit->var_list;
jit->var_list = v;
}
@@ -261,23 +351,60 @@ static int get_var_offset(JIT *jit, const char *name) {
fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1);
}
+static GlobalVar *find_global(JIT *jit, const char *name) {
+ for (GlobalVar *g = jit->global_list; g; g = g->next)
+ if (strcmp(g->name, name) == 0) return g;
+ return NULL;
+}
+
+static GlobalVar *register_global(JIT *jit, const char *name, _TY type) {
+ if (find_global(jit, name)) {
+ fprintf(stderr, "[JIT] Duplicate global '%s'\n", name); exit(1);
+ }
+ int elem_sz = ty_slot_size((_TY){type.base, type.ptr_level > 0 ? type.ptr_level : 0, -1});
+ int n_elems = (type.array_size > 0) ? type.array_size : 1;
+ int total_sz = elem_sz * n_elems;
+ size_t align = (elem_sz < 8) ? elem_sz : 8;
+ size_t off = (jit->globals_used + align - 1) & ~(align - 1);
+ if (off + (size_t)total_sz > jit->globals_cap) {
+ fprintf(stderr, "[JIT] Globals segment full\n"); exit(1);
+ }
+ GlobalVar *g = (GlobalVar *)malloc(sizeof(GlobalVar));
+ if (!g) { fprintf(stderr, "[JIT] OOM\n"); exit(1); }
+ g->name = strdup(name);
+ g->addr = jit->globals_buf + off;
+ g->type = type;
+ g->size = total_sz;
+ g->next = jit->global_list;
+ jit->global_list = g;
+ jit->globals_used = off + total_sz;
+ return g;
+}
+
static _TY get_var_type(JIT *jit, const char *name) {
for (VarMap *v = jit->var_list; v; v = v->next)
if (strcmp(v->name, name) == 0) return v->type;
+ /* Fall back to globals */
+ GlobalVar *g = find_global(jit, name);
+ if (g) return g->type;
fprintf(stderr, "[JIT] Unknown variable '%s'\n", name); exit(1);
}
-/* --- Type checking --- */
+static _TY get_func_ret_type(JIT *jit, const char *name);
+static void reset_varmap(JIT *jit);
static _TY get_expr_type(JIT *jit, _EX *expr) {
switch (expr->kind) {
case EX_NUMBER: return (_TY){TY_INT, 0, -1};
case EX_STRING: return (_TY){TY_CHAR, 1, -1};
- case EX_VAR: return get_var_type(jit, expr->name);
+ case EX_VAR: {
+ _TY t = get_var_type(jit, expr->name);
+ if (t.array_size > 0) return (_TY){t.base, t.ptr_level + 1, -1};
+ return t;
+ }
case EX_BINOP: {
_TY left_type = get_expr_type(jit, expr->binop.l);
- // All arithmetic, comparison, bitwise ops produce int
- (void)get_expr_type(jit, expr->binop.r);
+ (void)get_expr_type(jit, expr->binop.r);
switch (expr->binop.op) {
case TK_PLUS: case TK_MINUS: case TK_STAR: case TK_SLASH: case TK_PERCENT:
case TK_EQ: case TK_NE: case TK_LT: case TK_LE: case TK_GT: case TK_GE:
@@ -287,7 +414,13 @@ static _TY get_expr_type(JIT *jit, _EX *expr) {
default: return left_type;
}
}
- case EX_CALL: return (_TY){TY_INT, 0, -1};
+ case EX_CALL:
+ if (strcmp(expr->call.func_name, "syscall") == 0)
+ return (_TY){TY_LONG, 0, -1};
+ if (strcmp(expr->call.func_name, "__initlist__") == 0 ||
+ strcmp(expr->call.func_name, "__sizeof__") == 0)
+ return (_TY){TY_INT, 0, -1};
+ return get_func_ret_type(jit, expr->call.func_name);
case EX_INDEX: {
_TY t = get_expr_type(jit, expr->index.array);
if (t.array_size > 0) return (_TY){t.base, t.ptr_level, -1};
@@ -303,31 +436,53 @@ static _TY get_expr_type(JIT *jit, _EX *expr) {
_TY t = get_expr_type(jit, expr->addr.expr);
return (_TY){t.base, t.ptr_level+1, -1};
}
+ case EX_TERNARY:
+ return get_expr_type(jit, expr->ternary.then_expr);
+ case EX_CAST:
+ return expr->cast.to;
default: return (_TY){TY_INT, 0, -1};
}
}
+static int type_is_integer(_TY ty) {
+ if (ty.ptr_level > 0) return 0;
+ switch (ty.base) {
+ case TY_INT: case TY_CHAR: case TY_SHORT:
+ case TY_LONG: case TY_BOOL: return 1;
+ default: return 0;
+ }
+}
+
static int types_compatible(_TY expected, _TY actual) {
if (expected.base == actual.base &&
expected.ptr_level == actual.ptr_level &&
expected.array_size == actual.array_size) return 1;
- // Allow untyped int literals to be assigned anywhere
if (actual.base == TY_INT && actual.ptr_level == 0 && actual.array_size == -1) return 1;
+ if (type_is_integer(expected) && type_is_integer(actual)) return 1;
+ if (expected.ptr_level > 0 && actual.ptr_level > 0) return 1;
+ if (expected.ptr_level > 0 && actual.array_size > 0 &&
+ expected.base == actual.base &&
+ expected.ptr_level == actual.ptr_level + 1) return 1;
return 0;
}
-/* --- Function registry --- */
-
-static void register_func(JIT *jit, const char *name) {
+static void register_func(JIT *jit, const char *name, _TY ret_type) {
FuncMap *f = (FuncMap *)malloc(sizeof(FuncMap));
if (!f) { fprintf(stderr, "[JIT] malloc failed in register_func\n"); exit(1); }
f->name = strdup(name);
if (!f->name) { fprintf(stderr, "[JIT] strdup failed in register_func\n"); free(f); exit(1); }
f->addr = NULL; f->size = 0; f->alloc_size = 0;
+ f->ret_type = ret_type;
f->next = jit->func_list;
jit->func_list = f;
}
+static _TY get_func_ret_type(JIT *jit, const char *name) {
+ for (FuncMap *f = jit->func_list; f; f = f->next)
+ if (strcmp(f->name, name) == 0) return f->ret_type;
+ return (_TY){TY_INT, 0, -1};
+}
+
static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, size_t alloc_size) {
for (FuncMap *f = jit->func_list; f; f = f->next) {
if (strcmp(f->name, name) != 0) continue;
@@ -352,17 +507,19 @@ static void set_func_addr(JIT *jit, const char *name, void *addr, size_t size, s
static void *get_func_addr(JIT *jit, const char *name) {
for (FuncMap *f = jit->func_list; f; f = f->next) {
if (strcmp(f->name, name) == 0)
- return f->addr ? f->addr : (void*)0xDEADBEEF; // placeholder; patched later
+ return f->addr ? f->addr : (void*)0xDEADBEEF;
}
fprintf(stderr, "[JIT] get_func_addr: unknown function '%s'\n", name); exit(1);
}
-/* --- Stack size calculation --- */
static int calculate_stack_size(_STN *s) {
int total = 0;
while (s) {
- if (s->kind == STK_VAR_DECL) total += calculate_type_size(s->var_decl.type);
+ if (s->kind == STK_VAR_DECL) {
+ int sz = calculate_type_size(s->var_decl.type);
+ total += (sz < 8 && s->var_decl.type.array_size <= 0) ? 8 : sz;
+ }
if (s->kind == STK_BLOCK) total += calculate_stack_size(s->body);
if (s->kind == STK_FOR) {
if (s->fr.init) total += calculate_stack_size(s->fr.init);
@@ -370,15 +527,21 @@ static int calculate_stack_size(_STN *s) {
}
s = s->n;
}
- return total;
+ return total + 8;
+}
+
+static int calculate_total_stack_size(_FN *f) {
+ int param_space = 0;
+ for (int i = 0; i < f->pac && i < 6; i++)
+ param_space += (calculate_type_size(f->param_types[i]) < 8) ? 8
+ : calculate_type_size(f->param_types[i]);
+ return calculate_stack_size(f->body) + param_space;
}
-/* --- Code generation (forward declarations) --- */
static void gen_expr_jit(JIT *jit, _EX *e);
static int gen_stmt_jit(JIT *jit, _STN *s);
-/* --- Statement generation --- */
static int gen_stmt_jit(JIT *jit, _STN *s) {
while (s) {
@@ -387,6 +550,19 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
case STK_VAR_DECL:
add_var(jit, s->var_decl.name, s->var_decl.type);
if (s->var_decl.init) {
+ if (s->var_decl.init->kind == EX_CALL &&
+ strcmp(s->var_decl.init->call.func_name, "__initlist__") == 0) {
+ _TY vty = s->var_decl.type;
+ int elem_sz = ty_slot_size((_TY){vty.base, 0, -1});
+ int arr_off = get_var_offset(jit, s->var_decl.name);
+ for (int ii = 0; ii < s->var_decl.init->call.argc; ii++) {
+ gen_expr_jit(jit, s->var_decl.init->call.args[ii]);
+ /* address of arr[ii] = RBP + arr_off + ii*elem_sz */
+ int elem_off = arr_off + ii * elem_sz;
+ emit_store_rax_to_mem(&jit->cb, elem_off, elem_sz);
+ }
+ break;
+ }
_TY init_type = get_expr_type(jit, s->var_decl.init);
if (!types_compatible(s->var_decl.type, init_type)) {
fprintf(stderr, "[JIT] Type mismatch: cannot assign %s to %s\n",
@@ -394,8 +570,7 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
exit(1);
}
gen_expr_jit(jit, s->var_decl.init);
- int sz = (s->var_decl.type.ptr_level > 0) ? 8
- : (s->var_decl.type.base == TY_CHAR) ? 1 : 8;
+ int sz = ty_slot_size(s->var_decl.type);
emit_store_rax_to_mem(&jit->cb, get_var_offset(jit, s->var_decl.name), sz);
}
break;
@@ -403,6 +578,15 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
case STK_ASSIGN: {
_EX *lhs = s->assign.lhs;
if (lhs->kind == EX_VAR) {
+ GlobalVar *gv = find_global(jit, lhs->name);
+ if (gv) {
+ gen_expr_jit(jit, s->assign.expr);
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr);
+ int sz = ty_slot_size(gv->type);
+ if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ break;
+ }
int offset = get_var_offset(jit, lhs->name);
_TY type = get_var_type(jit, lhs->name);
_TY expr_type = get_expr_type(jit, s->assign.expr);
@@ -412,36 +596,74 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
exit(1);
}
gen_expr_jit(jit, s->assign.expr);
- int sz = (type.ptr_level > 0) ? 8 : (type.base == TY_CHAR) ? 1 : 8;
+ int sz = ty_slot_size(type);
emit_store_rax_to_mem(&jit->cb, offset, sz);
} else if (lhs->kind == EX_DEREF) {
gen_expr_jit(jit, lhs->deref.expr);
- emit_mov_reg_reg(&jit->cb, RBX, RAX); // RBX = address
- gen_expr_jit(jit, s->assign.expr); // RAX = value
- emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ emit_push_reg(&jit->cb, RAX); // save address on stack
+ gen_expr_jit(jit, s->assign.expr); // RAX = value (may clobber RBX)
+ emit_pop_reg(&jit->cb, RBX); // RBX = address
+ _TY ptr_ty = get_expr_type(jit, lhs->deref.expr);
+ int dsz = (ptr_ty.ptr_level > 1) ? 8 : ty_slot_size((_TY){ptr_ty.base, 0, -1});
+ if (dsz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
} else if (lhs->kind == EX_INDEX) {
if (lhs->index.array->kind != EX_VAR) {
gen_expr_jit(jit, lhs->index.array); break;
}
_TY var_type = get_var_type(jit, lhs->index.array->name);
- if (var_type.array_size <= 0) {
- fprintf(stderr, "[JIT] Cannot index non-array '%s'\n", lhs->index.array->name); exit(1);
+ int element_size = ty_slot_size((_TY){var_type.base, 0, -1});
+
+ if (var_type.ptr_level > 0) {
+ GlobalVar *gv_ptr2 = find_global(jit, lhs->index.array->name);
+ if (gv_ptr2) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr2->addr);
+ emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref global ptr */
+ } else {
+ int ptr_offset = get_var_offset(jit, lhs->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ }
+ gen_expr_jit(jit, lhs->index.index); // RAX = index
+ if (element_size > 1) {
+ emit_mov_reg_reg(&jit->cb, RCX, RBX);
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ emit_mov_reg_reg(&jit->cb, RBX, RCX);
+ }
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &ptr[index]
+ emit_push_reg(&jit->cb, RBX); // save address
+ gen_expr_jit(jit, s->assign.expr); // RAX = value
+ emit_pop_reg(&jit->cb, RBX); // restore address
+ if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+
+ } else if (var_type.array_size > 0) {
+ GlobalVar *gv_arr2 = find_global(jit, lhs->index.array->name);
+ gen_expr_jit(jit, lhs->index.index); // RAX = index
+ emit_mov_reg_imm64(&jit->cb, RBX, element_size);
+ emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
+ if (gv_arr2) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr2->addr);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index]
+ } else {
+ int array_offset = get_var_offset(jit, lhs->index.array->name);
+ emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
+ emit_add_reg_reg(&jit->cb, RBX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset
+ }
+ emit_push_reg(&jit->cb, RBX); // save address
+ gen_expr_jit(jit, s->assign.expr); // RAX = value
+ emit_pop_reg(&jit->cb, RBX); // RBX = address
+ if (element_size == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+
+ } else {
+ fprintf(stderr, "[JIT] Cannot index non-array/non-pointer '%s'\n", lhs->index.array->name); exit(1);
}
- int array_offset = get_var_offset(jit, lhs->index.array->name);
- int element_size = (var_type.base == TY_CHAR) ? 1 : 8;
- gen_expr_jit(jit, lhs->index.index); // RAX = index
- emit_mov_reg_imm64(&jit->cb, RBX, element_size);
- emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
- emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
- emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset
- emit_mov_reg_reg(&jit->cb, RAX, RBX);
- emit_mov_reg_reg(&jit->cb, RBX, RBP);
- emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index]
- emit_mov_reg_reg(&jit->cb, RBX, RAX);
- gen_expr_jit(jit, s->assign.expr); // RAX = value
- emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
} else {
fprintf(stderr, "[JIT] Unsupported assignment LHS kind %d\n", lhs->kind); exit(1);
@@ -479,38 +701,151 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
}
case STK_WHILE: {
+ int depth = jit->loop_depth++;
+ jit->break_patch_count[depth] = 0;
+ jit->cont_patch_count[depth] = 0;
+
size_t loop_start = jit->cb.len;
gen_expr_jit(jit, s->whl.cond);
emit_test_rax_rax(&jit->cb);
size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
gen_stmt_jit(jit, s->whl.body);
+ size_t cont_target = jit->cb.len;
+ for (int i = 0; i < jit->cont_patch_count[depth]; i++) {
+ size_t p = jit->cont_patches[depth][i];
+ int32_t r = (int32_t)(cont_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
- int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6));
+ size_t break_target = jit->cb.len;
+ int32_t rel_end = (int32_t)(break_target - (jz_pos + 6));
memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4);
+ for (int i = 0; i < jit->break_patch_count[depth]; i++) {
+ size_t p = jit->break_patches[depth][i];
+ int32_t r = (int32_t)(break_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ jit->loop_depth--;
break;
}
case STK_FOR: {
+ int depth = jit->loop_depth++;
+ jit->break_patch_count[depth] = 0;
+ jit->cont_patch_count[depth] = 0;
+
if (s->fr.init) gen_stmt_jit(jit, s->fr.init);
size_t loop_start = jit->cb.len;
- if (s->fr.cond) {
+ size_t jz_pos_for = 0;
+ int has_cond = (s->fr.cond != NULL);
+ if (has_cond) {
gen_expr_jit(jit, s->fr.cond);
emit_test_rax_rax(&jit->cb);
- size_t jz_pos = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
- gen_stmt_jit(jit, s->fr.body);
- if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
- emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
- int32_t rel_end = (int32_t)(jit->cb.len - (jz_pos + 6));
- memcpy(jit->cb.buf + jz_pos + 2, &rel_end, 4);
- } else {
- gen_stmt_jit(jit, s->fr.body);
- if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
- emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ jz_pos_for = jit->cb.len; emit_jcc_rel32(&jit->cb, 0x04, 0);
+ }
+ gen_stmt_jit(jit, s->fr.body);
+ size_t cont_target = jit->cb.len;
+ for (int i = 0; i < jit->cont_patch_count[depth]; i++) {
+ size_t p = jit->cont_patches[depth][i];
+ int32_t r = (int32_t)(cont_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
}
+ if (s->fr.step) gen_stmt_jit(jit, s->fr.step);
+ emit_jmp_rel32(&jit->cb, (int32_t)(loop_start - (jit->cb.len + 5)));
+ size_t break_target = jit->cb.len;
+ if (has_cond) {
+ int32_t rel_end = (int32_t)(break_target - (jz_pos_for + 6));
+ memcpy(jit->cb.buf + jz_pos_for + 2, &rel_end, 4);
+ }
+ for (int i = 0; i < jit->break_patch_count[depth]; i++) {
+ size_t p = jit->break_patches[depth][i];
+ int32_t r = (int32_t)(break_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ jit->loop_depth--;
break;
}
- default:
+ case STK_DOWHILE: {
+ int depth = jit->loop_depth++;
+ jit->break_patch_count[depth] = 0;
+ jit->cont_patch_count[depth] = 0;
+
+ size_t loop_start = jit->cb.len;
+ gen_stmt_jit(jit, s->dowhl.body);
+ /* continue → jump to condition check */
+ size_t cont_target = jit->cb.len;
+ for (int i = 0; i < jit->cont_patch_count[depth]; i++) {
+ size_t p = jit->cont_patches[depth][i];
+ int32_t r = (int32_t)(cont_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ gen_expr_jit(jit, s->dowhl.cond);
+ emit_test_rax_rax(&jit->cb);
+ emit_jcc_rel32(&jit->cb, 0x05, (int32_t)(loop_start - (jit->cb.len + 6)));
+ size_t break_target = jit->cb.len;
+ for (int i = 0; i < jit->break_patch_count[depth]; i++) {
+ size_t p = jit->break_patches[depth][i];
+ int32_t r = (int32_t)(break_target - (p + 5));
+ memcpy(jit->cb.buf + p + 1, &r, 4);
+ }
+ jit->loop_depth--;
+ break;
+ }
+
+ case STK_BREAK: {
+ if (jit->loop_depth == 0) {
+ fprintf(stderr, "[JIT] 'break' outside loop\n"); exit(1);
+ }
+ int depth = jit->loop_depth - 1;
+ size_t pos = jit->cb.len;
+ emit_jmp_rel32(&jit->cb, 0);
+ int cnt = jit->break_patch_count[depth];
+ if (cnt >= 256) { fprintf(stderr, "[JIT] Too many breaks\n"); exit(1); }
+ jit->break_patches[depth][cnt] = pos;
+ jit->break_patch_count[depth]++;
+ break;
+ }
+
+ case STK_CONTINUE: {
+ if (jit->loop_depth == 0) {
+ fprintf(stderr, "[JIT] 'continue' outside loop\n"); exit(1);
+ }
+ int depth = jit->loop_depth - 1;
+ size_t pos = jit->cb.len;
+ emit_jmp_rel32(&jit->cb, 0);
+ int cnt = jit->cont_patch_count[depth];
+ if (cnt >= 256) { fprintf(stderr, "[JIT] Too many continues\n"); exit(1); }
+ jit->cont_patches[depth][cnt] = pos;
+ jit->cont_patch_count[depth]++;
+ break;
+ }
+
+ case STK_GLOBAL: {
+ GlobalVar *gv = find_global(jit, s->global.name);
+ if (!gv) {
+ fprintf(stderr, "[JIT] Global '%s' not registered\n", s->global.name); exit(1);
+ }
+ if (s->global.init) {
+ if (s->global.init->kind == EX_CALL &&
+ strcmp(s->global.init->call.func_name, "__initlist__") == 0) {
+ int esz = ty_slot_size((_TY){gv->type.base, 0, -1});
+ for (int ii = 0; ii < s->global.init->call.argc; ii++) {
+ gen_expr_jit(jit, s->global.init->call.args[ii]);
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)(gv->addr + ii * esz));
+ if (esz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ }
+ } else {
+ gen_expr_jit(jit, s->global.init);
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr);
+ int sz = ty_slot_size(gv->type);
+ if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ }
+ }
+ break;
+ }
fprintf(stderr, "[JIT] Unsupported statement kind %d\n", s->kind); exit(1);
}
s = s->n;
@@ -518,7 +853,6 @@ static int gen_stmt_jit(JIT *jit, _STN *s) {
return 0;
}
-/* --- Expression generation --- */
static void gen_expr_jit(JIT *jit, _EX *e) {
if (!e) return;
@@ -529,16 +863,110 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
break;
case EX_VAR: {
- int off = get_var_offset(jit, e->name);
- _TY ty = get_var_type(jit, e->name);
- int sz = (ty.ptr_level > 0) ? 8 : (ty.base == TY_CHAR) ? 1 : 8;
- if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off);
- else emit_mov_reg_mem64(&jit->cb, RAX, off);
- break;
+ GlobalVar *gv = NULL;
+ for (VarMap *v = jit->var_list; v; v = v->next) {
+ if (strcmp(v->name, e->name) == 0) goto local_var;
+ }
+ gv = find_global(jit, e->name);
+ if (gv) {
+ _TY ty = gv->type;
+ if (ty.array_size > 0) {
+ /* Global array: load its absolute address */
+ emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv->addr);
+ break;
+ }
+ /* Global scalar: absolute address → load through it */
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv->addr);
+ int sz = ty_slot_size(ty);
+ if (sz == 1) emit_movzx_rax_mem8_base(&jit->cb, RBX);
+ else emit_mov_reg_mem_reg(&jit->cb, RAX, RBX);
+ break;
+ }
+ local_var: {
+ int off = get_var_offset(jit, e->name);
+ _TY ty = get_var_type(jit, e->name);
+ if (ty.array_size > 0) {
+ emit_lea_rax_rbp_disp(&jit->cb, off);
+ break;
+ }
+ int sz = ty_slot_size(ty);
+ if (sz == 1) emit_movzx_rax_mem8(&jit->cb, off);
+ else emit_mov_reg_mem64(&jit->cb, RAX, off);
+ break;
+ }
}
case EX_BINOP: {
- // Short-circuit AND
+ if (e->binop.op == TK_INC || e->binop.op == TK_DEC) {
+ int is_pre = (e->binop.r->value == -1);
+ int is_inc = (e->binop.op == TK_INC);
+ _EX *lval = e->binop.l;
+
+ _TY lty = get_expr_type(jit, lval);
+ int step = (lty.ptr_level > 0) ? ty_slot_size((_TY){lty.base,0,-1}) : 1;
+
+ gen_expr_jit(jit, lval); /* RAX = old value */
+ emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = old value */
+ emit_mov_reg_imm64(&jit->cb, RBX, step);
+ if (is_inc) emit_add_reg_reg(&jit->cb, RCX, RBX);
+ else emit_sub_reg_reg(&jit->cb, RCX, RBX);
+ if (lval->kind == EX_VAR) {
+ GlobalVar *gv_inc = find_global(jit, lval->name);
+ emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ if (gv_inc) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_inc->addr);
+ int sz = ty_slot_size(lty);
+ if (sz == 1) emit_mov_mem8_reg_reg(&jit->cb, RBX);
+ else emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ } else {
+ int off = get_var_offset(jit, lval->name);
+ int sz = ty_slot_size(lty);
+ emit_store_rax_to_mem(&jit->cb, off, sz);
+ }
+ } else {
+ emit_push_reg(&jit->cb, RCX); /* save new value */
+ if (lval->kind == EX_DEREF) {
+ gen_expr_jit(jit, lval->deref.expr); /* RAX = address */
+ } else { /* EX_INDEX */
+ _TY vty = get_var_type(jit, lval->index.array->name);
+ int esz = ty_slot_size((_TY){vty.base,0,-1});
+ if (vty.ptr_level > 0) {
+ int poff = get_var_offset(jit, lval->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RAX, poff);
+ emit_push_reg(&jit->cb, RAX);
+ gen_expr_jit(jit, lval->index.index);
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RBX, esz);
+ emit_imul_rax_rbx(&jit->cb);
+ }
+ emit_pop_reg(&jit->cb, RBX);
+ emit_add_reg_reg(&jit->cb, RAX, RBX);
+ } else {
+ int aoff = get_var_offset(jit, lval->index.array->name);
+ gen_expr_jit(jit, lval->index.index);
+ emit_mov_reg_imm64(&jit->cb, RBX, esz);
+ emit_imul_rax_rbx(&jit->cb);
+ emit_mov_reg_imm64(&jit->cb, RBX, aoff);
+ emit_add_reg_reg(&jit->cb, RBX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RAX, RBX);
+ }
+ }
+ emit_mov_reg_reg(&jit->cb, RBX, RAX); /* RBX = address */
+ emit_pop_reg(&jit->cb, RCX); /* RCX = new val */
+ emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ emit_mov_mem_reg_reg(&jit->cb, RBX, RAX);
+ }
+ if (is_pre) emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ if (!is_pre) {
+ emit_mov_reg_reg(&jit->cb, RAX, RCX);
+ emit_mov_reg_imm64(&jit->cb, RBX, step);
+ if (is_inc) emit_sub_reg_reg(&jit->cb, RAX, RBX);
+ else emit_add_reg_reg(&jit->cb, RAX, RBX);
+ }
+ break;
+ }
if (e->binop.op == TK_AND) {
gen_expr_jit(jit, e->binop.l);
emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
@@ -554,7 +982,6 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4);
break;
}
- // Short-circuit OR
if (e->binop.op == TK_OR) {
gen_expr_jit(jit, e->binop.l);
emit_mov_reg_reg(&jit->cb, RBX, RAX); emit_or_rax_rbx(&jit->cb);
@@ -570,18 +997,54 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
memcpy(jit->cb.buf + jmp_pos + 1, &rel, 4);
break;
}
- // All other binary ops: eval left -> push; eval right -> RBX = left
gen_expr_jit(jit, e->binop.l);
emit_push_reg(&jit->cb, RAX);
gen_expr_jit(jit, e->binop.r);
emit_pop_reg(&jit->cb, RBX); // RBX = left, RAX = right
switch (e->binop.op) {
- case TK_PLUS: emit_add_rax_rbx(&jit->cb); break;
- case TK_MINUS:
+ case TK_PLUS: {
+ _TY lt = get_expr_type(jit, e->binop.l);
+ _TY rt = get_expr_type(jit, e->binop.r);
+ int lptr = (lt.ptr_level > 0 || lt.array_size > 0);
+ int rptr = (rt.ptr_level > 0 || rt.array_size > 0);
+ if (lptr && !rptr) {
+ int esz = ty_slot_size((_TY){lt.base, 0, -1});
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RCX, esz);
+ emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4);
+ }
+ } else if (rptr && !lptr) {
+ int esz = ty_slot_size((_TY){rt.base, 0, -1});
+ emit_mov_reg_reg(&jit->cb, RCX, RAX); /* RCX = ptr */
+ emit_mov_reg_reg(&jit->cb, RAX, RBX); /* RAX = int */
+ emit_mov_reg_reg(&jit->cb, RBX, RCX); /* RBX = ptr */
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RCX, esz);
+ emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4);
+ }
+ }
+ emit_add_rax_rbx(&jit->cb);
+ break;
+ }
+ case TK_MINUS: {
+ _TY lt = get_expr_type(jit, e->binop.l);
+ _TY rt = get_expr_type(jit, e->binop.r);
+ int lptr = (lt.ptr_level > 0 || lt.array_size > 0);
+ int rptr = (rt.ptr_level > 0 || rt.array_size > 0);
+ if (lptr && !rptr) {
+ /* ptr - int: scale int by element size */
+ int esz = ty_slot_size((_TY){lt.base, 0, -1});
+ if (esz > 1) {
+ emit_mov_reg_imm64(&jit->cb, RCX, esz);
+ emitN(&jit->cb, (uint8_t[]){0x48,0x0F,0xAF,0xC1}, 4);
+ }
+ }
+ /* RAX=right(scaled), RBX=left: result = left - right */
emit_mov_reg_reg(&jit->cb, RCX, RAX);
emit_mov_reg_reg(&jit->cb, RAX, RBX);
emit_rex(&jit->cb, RCX, RAX, 1); emit8(&jit->cb, 0x29); emit_modrm(&jit->cb, 3, RCX, RAX);
break;
+ }
case TK_STAR: emit_imul_rax_rbx(&jit->cb); break;
case TK_SLASH:
emit_mov_reg_reg(&jit->cb, RCX, RAX);
@@ -626,24 +1089,84 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
case EX_CALL: {
+ if (strcmp(e->call.func_name, "__sizeof__") == 0) {
+ _EX *arg = e->call.args[0];
+ _TY ty;
+ if (arg->kind == EX_VAR) {
+ ty = get_var_type(jit, arg->name);
+ } else {
+ ty = get_expr_type(jit, arg);
+ }
+ int sz;
+ if (ty.array_size > 0) {
+ int esz = (ty.base == TY_CHAR) ? 1 : (ty.base == TY_SHORT) ? 2 :
+ (ty.base == TY_LONG) ? 8 : 4;
+ sz = esz * ty.array_size;
+ } else if (ty.ptr_level > 0) {
+ sz = 8;
+ } else {
+ switch (ty.base) {
+ case TY_CHAR: sz = 1; break;
+ case TY_SHORT: sz = 2; break;
+ case TY_LONG: sz = 8; break;
+ default: sz = 4; break; /* int, bool, float */
+ }
+ }
+ emit_mov_reg_imm64(&jit->cb, RAX, sz);
+ break;
+ }
+
+ if (strcmp(e->call.func_name, "syscall") == 0) {
+ int argc = e->call.argc;
+ if (argc < 1) {
+ fprintf(stderr, "[JIT] syscall() requires at least 1 argument (number)\n"); exit(1);
+ }
+ if (argc > 7) {
+ fprintf(stderr, "[JIT] syscall() supports at most 7 arguments\n"); exit(1);
+ }
+ const int sc_regs[6] = {RDI, RSI, RDX, R10, R8, R9};
+ int nargs = argc - 1; /* number of args after the syscall number */
+
+ for (int i = argc - 1; i >= 0; i--) {
+ gen_expr_jit(jit, e->call.args[i]);
+ emit_push_reg(&jit->cb, RAX);
+ }
+
+ emit_pop_reg(&jit->cb, RAX);
+
+ for (int i = 0; i < nargs; i++) {
+ emit_pop_reg(&jit->cb, sc_regs[i]);
+ }
+
+ emit8(&jit->cb, 0x0F);
+ emit8(&jit->cb, 0x05);
+ break;
+ }
+
int total_args = e->call.argc;
int stack_args = total_args > 6 ? total_args - 6 : 0;
+ int reg_args = total_args < 6 ? total_args : 6;
int padding = (stack_args % 2) ? 8 : 0;
const int arg_regs[6] = {RDI, RSI, RDX, RCX, R8, R9};
+ /* Push stack-passed args right-to-left (args[total-1] first). */
for (int i = total_args-1; i >= 6; i--) {
gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX);
}
- for (int i = 0; i < total_args && i < 6; i++) {
- gen_expr_jit(jit, e->call.args[i]); emit_mov_reg_reg(&jit->cb, arg_regs[i], RAX);
+ for (int i = reg_args-1; i >= 0; i--) {
+ gen_expr_jit(jit, e->call.args[i]); emit_push_reg(&jit->cb, RAX);
+ }
+
+ for (int i = 0; i < reg_args; i++) {
+ emit_pop_reg(&jit->cb, arg_regs[i]);
}
+
if (padding) {
emit8(&jit->cb,0x48); emit8(&jit->cb,0x83); emit8(&jit->cb,0xEC); emit8(&jit->cb,0x08);
}
void *addr = get_func_addr(jit, e->call.func_name);
if (addr == (void*)0xDEADBEEF) {
- // Forward call: emit placeholder imm64, record offset for later patching
emit_movabs_rax_imm64(&jit->cb, (uint64_t)0xDEADBEEF);
PatchEntry *patch = (PatchEntry*)malloc(sizeof(PatchEntry));
if (!patch) { fprintf(stderr, "[JIT] malloc failed (patch)\n"); exit(1); }
@@ -659,7 +1182,7 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
} else {
emit_movabs_rax_imm64(&jit->cb, (uint64_t)addr);
}
- emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0); // CALL RAX
+ emit8(&jit->cb, 0xFF); emit8(&jit->cb, 0xD0);
if (stack_args * 8 + padding > 0) {
emit8(&jit->cb,0x48); emit8(&jit->cb,0x81); emit8(&jit->cb,0xC4);
@@ -669,12 +1192,56 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
case EX_ADDR:
- if (e->addr.expr->kind != EX_VAR) {
+ if (e->addr.expr->kind == EX_VAR) {
+ GlobalVar *gv_addr = find_global(jit, e->addr.expr->name);
+ if (gv_addr) {
+ emit_mov_reg_imm64(&jit->cb, RAX, (uint64_t)gv_addr->addr);
+ } else {
+ emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name));
+ }
+ } else {
fprintf(stderr, "[JIT] &expr: only &var supported\n"); exit(1);
}
- emit_lea_rax_rbp_disp(&jit->cb, get_var_offset(jit, e->addr.expr->name));
break;
+ case EX_TERNARY: {
+ gen_expr_jit(jit, e->ternary.cond);
+ emit_mov_reg_reg(&jit->cb, RBX, RAX);
+ emit_or_rax_rbx(&jit->cb);
+ size_t jz_pos = jit->cb.len;
+ emit_jcc_rel32(&jit->cb, 0x04, 0);
+ gen_expr_jit(jit, e->ternary.then_expr);
+ size_t jmp_pos = jit->cb.len;
+ emit_jmp_rel32(&jit->cb, 0);
+ int32_t rel_jz = (int32_t)(jit->cb.len - (jz_pos + 6));
+ memcpy(jit->cb.buf + jz_pos + 2, &rel_jz, 4);
+ gen_expr_jit(jit, e->ternary.else_expr);
+ int32_t rel_jmp = (int32_t)(jit->cb.len - (jmp_pos + 5));
+ memcpy(jit->cb.buf + jmp_pos + 1, &rel_jmp, 4);
+ break;
+ }
+
+ case EX_CAST: {
+ gen_expr_jit(jit, e->cast.expr);
+ _TY to = e->cast.to;
+ if (to.ptr_level > 0) break; /* pointer cast: value unchanged */
+ switch (to.base) {
+ case TY_CHAR:
+ emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6);
+ emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, al */
+ break;
+ case TY_SHORT:
+ emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB7);
+ emit_modrm(&jit->cb, 3, RAX, RAX); /* movzx rax, ax */
+ break;
+ case TY_INT:
+ emit8(&jit->cb, 0x48); emit8(&jit->cb, 0x63); emit_modrm(&jit->cb, 3, RAX, RAX);
+ break;
+ default: break; /* long/void/etc: no-op */
+ }
+ break;
+ }
+
case EX_DEREF:
gen_expr_jit(jit, e->deref.expr);
emit_rex(&jit->cb, RAX, RAX, 1); emit8(&jit->cb, 0x8B); emit_modrm(&jit->cb, 0, RAX, RAX);
@@ -692,12 +1259,17 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
gen_expr_jit(jit, e->index.array); break;
}
_TY var_type = get_var_type(jit, e->index.array->name);
- int element_size = (var_type.base == TY_CHAR) ? 1 : 8;
+ int element_size = ty_slot_size((_TY){var_type.base, 0, -1});
if (var_type.ptr_level > 0) {
- // Pointer indexing: load pointer, add index*element_size
- int ptr_offset = get_var_offset(jit, e->index.array->name);
- emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ GlobalVar *gv_ptr = find_global(jit, e->index.array->name);
+ if (gv_ptr) {
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_ptr->addr);
+ emit_mov_reg_mem_reg(&jit->cb, RBX, RBX); /* deref: RBX = *addr = pointer value */
+ } else {
+ int ptr_offset = get_var_offset(jit, e->index.array->name);
+ emit_mov_reg_mem64(&jit->cb, RBX, ptr_offset); // RBX = pointer
+ }
gen_expr_jit(jit, e->index.index); // RAX = index
if (element_size > 1) {
emit_mov_reg_reg(&jit->cb, RCX, RBX); // save pointer
@@ -714,17 +1286,22 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
} else if (var_type.array_size > 0) {
- // Array indexing: compute RBP-relative address = array_offset - index*element_size
- int array_offset = get_var_offset(jit, e->index.array->name);
+ GlobalVar *gv_arr = find_global(jit, e->index.array->name);
gen_expr_jit(jit, e->index.index); // RAX = index
emit_mov_reg_imm64(&jit->cb, RBX, element_size);
emit_imul_rax_rbx(&jit->cb); // RAX = byte offset
- emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
- emit_sub_reg_reg(&jit->cb, RBX, RAX); // RBX = array_offset - byte_offset
- emit_mov_reg_reg(&jit->cb, RAX, RBX);
- emit_mov_reg_reg(&jit->cb, RBX, RBP);
- emit_add_reg_reg(&jit->cb, RAX, RBX); // RAX = &arr[index]
- emit_mov_reg_reg(&jit->cb, RBX, RAX);
+ if (gv_arr) {
+ /* Global array: base is absolute address */
+ emit_mov_reg_imm64(&jit->cb, RBX, (uint64_t)gv_arr->addr);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = &arr[index]
+ } else {
+ int array_offset = get_var_offset(jit, e->index.array->name);
+ emit_mov_reg_imm64(&jit->cb, RBX, array_offset);
+ emit_add_reg_reg(&jit->cb, RBX, RAX);
+ emit_mov_reg_reg(&jit->cb, RAX, RBX);
+ emit_mov_reg_reg(&jit->cb, RBX, RBP);
+ emit_add_reg_reg(&jit->cb, RBX, RAX); // RBX = RBP + array_offset + byte_offset
+ }
if (element_size == 1) {
emit_rex(&jit->cb, RAX, RBX, 0); emit8(&jit->cb, 0x0F); emit8(&jit->cb, 0xB6);
emit_modrm(&jit->cb, 0, RAX, RBX);
@@ -743,22 +1320,35 @@ static void gen_expr_jit(JIT *jit, _EX *e) {
}
}
-/* --- Compile one function --- */
-
static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) {
reset_varmap(jit);
jit->current_func_name = f->name;
- int total_stack_size = calculate_stack_size(f->body);
+ int total_stack_size = calculate_total_stack_size(f);
cb_init(&jit->cb);
emit_prologue(&jit->cb, total_stack_size);
const int param_regs[6] = {RDI, RSI, RDX, RCX, R8, R9};
+
for (int i = 0; i < f->pac && i < 6; i++) {
add_var(jit, f->params[i], f->param_types[i]);
emit_mov_mem64_reg(&jit->cb, get_var_offset(jit, f->params[i]), param_regs[i]);
}
+ int num_stack_params = f->pac > 6 ? f->pac - 6 : 0;
+ int stack_arg_padding = (num_stack_params % 2) ? 8 : 0;
+ int stack_arg_base = 16 + stack_arg_padding; /* offset of last-pushed (highest-index) stack arg */
+ for (int i = 6; i < f->pac; i++) {
+ VarMap *v = (VarMap *)malloc(sizeof(VarMap));
+ if (!v) { fprintf(stderr, "[JIT] malloc failed (stack param)\n"); exit(1); }
+ v->name = strdup(f->params[i]);
+ if (!v->name) { fprintf(stderr, "[JIT] strdup failed (stack param)\n"); free(v); exit(1); }
+ v->type = f->param_types[i];
+ v->offset = stack_arg_base + (i - 6) * 8;
+ v->next = jit->var_list;
+ jit->var_list = v;
+ }
+
int did_return = gen_stmt_jit(jit, f->body);
if (!did_return) {
emit_movabs_rax_imm64(&jit->cb, 0);
@@ -777,7 +1367,6 @@ static void *gen_function_jit(JIT *jit, _FN *f, size_t *out_size) {
return mem;
}
-/* --- Patch forward calls after all functions are compiled --- */
static void patch_function_calls(JIT *jit) {
for (PatchEntry *patch = jit->patch_list; patch; patch = patch->next) {
@@ -797,16 +1386,36 @@ static void patch_function_calls(JIT *jit) {
}
}
-/* --- Compile all functions and patch forward calls --- */
-
static void jit_compile_all(JIT *jit, _FN *fn_list) {
- for (_FN *cur = fn_list; cur; cur = cur->n)
- register_func(jit, cur->name);
+ /* First pass: register all globals so type lookups work during codegen */
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) == 0) {
+ /* Extract variable name from __global_NAME__ */
+ const char *start = cur->name + 9;
+ size_t len = strlen(start);
+ if (len >= 2 && start[len-2] == '_' && start[len-1] == '_') len -= 2;
+ char *vname = strndup(start, len);
+ if (!vname) { fprintf(stderr, "[JIT] OOM\n"); exit(1); }
+ /* Get type from the STK_GLOBAL node */
+ _STN *gdecl = cur->body;
+ if (gdecl && gdecl->kind == STK_GLOBAL) {
+ register_global(jit, vname, gdecl->global.type);
+ }
+ free(vname);
+ }
+ }
+
+ /* Second pass: register all real functions (so forward calls work) */
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) != 0)
+ register_func(jit, cur->name, cur->ret_type);
+ }
- // Compile in reverse order so callees are typically compiled before callers
+ /* Compile real functions in reverse order */
_FN *functions[64];
int count = 0;
for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) == 0) continue;
if (count >= 64) { fprintf(stderr, "[JIT] Too many functions (max 64)\n"); exit(1); }
functions[count++] = cur;
}
@@ -814,9 +1423,27 @@ static void jit_compile_all(JIT *jit, _FN *fn_list) {
gen_function_jit(jit, functions[i], NULL);
patch_function_calls(jit);
-}
-/* --- Entry point --- */
+ /* Run global initialisers in declaration order */
+ for (_FN *cur = fn_list; cur; cur = cur->n) {
+ if (strncmp(cur->name, "__global_", 9) != 0) continue;
+ reset_varmap(jit);
+ jit->current_func_name = cur->name;
+ cb_init(&jit->cb);
+ emit8(&jit->cb, 0x55); /* push rbp */
+ emitN(&jit->cb, (uint8_t[]){0x48,0x89,0xE5}, 3); /* mov rbp, rsp */
+ gen_stmt_jit(jit, cur->body);
+ emit8(&jit->cb, 0xC9); /* leave */
+ emit8(&jit->cb, 0xC3); /* ret */
+ size_t isz = jit->cb.len;
+ void *ibuf = mmap(NULL, isz, PROT_READ|PROT_WRITE|PROT_EXEC,
+ MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
+ if (ibuf == MAP_FAILED) { perror("mmap init"); exit(1); }
+ memcpy(ibuf, jit->cb.buf, isz);
+ ((void(*)(void))ibuf)();
+ munmap(ibuf, isz);
+ }
+}
static int jit_run(JIT *jit, int argc, char **argv) {
int (*main_func)(int, char **) = get_func_addr(jit, "main");
@@ -824,4 +1451,4 @@ static int jit_run(JIT *jit, int argc, char **argv) {
return main_func(argc, argv);
}
-#endif \ No newline at end of file
+#endif