From 29e9b2ea437d35daee7a054ed47d57a1e3ee9f64 Mon Sep 17 00:00:00 2001 From: Johannes Maier Date: Wed, 10 Jan 2024 16:03:06 +0100 Subject: [PATCH] Extend framework --- vuln.c | 169 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 155 insertions(+), 14 deletions(-) diff --git a/vuln.c b/vuln.c index 9fcfc78..de90eb4 100644 --- a/vuln.c +++ b/vuln.c @@ -2,35 +2,62 @@ #include #include #include +#include +#include +#include +#include +#include -#define MAX_PROGRAM_LEN 0x10000 +#define MAX_PROGRAM_LEN 0x1000 +#define MAX_NUM_REGISTERS 0x100 -typedef enum Instruction : uint8_t { TODO } Instruction; +typedef enum Opcode : uint8_t { COUNT_OPCODES } Opcode; -Instruction *get_program(size_t *program_len) { - puts("Now to your next program: How long should it bee?"); +typedef struct Instruction { + Opcode opcode; +} Instruction; - size_t len; - char len_buf[0x10]; +typedef struct Context { + size_t pc; + uint8_t *next_code_start; + size_t overall_code_size; + size_t register_count; + int regs[]; +} Context; + +typedef int (*exec_func_t)(); + +static bool premium_activated = false; + +size_t get_size_t(size_t limit) { + size_t val; + char buf[0x10]; char *end_ptr; do { - if (fgets(len_buf, sizeof(len_buf), stdin) == NULL) { + if (fgets(buf, sizeof(buf), stdin) == NULL) { exit(EXIT_FAILURE); } - len = strtoull(len_buf, &end_ptr, 0); + val = strtoull(buf, &end_ptr, 0); - if (len_buf == end_ptr) { + if (buf == end_ptr) { puts("That's not a integer, come back when you passed elementary school!"); exit(EXIT_FAILURE); } - if (len <= MAX_PROGRAM_LEN) { + if (val <= limit) { break; } puts("Nah, that's to long. Let's try again."); - } while (true); + return val; +} + +Instruction *get_program(size_t *program_len, size_t *register_count) { + puts("Now to your next program: How long should it bee?"); + size_t len = get_size_t(MAX_PROGRAM_LEN); + puts("How many registers do you want to use?"); + size_t num_registers = get_size_t(MAX_NUM_REGISTERS); Instruction *program = malloc(len * sizeof(Instruction)); @@ -45,23 +72,137 @@ Instruction *get_program(size_t *program_len) { } *program_len = len; + *register_count = num_registers; return program; } -int run_jit(Instruction *program, size_t len) { return 0; } +bool validate_program(Instruction *program, size_t len, size_t register_count) { + for (size_t i = 0; i < len; ++i) { + // prevent use of wrong opcodes + if (program[i].opcode >= COUNT_OPCODES) { + return false; + } + // prevent use of wrong registers + if (program[i].reg1 >= register_count || program[i].reg2 >= register_count) { + return false; + } + } + return true; +} + +Context *init_ctx(uint8_t *code, size_t allocated_code_len, size_t register_count) { + Context *ctx = (struct Context *)code; + ctx->register_count = register_count; + memset(ctx->regs, 0, sizeof(*ctx->regs) * ctx->register_count); + ctx->pc = 0; + ctx->next_code_start = code + sizeof(*ctx) + ctx->register_count * sizeof(*ctx->regs); + ctx->overall_code_size = allocated_code_len; + return ctx; +} + +void init_seccomp() { + // TODO: +} + +void exec_code(uint8_t *code) { + exec_func_t exec_func = (exec_func_t)code; + init_seccomp(); + close(0); + close(1); + close(2); + uint8_t res = exec_func(); + _exit(res); +} + +void gen_code(uint8_t *code, Context *ctx, Instruction *program) { + (void)code; + (void)ctx; + (void)program; +} + +int run_jit(Instruction *program, size_t len, size_t register_count) { + // TODO: + size_t expected_code_len = 0; + // page alignment + size_t allocated_code_len = (expected_code_len + 0xFFF) & ~0xFFF; + + // allocate memory for context and code + uint8_t *code = (uint8_t *)mmap(NULL, allocated_code_len, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (code == (void *)-1) { + puts("Cannot mmap memory for code."); + exit(EXIT_FAILURE); + } + + Context *ctx = init_ctx(code, len, register_count); + + gen_code(code, ctx, program); + + // make code executable and non-writeable + if (mprotect(code, allocated_code_len, PROT_READ | PROT_EXEC) != 0) { + puts("Cannot make code executable!"); + exit(EXIT_FAILURE); + } + + int child_pid = fork(); + switch (child_pid) { + case -1: + puts("I'm infertile, I cannot have a child \U0001F62D"); + exit(EXIT_FAILURE); + case 0: + // child + exec_code(code); + __builtin_unreachable(); + default: + // parent + break; + } + + // continue in the parent; child never gets here + + // unmap allocated memory + if (munmap(code, allocated_code_len) != 0) { + puts("Cannot unmap code."); + exit(EXIT_FAILURE); + } + + // wait for child and extract exit code + int wstatus = 0; + if (waitpid(child_pid, &wstatus, 0) == -1) { + puts("waitpid failed!"); + exit(EXIT_FAILURE); + } + + if (!WIFEXITED(wstatus)) { + puts("Program crashed! WHAT?"); + exit(EXIT_FAILURE); + } + + uint8_t exit_code = WEXITSTATUS(wstatus); + + return exit_code; +} int main() { + // TODO: signal handlers? SIGCHILD? seccomp? + // TODO: better pun, add reference to pop-culture puts("Welcome to JIT-aaS (Just In Time - always a Surprise)"); Instruction *program; size_t program_len; + size_t register_count; int exit_code; while (true) { - program = get_program(&program_len); + // TODO: check for password and enable premium mode + program = get_program(&program_len, ®ister_count); + if (!validate_program(program, program_len, register_count)) { + puts("Your program is not valid. You possible use invalid register ids!"); + free(program); + continue; + } - exit_code = run_jit(program, program_len); + exit_code = run_jit(program, program_len, register_count); printf("Your program exited with %d\n", exit_code); free(program);