diff --git a/activation_key.txt b/activation_key.txt new file mode 100644 index 0000000..9b84bd6 --- /dev/null +++ b/activation_key.txt @@ -0,0 +1 @@ +THIS IS A TEST. JUST A TEST. NOTHING BUT A TEST. YOU'VE BEEN WARNED. \ No newline at end of file diff --git a/debug_docker/Dockerfile b/debug_docker/Dockerfile index 3d3bbc2..b006986 100644 --- a/debug_docker/Dockerfile +++ b/debug_docker/Dockerfile @@ -6,7 +6,7 @@ FROM debian:bullseye -RUN apt update -y && apt upgrade -y && apt install -y build-essential wget cmake tar gdb libc6-dbg python3 file +RUN apt update -y && apt upgrade -y && apt install -y build-essential wget cmake tar gdb libc6-dbg python3 file strace ############### INSTALL FNETD @@ -41,6 +41,7 @@ RUN make WORKDIR / RUN cp /home/pwn/build/vuln /home/pwn/vuln +RUN cp /home/pwn/source/activation_key.txt /home/pwn/activation_key.txt RUN chmod 0755 /home/pwn/vuln diff --git a/tests/common.py b/tests/common.py index d8a5ebb..733a292 100644 --- a/tests/common.py +++ b/tests/common.py @@ -43,39 +43,45 @@ def instr_r(opcode, reg1, reg2): return bytes([opcode, reg1, 0, 0, reg2, 0, 0, 0]) -def exec_program(program: bytes, debug: bool = False) -> int | None: +def exec_program(program: bytes, assert_f, debug: bool = False) -> int | None: if debug: context.log_level = 'debug' + else: + context.log_level = 'warn' + with remote("localhost", 1337, fam="ipv4") as p: - p.recvuntil(b"Password: ", timeout=1) + msg = p.recvuntil(b"Password: ", timeout=1) + assert_f(msg != b'') p.sendline(b"1234") - msg = p.recvuntil(b"always a Surprise)", timeout=1) - if msg == b'': - return None - print(msg.decode()) + msg = p.recvuntil(b"always a Surprise)\n", timeout=1) + assert_f(msg == b'Welcome to JIT-aaS (Just In Time - always a Surprise)\n') + + msg = p.recvuntil(b"? (y/N):", timeout=1) + assert_f(msg == b'Do you want to activate the premium version? (y/N):') + + p.sendline(b"N") + msg = p.recvuntil(b"Using the demo version!\n") + assert_f(msg == b"Using the demo version!\n") msg = p.recvuntil(b"should it bee?", timeout=1) - if msg == b'': - return None - print(msg.decode()) + assert_f(msg == b'Now to your next program: How long should it bee?') len_msg = str(len(program) // INSTR_LEN).encode() - log.info(f"Sending: {len_msg}") + log.debug(f"Sending: {len_msg}") p.sendline(len_msg) msg = p.recvuntil(b"Now your program:", timeout=1) - if msg == b'': - return None - print(msg.decode()) + assert_f(msg == b'Now your program:') - log.info(f"Sending program: {list(program)}") + log.debug(f"Sending program: {list(program)}") p.send(program) msg = p.recvuntil(b"Your program exited with ", timeout=1) - if msg == b'': - return None - print(msg.decode()) + assert_f(msg == b'Your program exited with ') - exit_code = int(p.recvuntil(b"!", drop=True, timeout=1)) + exit_code_msg = p.recvuntil(b"!", drop=True, timeout=1) + assert_f(exit_code_msg != b'') + + exit_code = int(exit_code_msg) return exit_code diff --git a/tests/test_add.py b/tests/test_add.py index a47d1f8..6827a97 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -6,7 +6,7 @@ class BasicInstructionTest(unittest.TestCase): def test_addi(self): for val in [0, 1, 50, 256, 0x543210, 0xFFFFFFFF]: program = instr_i(Opcode.ADDI, Register.A, val) - exit_code = exec_program(program) + exit_code = exec_program(program, self.assertTrue) self.assertIsNotNone(exit_code, "Connection timeout!") self.assertEqual(exit_code, val & 0xFF, "Computed the wrong result!") @@ -17,7 +17,7 @@ class BasicInstructionTest(unittest.TestCase): val_sum = sum(vals) for val in vals: program += instr_i(Opcode.ADDI, Register.A, val) - exit_code = exec_program(program) + exit_code = exec_program(program, self.assertTrue) self.assertIsNotNone(exit_code, "Connection timeout!") self.assertEqual(exit_code, val_sum & 0xFF, "Computed the wrong result!") @@ -25,7 +25,7 @@ class BasicInstructionTest(unittest.TestCase): for val1, val2 in [(0, 0), (5, 8), (0xFFFFFFFF, 0xFFFFFFFF), (0xFFFFFFFF, 1)]: program = instr_i(Opcode.ADDI, Register.A, val1) + instr_i(Opcode.ADDI, Register.B, val2) + instr_r( Opcode.ADD, Register.A, Register.B) - exit_code = exec_program(program) + exit_code = exec_program(program, self.assertTrue) self.assertIsNotNone(exit_code, "Connection timeout!") self.assertEqual(exit_code, (val1 + val2) & 0xFF, "Computed the wrong result!") diff --git a/vuln.c b/vuln.c index 6d0d1b0..06d1184 100644 --- a/vuln.c +++ b/vuln.c @@ -1,3 +1,4 @@ +#include #include #include #include @@ -8,6 +9,9 @@ #include #define MAX_PROGRAM_LEN 0x1000 +#define ACTIVATION_KEY_LEN 0x80 + +static char activation_key[ACTIVATION_KEY_LEN] = {0}; typedef enum Opcode { ADD = 0, ADDI = 1, SUB = 2, COPY = 3, LOADI = 4, COUNT_OPCODES } Opcode; @@ -242,41 +246,87 @@ uint8_t run_jit(Instruction *program, size_t len) { exit(EXIT_FAILURE); } - int child_pid = fork(); - switch (child_pid) { - case -1: - puts("I'm infertile, I cannot have a child \U0001F62D"); + if (premium_activated) { + // premium version does not use sandbox + exec_func_t exec_func = (exec_func_t)code; + return exec_func(); + } else { + int child_pid = fork(); + switch (child_pid) { + case -1: + puts("I'm infertile, I cannot have a child \U0001F62D"); + exit(EXIT_FAILURE); + case 0: + // child + exec_code(code); + __builtin_unreachable(); + default: + // parent + break; + } + + // continue in the parent; child never gets here + + // unmap allocated memory + if (munmap(code, allocated_code_len) != 0) { + puts("Cannot unmap code."); + exit(EXIT_FAILURE); + } + + // wait for child and extract exit code + int wstatus = 0; + if (waitpid(child_pid, &wstatus, 0) == -1) { + puts("waitpid failed!"); + exit(EXIT_FAILURE); + } + + if (!WIFEXITED(wstatus)) { + puts("Program crashed! WHAT?"); + exit(EXIT_FAILURE); + } + + return WEXITSTATUS(wstatus); + } +} + +void check_premium() { + printf("Do you want to activate the premium version? (y/N):"); + char answer[3] = {0}; + if (fgets(answer, sizeof(answer), stdin) == NULL) { exit(EXIT_FAILURE); - case 0: - // child - exec_code(code); - __builtin_unreachable(); - default: - // parent + } + + switch (answer[0]) { + case 'y': break; + default: + premium_activated = false; + return; } - // continue in the parent; child never gets here - - // unmap allocated memory - if (munmap(code, allocated_code_len) != 0) { - puts("Cannot unmap code."); + int secret_fd = open("activation_key.txt", O_RDONLY); + if (secret_fd < 0) { + puts("Cannot open activation key file!"); exit(EXIT_FAILURE); } - // wait for child and extract exit code - int wstatus = 0; - if (waitpid(child_pid, &wstatus, 0) == -1) { - puts("waitpid failed!"); + read(secret_fd, activation_key, sizeof(activation_key)); + close(secret_fd); + + printf("Then please enter your activation key:"); + char buf[ACTIVATION_KEY_LEN] = {0}; + + ssize_t read_bytes = read(0, buf, sizeof(buf)); + if (read_bytes <= 0) { + puts("Cannot read activation key from user!"); exit(EXIT_FAILURE); } - if (!WIFEXITED(wstatus)) { - puts("Program crashed! WHAT?"); - exit(EXIT_FAILURE); + if (buf[read_bytes - 1] == '\n') { + buf[read_bytes - 1] = 0; } - return WEXITSTATUS(wstatus); + premium_activated = memcmp(activation_key, buf, sizeof(activation_key)) == 0; } int main() { @@ -289,12 +339,18 @@ int main() { // TODO: better pun, add reference to pop-culture puts("Welcome to JIT-aaS (Just In Time - always a Surprise)"); + check_premium(); + if (premium_activated) { + puts("Using premium version! No sandbox for you!"); + } else { + puts("Using the demo version!"); + } + Instruction *program; size_t program_len; int exit_code; while (true) { - // TODO: check for password and enable premium mode program = get_program(&program_len); if (!validate_program(program, program_len)) { puts("Your program is not valid. You possibly use invalid opcodes or registers!");