Rework registers.

This commit is contained in:
Johannes Maier
2024-01-10 16:10:11 +01:00
parent 29e9b2ea43
commit 2d8f9eba01

60
vuln.c
View File

@@ -6,28 +6,22 @@
#include <sys/mman.h> #include <sys/mman.h>
#include <sys/wait.h> #include <sys/wait.h>
#include <unistd.h> #include <unistd.h>
#include <stdbool.h>
#define MAX_PROGRAM_LEN 0x1000 #define MAX_PROGRAM_LEN 0x1000
#define MAX_NUM_REGISTERS 0x100
typedef enum Opcode : uint8_t { COUNT_OPCODES } Opcode; typedef enum Opcode : uint8_t { COUNT_OPCODES } Opcode;
typedef enum Register : uint8_t { COUNT_REGISTERS } Register;
typedef struct Instruction { typedef struct Instruction {
Opcode opcode; Opcode opcode;
Register reg;
} Instruction; } Instruction;
typedef struct Context {
size_t pc;
uint8_t *next_code_start;
size_t overall_code_size;
size_t register_count;
int regs[];
} Context;
typedef int (*exec_func_t)(); typedef int (*exec_func_t)();
static bool premium_activated = false;
static __attribute__((unused)) bool premium_activated = false;
size_t get_size_t(size_t limit) { size_t get_size_t(size_t limit) {
size_t val; size_t val;
@@ -53,11 +47,9 @@ size_t get_size_t(size_t limit) {
return val; return val;
} }
Instruction *get_program(size_t *program_len, size_t *register_count) { Instruction *get_program(size_t *program_len) {
puts("Now to your next program: How long should it bee?"); puts("Now to your next program: How long should it bee?");
size_t len = get_size_t(MAX_PROGRAM_LEN); size_t len = get_size_t(MAX_PROGRAM_LEN);
puts("How many registers do you want to use?");
size_t num_registers = get_size_t(MAX_NUM_REGISTERS);
Instruction *program = malloc(len * sizeof(Instruction)); Instruction *program = malloc(len * sizeof(Instruction));
@@ -72,34 +64,19 @@ Instruction *get_program(size_t *program_len, size_t *register_count) {
} }
*program_len = len; *program_len = len;
*register_count = num_registers;
return program; return program;
} }
bool validate_program(Instruction *program, size_t len, size_t register_count) { bool validate_program(Instruction *program, size_t len) {
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
// prevent use of wrong opcodes // prevent use of wrong opcodes or registers
if (program[i].opcode >= COUNT_OPCODES) { if (program[i].opcode >= COUNT_OPCODES || program[i].reg >= COUNT_REGISTERS) {
return false;
}
// prevent use of wrong registers
if (program[i].reg1 >= register_count || program[i].reg2 >= register_count) {
return false; return false;
} }
} }
return true; return true;
} }
Context *init_ctx(uint8_t *code, size_t allocated_code_len, size_t register_count) {
Context *ctx = (struct Context *)code;
ctx->register_count = register_count;
memset(ctx->regs, 0, sizeof(*ctx->regs) * ctx->register_count);
ctx->pc = 0;
ctx->next_code_start = code + sizeof(*ctx) + ctx->register_count * sizeof(*ctx->regs);
ctx->overall_code_size = allocated_code_len;
return ctx;
}
void init_seccomp() { void init_seccomp() {
// TODO: // TODO:
} }
@@ -114,13 +91,12 @@ void exec_code(uint8_t *code) {
_exit(res); _exit(res);
} }
void gen_code(uint8_t *code, Context *ctx, Instruction *program) { void gen_code(uint8_t *code, Instruction *program) {
(void)code; (void)code;
(void)ctx;
(void)program; (void)program;
} }
int run_jit(Instruction *program, size_t len, size_t register_count) { int run_jit(Instruction *program, size_t len) {
// TODO: // TODO:
size_t expected_code_len = 0; size_t expected_code_len = 0;
// page alignment // page alignment
@@ -132,10 +108,7 @@ int run_jit(Instruction *program, size_t len, size_t register_count) {
puts("Cannot mmap memory for code."); puts("Cannot mmap memory for code.");
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
gen_code(code, program);
Context *ctx = init_ctx(code, len, register_count);
gen_code(code, ctx, program);
// make code executable and non-writeable // make code executable and non-writeable
if (mprotect(code, allocated_code_len, PROT_READ | PROT_EXEC) != 0) { if (mprotect(code, allocated_code_len, PROT_READ | PROT_EXEC) != 0) {
@@ -190,19 +163,18 @@ int main() {
Instruction *program; Instruction *program;
size_t program_len; size_t program_len;
size_t register_count;
int exit_code; int exit_code;
while (true) { while (true) {
// TODO: check for password and enable premium mode // TODO: check for password and enable premium mode
program = get_program(&program_len, &register_count); program = get_program(&program_len);
if (!validate_program(program, program_len, register_count)) { if (!validate_program(program, program_len)) {
puts("Your program is not valid. You possible use invalid register ids!"); puts("Your program is not valid. You possibly use invalid opcodes or registers!");
free(program); free(program);
continue; continue;
} }
exit_code = run_jit(program, program_len, register_count); exit_code = run_jit(program, program_len);
printf("Your program exited with %d\n", exit_code); printf("Your program exited with %d\n", exit_code);
free(program); free(program);