Add premium mode. Closes #3.

This commit is contained in:
Johannes Maier
2024-01-16 15:58:52 +01:00
parent f244d69127
commit 4dd7d3d9d6
5 changed files with 110 additions and 46 deletions

View File

@@ -43,39 +43,45 @@ def instr_r(opcode, reg1, reg2):
return bytes([opcode, reg1, 0, 0, reg2, 0, 0, 0])
def exec_program(program: bytes, debug: bool = False) -> int | None:
def exec_program(program: bytes, assert_f, debug: bool = False) -> int | None:
if debug:
context.log_level = 'debug'
else:
context.log_level = 'warn'
with remote("localhost", 1337, fam="ipv4") as p:
p.recvuntil(b"Password: ", timeout=1)
msg = p.recvuntil(b"Password: ", timeout=1)
assert_f(msg != b'')
p.sendline(b"1234")
msg = p.recvuntil(b"always a Surprise)", timeout=1)
if msg == b'':
return None
print(msg.decode())
msg = p.recvuntil(b"always a Surprise)\n", timeout=1)
assert_f(msg == b'Welcome to JIT-aaS (Just In Time - always a Surprise)\n')
msg = p.recvuntil(b"? (y/N):", timeout=1)
assert_f(msg == b'Do you want to activate the premium version? (y/N):')
p.sendline(b"N")
msg = p.recvuntil(b"Using the demo version!\n")
assert_f(msg == b"Using the demo version!\n")
msg = p.recvuntil(b"should it bee?", timeout=1)
if msg == b'':
return None
print(msg.decode())
assert_f(msg == b'Now to your next program: How long should it bee?')
len_msg = str(len(program) // INSTR_LEN).encode()
log.info(f"Sending: {len_msg}")
log.debug(f"Sending: {len_msg}")
p.sendline(len_msg)
msg = p.recvuntil(b"Now your program:", timeout=1)
if msg == b'':
return None
print(msg.decode())
assert_f(msg == b'Now your program:')
log.info(f"Sending program: {list(program)}")
log.debug(f"Sending program: {list(program)}")
p.send(program)
msg = p.recvuntil(b"Your program exited with ", timeout=1)
if msg == b'':
return None
print(msg.decode())
assert_f(msg == b'Your program exited with ')
exit_code = int(p.recvuntil(b"!", drop=True, timeout=1))
exit_code_msg = p.recvuntil(b"!", drop=True, timeout=1)
assert_f(exit_code_msg != b'')
exit_code = int(exit_code_msg)
return exit_code

View File

@@ -6,7 +6,7 @@ class BasicInstructionTest(unittest.TestCase):
def test_addi(self):
for val in [0, 1, 50, 256, 0x543210, 0xFFFFFFFF]:
program = instr_i(Opcode.ADDI, Register.A, val)
exit_code = exec_program(program)
exit_code = exec_program(program, self.assertTrue)
self.assertIsNotNone(exit_code, "Connection timeout!")
self.assertEqual(exit_code, val & 0xFF, "Computed the wrong result!")
@@ -17,7 +17,7 @@ class BasicInstructionTest(unittest.TestCase):
val_sum = sum(vals)
for val in vals:
program += instr_i(Opcode.ADDI, Register.A, val)
exit_code = exec_program(program)
exit_code = exec_program(program, self.assertTrue)
self.assertIsNotNone(exit_code, "Connection timeout!")
self.assertEqual(exit_code, val_sum & 0xFF, "Computed the wrong result!")
@@ -25,7 +25,7 @@ class BasicInstructionTest(unittest.TestCase):
for val1, val2 in [(0, 0), (5, 8), (0xFFFFFFFF, 0xFFFFFFFF), (0xFFFFFFFF, 1)]:
program = instr_i(Opcode.ADDI, Register.A, val1) + instr_i(Opcode.ADDI, Register.B, val2) + instr_r(
Opcode.ADD, Register.A, Register.B)
exit_code = exec_program(program)
exit_code = exec_program(program, self.assertTrue)
self.assertIsNotNone(exit_code, "Connection timeout!")
self.assertEqual(exit_code, (val1 + val2) & 0xFF, "Computed the wrong result!")