#!/usr/bin/env python3 from dataclasses import dataclass import sys from typing import Callable, Iterable from .parser import parser from . import parser_types as pt from .simple_assembler import encode_instruction @dataclass class MnemonicInfo: opcode: str num_args: int supports_jmp: bool class Program: def __init__(self): self.labels: dict[str, int] = {} self.instructions: bytes = b"" self.pc: int = 0 def encode(self, ins: pt.Instruction): arg1 = self._resolve(ins, lambda: ins.arg1) arg2 = self._resolve(ins, lambda: ins.arg2) op = encode_instruction( ins.opcode, ins.dest.name if ins.dest else "", arg1, arg2, (ins.jumptarget or ""), ) self.instructions += op.to_bytes(length=2, byteorder="little") def _resolve( self, ins: pt.Instruction, get_prop: Callable[[], pt.Symbol | pt.Register | pt.Immediate | str | None], ): arg = get_prop() if isinstance(arg, pt.Symbol): ret = self.labels.get(arg.name, None) if ret is None: raise ValueError(f"Line {ins.lineno}: Label {arg.name} not defined") elif isinstance(arg, pt.Register): ret = arg.name elif isinstance(arg, pt.Immediate): ret = arg.value else: ret = arg return ret def write_to_file(self, filename: str) -> None: with open(filename, "wb") as outfile: _ = outfile.write(self.instructions) print(f"Output written to {filename}") opcode_infos: dict[str, MnemonicInfo] = { "and": MnemonicInfo("and", 3, True), "or": MnemonicInfo("or", 3, True), "xor": MnemonicInfo("xor", 3, True), "not": MnemonicInfo("not", 2, True), "mov": MnemonicInfo("mov", 2, True), "add": MnemonicInfo("add", 3, True), "inc": MnemonicInfo("inc", 2, True), "sub": MnemonicInfo("sub", 3, True), "dec": MnemonicInfo("dec", 2, True), "cmp": MnemonicInfo("cmp", 2, True), "neg": MnemonicInfo("neg", 2, True), "hlt": MnemonicInfo("hlt", 0, False), "nop": MnemonicInfo("nop", 0, False), } def get_op_info(instruction: pt.Instruction) -> MnemonicInfo | None: """Get information about a given opcode in a instruction.""" return opcode_infos.get(instruction.opcode, None) def check_instructions( instructions: Iterable[pt.AsmLine], ) -> Iterable[pt.ErrorInstruction]: """Given an iterable of assembler lines, check for errors.""" for ins in instructions: # If instruction already is an error generated by the parser, just return that. if isinstance(ins, pt.ErrorInstruction): yield ins continue if not isinstance(ins, pt.Instruction): continue if ( ins.arg1 is not None and ins.arg2 is not None and not isinstance(ins.arg1, pt.Register) and not isinstance(ins.arg2, pt.Register) ): yield pt.ErrorInstruction( lineno=ins.lineno, opcode=ins.opcode, error_message="At least one argument must be a register.", ) opinfo = get_op_info(ins) if opinfo is None: yield pt.ErrorInstruction( lineno=ins.lineno, opcode=ins.opcode, error_message="Unknown instruction", ) continue if opinfo.num_args != ins.num_args: yield pt.ErrorInstruction( lineno=ins.lineno, opcode=ins.opcode, error_message=f"Expected {opinfo.num_args} args, got {ins.num_args}.", ) if not opinfo.supports_jmp and ins.jumptarget: yield pt.ErrorInstruction( lineno=ins.lineno, opcode=ins.opcode, error_message="OPcode got a jump, but it's not supported here.", ) def assemble(instructions: Iterable[pt.AsmLine]) -> Program: prog = Program() prog.pc = 0 # first pass: populate symbols for ins in instructions: match ins: case pt.JumpTarget(): lblname = ins.label.name if lblname in prog.labels: print( f"WARNING: Label {lblname} redefined on line {ins.lineno}. Using previous definition.", file=sys.stderr, ) else: prog.labels[lblname] = prog.pc case pt.Instruction(): prog.pc += 1 case _: pass prog.pc = 0 # second pass: assemble with resolve for ins in instructions: match ins: case pt.Instruction(): prog.encode(ins) prog.pc += 1 case _: pass return prog with open(sys.argv[1], "rb") as infile: data = infile.read() data2 = data.decode("ascii") result: list[pt.AsmLine] result = parser.parse(data2, tracking=True) errors = check_instructions(result) errors = list(errors) if errors: for e in errors: print(f"ERROR: On line {e.lineno}: {e.opcode} : {e.error_message}") sys.exit(1) print("Instruction checks passed") p = assemble(result) if len(sys.argv) >= 3: dest = sys.argv[2] else: dest = sys.argv[1] + ".bin" p.write_to_file(dest)