From 7e7f04f1252d3708287074c4c46c6463e572e808 Mon Sep 17 00:00:00 2001 From: paumann Date: Tue, 11 Oct 2022 00:02:25 +0200 Subject: [PATCH] Initial commit --- challenge_evaluation.py | 276 +++++++++++++++++++++++++++++++++++++++ machine_evaluation.py | 277 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 553 insertions(+) create mode 100644 challenge_evaluation.py create mode 100644 machine_evaluation.py diff --git a/challenge_evaluation.py b/challenge_evaluation.py new file mode 100644 index 0000000..8e580c7 --- /dev/null +++ b/challenge_evaluation.py @@ -0,0 +1,276 @@ +import argparse +import zipfile +import json +import os + +# Costs +CONSTANT_COST = 1 +REGISTER_COST = 1 +OP_COST = { + "SUB": 4, + "ADD": 4, + "INC": 4, + "DEC": 4, + "MUL": 10, + "DIV": 10, + "MOD": 10, + "OR": 8, + "AND": 8, + "XOR": 8, + "INV": 8, + "SL": 5, + "SR": 5, + "SLU": 5, + "SRU": 5, + "ROTL": 7, + "ROTR": 7, +} +ALG_LINE_COST = 0.5 + +# ansi escape codes +ul = "\033[4m" # underline +end = "\033[0m" # reset +ylw = "\033[33m" # yellow + + +def is_empty_row(row, pedantic=False): + # Check if signal table is non-empty + if "signal" in row and len(row["signal"]) != 0: + return False + + # Only check if keys are set if pedantic is true + if not pedantic: + return True + # Check if "unconditional-jump", "conditional-jump" or "label" is set + keys = ("unconditional-jump", "conditional-jump", "label") + return not any(key in row for key in keys) + + +def evaluate(filepath, verbose=False, pedantic=False): + filename = filepath.split(os.path.sep)[-1] + if not filepath.endswith(".zip"): + print( + f"{filename} :: Supplied file does not have .zip file extension. Skipping .." + ) + return -1 + + with zipfile.ZipFile(filepath, "r") as savefile: + with savefile.open("machine.json", "r") as machinefile, savefile.open( + "signal.json", "r" + ) as signalfile: + machine = json.load(machinefile) + signal = json.load(signalfile) + + # Load lines of code (rows) + total_rows = signal["signaltable"]["row"] + # Remove rows without effect (used for formatting for example) + rows = [row for row in total_rows if not is_empty_row(row, pedantic=pedantic)] + if verbose: + print(f"{filename} :: Total number of rows: {len(total_rows)}") + print(f"{filename} :: Number of rows after excluding empty: {len(rows)}") + + # Check if IR or PC register was used + + pc_used = any( + ( + signal["name"] == "PC.W" and signal["value"] == "1" + for row in rows + for signal in row["signal"] + ) + ) + ir_used = any( + ( + signal["name"] == "IR.W" and signal["value"] == "1" + for row in rows + for signal in row["signal"] + ) + ) + + if verbose: + if pc_used: + print(f"{filename} :: PC Register was used in signal table row.") + if ir_used: + print(f"{filename} :: IR Register was used in signal table row.") + + # Load used multiplexer constants + try: + mux_input_a = next( + filter(lambda mux: mux["muxType"] == "A", machine["machine"]["muxInputs"]) + )["input"] + mux_consts_a = [ + int(mux_input["value"]) + for mux_input in mux_input_a + if mux_input["type"] == "constant" + ] + except StopIteration: + print( + f"{filename} :: Couldn't find input for multiplexer A. Is the file corrupted? Skipping file .." + ) + return -1 + try: + mux_input_b = next( + filter(lambda mux: mux["muxType"] == "B", machine["machine"]["muxInputs"]) + )["input"] + mux_consts_b = [ + int(mux_input["value"]) + for mux_input in mux_input_b + if mux_input["type"] == "constant" + ] + except StopIteration: + print( + f"{filename} :: Couldn't find input for multiplexer B. Is the file corrupted? Skipping file .." + ) + return -1 + + # Base machine has constants 0 and 1 at multiplexer A. All other constants are extensions. + base_muxt_a = (0, 1) + for base_input in base_muxt_a: + try: + mux_consts_a.remove(base_input) + except ValueError: + pass + + constants = set(mux_consts_a + mux_consts_b) + + if verbose: + print( + f"{filename} :: Found {len(mux_consts_a)} constants for multiplexer A: {mux_consts_a}" + ) + print( + f"{filename} :: Found {len(mux_consts_b)} constants for multiplexer B: {mux_consts_b}" + ) + print( + f"{filename} :: Found {len(constants)} total unique constants: [{', '.join([str(c) for c in constants])}]" + ) + + # Load used registers + registers = machine["machine"]["registers"]["register"] + registers = [register["name"] for register in registers] + + if pc_used: + registers.append("PC_ALT") + if ir_used: + registers.append("IR_ALT") + + if verbose: + print(f"{filename} :: Found {len(registers)} additional registers: {registers}") + + # Load used operations + operations = machine["machine"]["alu"]["operation"] + base_operations = ("A_ADD_B", "B_SUB_A", "TRANS_A", "TRANS_B") + for base_op in base_operations: + try: + operations.remove(base_op) + except ValueError: + pass + + # Extract operation (remove operands) + def get_op(op_str): + return next(filter(lambda substr: substr not in ("A", "B"), op_str.split("_"))) + + operations = list(map(get_op, operations)) + + if verbose: + print( + f"{filename} :: Found {len(operations)} additional operations: {operations}" + ) + + # Sum points + alg_line_costs = len(rows) * ALG_LINE_COST # every line of code + constant_costs = len(constants) * CONSTANT_COST # constants at both multiplexers + register_costs = len(registers) * REGISTER_COST # registers + operation_costs = 0 + for operation in operations: # operations + operation_costs += OP_COST[operation] + + total = alg_line_costs + constant_costs + register_costs + operation_costs + + # Summarize + costs = (alg_line_costs, constant_costs, register_costs, operation_costs, total) + precision = max( + [len(str(float(cost)).split(".")[1].lstrip("0")) for cost in costs] + ) # unreadable but works ¯\_(ツ)_/¯ + + if verbose: + print("") + + print(f"{ul}Summary for {filename}:{end}\n") + print(f" {alg_line_costs:5.{min(precision, 2)}f} LINES (excluding empty lines)") + print(f"+ {constant_costs:5.{min(precision, 2)}f} CONSTANTS") + print(f"+ {register_costs:5.{min(precision, 2)}f} REGISTERS") + print(f"+ {operation_costs:5.{min(precision, 2)}f} OPERATIONS") + print(f"-------------") + print(f"= {ylw}{total:5.{min(precision, 2)}f} TOTAL{end}\n\n") + + return total + + +# Evaluation +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "source", + type=str, + nargs="+", + help="Either ZIP file(s) generated by simulator or the submission root folder", + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + help="Prints additional information.", + ) + parser.add_argument( + "-p", + "--pedantic", + dest="pedantic", + action="store_true", + help="Extra pedantic (for example when checking for empty lines)", + ) + parser.add_argument( + "-t", + "--top", + dest="top", + type=int, + help="Print top n candidates (defaults to 7)", + ) + args = parser.parse_args() + + verbose = args.verbose + pedantic = args.pedantic + top_n = args.top if args.top != None else 7 + + # output score for each file + scores = [] + + # check if source argument is folder + savefiles = [] + for source in args.source: + if os.path.isdir(source): + # add all zip from subdirectories + for d in [ + e for e in os.listdir(source) if os.path.isdir(os.path.join(source, e)) + ]: + savefiles.append(os.path.join(source, d, f"{d}.zip")) + elif source.endswith(".zip"): + savefiles.append(source) + else: + print(f"Source '{source}' is not a ZIP file.") + + for savefile in savefiles: + score = evaluate(savefile, verbose=verbose, pedantic=pedantic) + if score == -1: + continue + + scores.append([savefile, score]) + + # if there is more than one file, output top 3 + scores.sort(key=lambda x: x[1]) + + n = len(savefiles) + if n > 1: + print(f"{ul}Leaderboard:{end}") + for i in range(min(n, top_n)): + file, score = scores[i] + print(f"#{i + 1} - {ylw}{score:5.2f} TOTAL{end} - {file}") diff --git a/machine_evaluation.py b/machine_evaluation.py new file mode 100644 index 0000000..c9a2770 --- /dev/null +++ b/machine_evaluation.py @@ -0,0 +1,277 @@ +import argparse +import json +import math +import sys +import os +import subprocess + +from zipfile import ZipFile + +""" + Checks if all submission output the correct decrypted data. + Directory structure should be as follows: + + submissions + ├── + | ├── mem_layout.json [optional] + | └── .zip + ├── + | ├── mem_layout.json [optional] + | └── .zip + ... + + group + ├── sBox + ├── key + ├── data_encrypted + └── data_decrypted + + mem_layout.json must contain the keys 'sbox', 'key' and 'data', with the respective addresses as values. + If the file does not exist, the user will be prompted to enter these addresses. The data + will then be saved as 'mem_layout.json', to be used in subsequent calls. + + If more than one simulator ZIP file is present, the user will also be prompted to choose. +""" + +SIMULATOR_PATH = "./minimax_simulator-2.0.0-cli.jar" + +# helper strings +OK = "\033[32mOK\033[0m" +ERROR = "\033[31mERROR\033[0m" + + +def compare_result(actual_file: str, expected_file: str) -> bool: + """ + Compares the file at path 'actual_file' with the file at path 'expected_file' bytewise. + If the actual file is larger than the expected file, additional bytes will be ignored. + """ + with open(actual_file, "rb") as actual, open(expected_file, "rb") as expected: + actual_bytes = actual.read() + expected_bytes = expected.read() + + return actual_bytes[: len(expected_bytes)] == expected_bytes + + +def create_mem_layout() -> dict: + """ + Prompts the user to enter addresses for the sbox, key and data. + Returns a dictionary with these three keys. + """ + mem_layout = dict() + for key in ("sbox", "key", "data"): + while True: + try: + value = input(f"{key} address (prefix hex numbers with '0x'): ") + base = 16 if value.startswith("0x") else 10 + value = int(value, base=base) + break + except ValueError: + print("Adress is not an integer.") + except KeyboardInterrupt: + sys.exit("Aborted.") + mem_layout[key] = value + return mem_layout + + +def is_valid_zip(file: str) -> bool: + """ + Checks if a zip file is a save file from the minimax simulator, e.g. if the contents + are a 'machine.json' and 'signal.json' file. + """ + if not file.endswith(".zip"): + return False + with ZipFile(file) as machine_zip: + zip_content = machine_zip.namelist() + return set(zip_content) == set(("machine.json", "signal.json")) + + +def select_zip(zips: list) -> str: + """ + Prompts the user to select a single zip file from a list, and returns it. + """ + print("Multiple zip files found. Please select one:") + for index, f in enumerate(zips, start=1): + print(f"[{index}] {f}") + while True: + try: + selection = input("Enter the number of the zip file to select: ") + selection = int(selection) - 1 + if selection <= 0 or selection > len(zips): + print(f"Please select a number between 1 and {len(zips)}.") + else: + return zips[selection] + except ValueError: + print("Please enter a valid integer.") + except KeyboardInterrupt: + sys.exit("Aborted") + + +def evaluate( + zip_file: str, + sbox_file: str, + key_file: str, + data_file: str, + result_file: str, + mem_layout: dict, + simulator: str, +) -> None: + """ + Runs the minimax simulator on the given input. The resulting file is saved in 'result_file'. + """ + args = [ + "java", + "-jar", + simulator, + zip_file, + "--import-file", + sbox_file, + "--import-from", + mem_layout["sbox"], + "--import-file", + key_file, + "--import-from", + mem_layout["key"], + "--import-file", + data_file, + "--import-from", + mem_layout["data"], + "--export-file", + result_file, + "--export-from", + mem_layout["data"], + "--export-to", + mem_layout["data"] + math.ceil(os.path.getsize(data_file)), + ] + + args = [ + str(arg) for arg in args + ] # subprocess.run requires all arguments to be strings + + print(f"Running simulator, storing result in '{result_file}'") + + result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + print("\033[38;5;245m") + + print(result.stdout.decode("utf-8")) + + print("\033[38;5;124") + + print(result.stderr.decode("utf-8")) + + print("\033[0m") + + +if __name__ == "__main__": # only run if executed as script + + parser = argparse.ArgumentParser() + parser.add_argument("submissions", type=str, help="Submissions root directory") + parser.add_argument( + "group", type=str, help="Group directory, contains all project files" + ) + parser.add_argument( + "-e", + "--file-extension", + dest="file_ext", + type=str, + help="Result file extension", + ) + parser.add_argument("-j", "--jar", dest="jar", type=str, help="Simulator jar file") + args = parser.parse_args() + + # Load teams + teams = [ + e + for e in os.listdir(args.submissions) + if os.path.isdir(os.path.join(args.submissions, e)) + ] + + print(f"The following teams were found:") + for team in teams: + print(f"* {team}") + + # Check directory structure + for file in ("sBox", "key", "data_encrypted", "data_decrypted"): + if not os.path.exists(os.path.join(args.group, file)): + sys.exit(f"Group project file '{file}' is missing.") + + # Check file extension + if args.file_ext is None: + args.file_ext = "" + elif not args.file_ext.startswith("."): + args.file_ext = f".{args.file_ext}" + + # Update simulator path if given + simulator = SIMULATOR_PATH if args.jar is None else args.jar + + # Store evaluation results + expected_file = os.path.join(args.group, "data_decrypted") + results = [] + + # Evaluate each team + for team in teams: + + # load memory layout file if available (otherwise create and store it) + try: + with open( + os.path.join(args.submissions, team, "mem_layout.json"), "r" + ) as mem_layout_file: + mem_layout = json.load(mem_layout_file) + except FileNotFoundError: + mem_layout = create_mem_layout() + with open( + os.path.join(args.submissions, team, "mem_layout.json"), "w" + ) as mem_layout_file: + json.dump(mem_layout, mem_layout_file) + + # Select project file (if more there is more than one zip file) + zip_files = [ + os.path.join(args.submissions, team, f) + for f in os.listdir(os.path.join(args.submissions, team)) + if is_valid_zip(os.path.join(args.submissions, team, f)) + ] + zip_file = zip_files[0] if len(zip_files) == 1 else select_zip(zip_files) + + # check memory layout and convert hexadecimal addresses + for key in ("sbox", "key", "data"): + if key not in mem_layout: + sys.exit(f"memory_layout does not contain key '{key}'") + try: + addr = int(mem_layout[key]) + except ValueError: + try: + addr = int(mem_layout[key], base=16) + except ValueError: + sys.exit(f"Invalid address '{mem_layout[key]}' for key '{key}'") + mem_layout[key] = addr + + # evaluate team + sbox_file = os.path.join(args.group, "sBox") + key_file = os.path.join(args.group, "key") + data_file = os.path.join(args.group, "data_encrypted") + result_file = os.path.join( + args.submissions, team, f"data_decrypted{args.file_ext}" + ) + + evaluate( + zip_file, + sbox_file, + key_file, + data_file, + result_file, + mem_layout, + simulator=simulator, + ) + + # Compare against expected result + results.append(compare_result(result_file, expected_file)) + + # Print result summary + print("Summary:") + for team, result in zip(teams, results): + if result is True: + print(f"[{OK}] - {team}") + else: + print(f"[{ERROR}] - {team}") + + print()