Engineering Note

プログラミングなどの技術的なメモ

Python x86エミュレータの作成(bios)

cpu

本記事はPythonで簡単なx86エミュレータを作成します。

前回ではIOポートの読み書きに使用されるin/out命令について学びました。

今回はBIOSの機能を使った文字出力について学んでいきます。

 

 

BIOSとは

前回ではIOポートの読み書きに使用されるin/out命令について学びました。

 

 

in/out命令によってIOポートに接続された周辺機器などのデバイスにアクセスし、読み書きすることが可能になりますが、機器ごとの仕様に差異があると、読み書きをするタイミングが異なり調整が必要になってきます。

そこでこれらの差異を吸収するために用意されたインターフェースがBIOSBasic Input Output System:基本入出力システム)になります。

BIOSIntel 8086の時代に使われたリアルモードで動作するプログラムなので、プロテクトモードで動作する現代的なOSでは殆ど直接的に触れることはありません。

しかし、PCの起動直後にハードウェアに異常がないかをチェックし、HDDに記録されたOSをメモリにロードして実行(ブートストラップローダ)する大切な役割はBIOSが担っています。

 

アセンブリ言語プログラム

今回は以下の参考書籍と同じアセンブリ言語プログラムを使用します。

 

 ;subroutine32.asm
BITS 32
    org 0x7c00
start:
    mov esi, msg
    call puts
    jmp 0

puts:
    mov al, [esi]
    inc esi
    cmp al, 0
    je  puts_end
    mov ah, 0x0e
    mov ebx, 10
    int 0x10
    jmp puts
puts_end:
    ret

msg:
    db "Hello, World!", 0x0d, 0x0a, 0

 

上記ではputs内の14行目~16行目がBIOSの処理になっていて、mov ah, 0x0eで1文字を表示させるテレタイプ出力を指定し、mov ebx, 10で出力する文字色(ここでは緑)を指定し、int 0x10BIOSの機能を呼び出します。

int命令はCPUに対してソフトウェア割り込み(Interrupt)をする命令になります。

 

Pythonによるスクリプトの作成

それでは、PythonBIOSのテレタイプ出力機能を実装していきます。

 

# emulator.py
import sys

class ModRM:
    def __init__(self):
        self.modrm = {
            "mod"       :0x00,
            "opecode"   :0x00,
            "reg_index" :0x00,
            "rm"        :0x00,
            "sib"       :0x00,
            "disp8"     :0x00,
            "disp32"    :0x00
        }

class Emulator:
    def __init__(self):
        self.register_name_extended = [
            "EAX", "ECX", "EDX", "EBX", "ESP",
            "EBP", "ESI", "EDI"]
        self.register_name = [
            "AL", "CL", "DL", "BL", "AH", "CH", "DH", "BH"
            ]
        self.registers = {
            "EAX": 0x00,
            "ECX": 0x00,
            "EDX": 0x00,
            "EBX": 0x00,
            "ESP": 0x00,
            "EBP": 0x00,
            "ESI": 0x00,
            "EDI": 0x00,
            }
        self.eflags = 0x00
        self.memory = None
        self.eip = 0x00
        self.instructions = [None for i in range(256)]

    def init_instructions(self):
        self.instructions[0x01] = self.add_rm32_r32
        self.instructions[0x3b] = self.cmp_r32_rm32
        self.instructions[0x3c] = self.cmp_al_imm8
        for i in range(8):
            self.instructions[0x40+i] = self.inc_r32
        for i in range(8):
            self.instructions[0x50+i] = self.push_r32
        for i in range(8):
            self.instructions[0x58+i] = self.pop_r32
        self.instructions[0x68] = self.push_imm32
        self.instructions[0x6a] = self.push_imm8
        self.instructions[0x70] = self.jo
        self.instructions[0x71] = self.jno
        self.instructions[0x72] = self.jc
        self.instructions[0x73] = self.jnc
        self.instructions[0x74] = self.jz
        self.instructions[0x75] = self.jnz
        self.instructions[0x78] = self.js
        self.instructions[0x79] = self.jns
        self.instructions[0x7c] = self.jl
        self.instructions[0x7e] = self.jle
        self.instructions[0x83] = self.code_83
        self.instructions[0x89] = self.mov_rm32_r32
        self.instructions[0x8a] = self.mov_r8_rm8
        self.instructions[0x8b] = self.mov_r32_rm32
        for i in range(8):
            self.instructions[0xb0 + i] = self.mov_r8_imm8
        for i in range(8):
            self.instructions[0xb8 + i] = self.mov_r32_imm32
        self.instructions[0xc3] = self.ret
        self.instructions[0xc7] = self.mov_rm32_imm32
        self.instructions[0xc9] = self.leave
        self.instructions[0xcd] = self.swi
        self.instructions[0xe8] = self.call_rel32
        self.instructions[0xe9] = self.near_jump
        self.instructions[0xeb] = self.short_jump
        self.instructions[0xec] = self.in_al_dx
        self.instructions[0xee] = self.out_dx_al
        self.instructions[0xff] = self.code_ff

    def create_emu(self, size, eip, esp):
        self.eip = eip
        self.registers["ESP"] = esp
        self.memory = [0x00 for _ in range(size)]

    def dump_registers(self):
        for i in range(8):
            name = self.register_name_extended[i]
            print("{} = 0x{:08x}".format(name, self.registers[name]))
        print("EIP = 0x{:08x}".format(self.eip))

    def mov_r32_imm32(self):
        reg = self.get_code8(0) - 0xb8
        value = self.get_code32(1)
        reg_name = self.register_name_extended[reg]
        self.registers[reg_name] = value
        self.eip += 5
        if self.eip >= 0x100000000:
            self.eip ^= 0x100000000

    def short_jump(self):
        diff = self.get_sign_code8(1)
        if diff & 0x80:
            diff -= 0x100
        self.eip += (diff + 2)

    def get_code8(self, index):
        code = self.memory[self.eip + index]
        if not type(code) == int:
            code = int.from_bytes(code, 'little')
        return code

    def get_sign_code8(self, index):
        code =  self.memory[self.eip + index]
        if not type(code) == int:
            code = int.from_bytes(code, 'little')
        return code & 0xff

    def get_code32(self, index):
        ret = 0x00
        for i in range(4):
            ret |= self.get_code8(index + i) << (i * 8)
        return ret

    def get_sign_code32(self, index):
        return  self.get_code32(index)

    def near_jump(self):
        diff = self.get_sign_code32(1)
        if diff & 0x80000000:
            diff -= 0x100000000
        self.eip += (diff + 5)

    def parse_modrm(self):
        m = ModRM()
        code = self.get_code8(0)
        m.modrm["mod"] = ((code & 0xc0) >> 6)
        m.modrm["opecode"] = m.modrm["reg_index"] = ((code & 0x38) >> 3)
        m.modrm["rm"] = code & 0x07

        self.eip += 1
        if (m.modrm["mod"] != 3 and m.modrm["rm"] == 4):
            m.modrm["sib"] = self.get_code8(0)
            eip += 1
        if (m.modrm["mod"] == 0 and m.modrm["rm"] == 5) or m.modrm["mod"] == 2:
            m.modrm["disp32"] = self.get_sign_code32(0)
            m.modrm["disp8"] = m.modrm["disp32"] & 0xff
            eip += 4
        elif m.modrm["mod"] == 1:
            m.modrm["disp8"] = m.modrm["disp32"] = self.get_sign_code8(0)
            self.eip += 1

        return m

    def mov_rm32_imm32(self):
        self.eip += 1
        m = self.parse_modrm()
        value = self.get_code32(0)
        self.eip += 4
        self.set_rm32(m, value)

    def set_rm32(self, m, value):
        if m.modrm["mod"] == 3:
            self.set_register32(m.modrm["rm"], value)
        else:
            address = self.calc_memory_address(m)
            self.set_memory32(address, value)

    def set_memory8(self, address, value):
        self.memory[address] = value & 0xff

    def set_memory32(self, address, value):
        for i in range(4):
            self.set_memory8(address+i, value >> (i*8))

    def calc_memory_address(self, m):
        if m.modrm["mod"] == 0:
            if m.modrm["rm"] == 4:
                print("not implemented ModRM mod = 0, rm = 4")
                sys.exit(0)
            elif m.modrm["rm"] == 5:
                return m.modrm["disp32"]
            else:
                return self.get_register32(m.modrm["rm"])
        elif m.modrm["mod"] == 1:
            if m.modrm["rm"] == 4:
                print("not implemented ModRM mod = 1, rm = 4")
                sys.exit(0)
            else:
                return self.get_register32(m.modrm["rm"]) + m.modrm["disp8"]
        elif m.modrm["mod"] == 2:
            if m.modrm["rm"] == 4:
                print("not implemented ModRM mod = 2, rm = 4")
                sys.exit(0)
            else:
                return self.get_register32(m.modrm["rm"]) + m.modrm["disp32"]
        else:
            print("not implemented ModRM mod = 3")
            sys.exit(0)

    def mov_rm32_r32(self):
        self.eip += 1
        m = self.parse_modrm()
        r32 = self.get_r32(m)
        self.set_rm32(m, r32)

    def mov_r32_rm32(self):
        self.eip += 1
        m = self.parse_modrm()
        rm32 = self.get_rm32(m)
        self.set_r32(m, rm32)

    def get_rm32(self, m):
        if m.modrm["mod"] == 3:
            return self.get_register32(m.modrm["rm"])
        else:
            address = self.calc_memory_address(m)
            return self.get_memory32(address)

    def get_memory8(self, address):
        return self.memory[address]

    def get_memory32(self, address):
        ret = 0
        for i in range(4):
            mem = self.get_memory8(address + i)
            if not type(mem) == int:
                mem = ord(mem)
            ret |= mem << (8*i)
        return ret

    def set_r32(self, m, value):
        self.set_register32(m.modrm["reg_index"], value)

    def get_r32(self, m):
        return self.get_register32(m.modrm["reg_index"])

    def add_rm32_r32(self):
        self.eip += 1
        m = self.parse_modrm()
        r32 = self.get_r32(m)
        rm32 = self.get_rm32(m)
        self.set_rm32(m, rm32 + r32)

    def sub_rm32_imm8(self, m):
        rm32 = self.get_rm32(m)
        imm8 = self.get_sign_code8(0)
        self.eip += 1
        result = rm32 - imm8
        self.set_rm32(m, result)
        self.update_eflags_sub(rm32, imm8, result)

    def code_83(self):
        self.eip += 1
        m = self.parse_modrm()
        if m.modrm["opecode"] == 0:
            self.add_rm32_imm8(m)
        elif m.modrm["opecode"] == 5:
            self.sub_rm32_imm8(m)
        elif m.modrm["opecode"] == 7:
            self.cmp_rm32_imm8(m)
        else:
            print("not implemented: 83 /{}".format(m.modrm["opecode"]))
            sys.exit(1)

    def inc_rm32(self, m):
        value = self.get_rm32(m)
        self.set_rm32(m, value + 1)

    def inc_r32(self):
        reg = self.get_code8(0) - 0x40
        self.set_register32(reg, self.get_register32(reg)+1)
        self.eip += 1

    def code_ff(self):
        self.eip += 1
        m = self.parse_modrm()

        if m.modrm["opecode"] == 0:
            self.inc_rm32(m)
        else:
            print("not implemented: FF /{}".format(m.modrm["opecode"]))
            sys.exit(1)

    def get_register32(self, index):
        reg = self.register_name_extended[index]
        return self.registers[reg]

    def set_register32(self, index, value):
        reg = self.register_name_extended[index]
        self.registers[reg] = value

    def push_r32(self):
        reg = self.get_code8(0) - 0x50
        self.push32(self.get_register32(reg))
        self.eip += 1

    def pop_r32(self):
        reg = self.get_code8(0) - 0x58
        self.set_register32(reg, self.pop32())
        self.eip += 1

    def push32(self, value):
        esp = self.register_name_extended.index("ESP")
        address = self.get_register32(esp) - 4
        self.set_register32(esp, address)
        self.set_memory32(address, value)

    def pop32(self):
        esp = self.register_name_extended.index("ESP")
        address = self.get_register32(esp)
        ret = self.get_memory32(address)
        self.set_register32(esp, address + 4)
        return ret

    def call_rel32(self):
        diff = self.get_sign_code32(1)
        if diff & 0x80000000:
            diff -= 0x100000000
        self.push32(self.eip + 5)
        self.eip += (diff + 5)

    def ret(self):
        self.eip = self.pop32()

    def leave(self):
        ebp = self.get_register32(self.register_name_extended.index("EBP"))
        self.set_register32(self.register_name_extended.index("ESP"), ebp)
        self.set_register32(self.register_name_extended.index("EBP"), self.pop32())
        self.eip += 1

    def push_imm8(self):
        value = self.get_code8(1)
        self.push32(value)
        self.eip += 2

    def push_imm32(self):
        value = self.get_code32(1)
        self.push32(value)
        self.eip += 5

    def add_rm32_imm8(self, m):
        rm32 = self.get_rm32(m)
        imm8 = self.get_sign_code8(0)
        self.eip += 1
        self.set_rm32(m, rm32+imm8)

    def cmp_r32_rm32(self):
        self.eip += 1
        m = self.parse_modrm()
        r32 = self.get_r32(m)
        rm32 = self.get_rm32(m)
        result = r32 - rm32
        self.update_eflags_sub(r32, rm32, result)

    def cmp_rm32_imm8(self, m):
        rm32 = self.get_rm32(m)
        imm8 = self.get_sign_code8(0)
        print(rm32, imm8)
        self.eip += 1
        result = rm32 - imm8
        self.update_eflags_sub(rm32, imm8, result)

    def update_eflags_sub(self, v1, v2, result):
        sign1 = v1 >> 31
        sign2 = v2 >> 31
        signr = (result >> 31) & 1
        self.set_carry(result >> 32)
        self.set_zero(result == 0)
        self.set_sign(signr)
        self.set_overflow(sign1 != sign2 and sign1 != signr)

    def set_carry(self, is_carry):
        if is_carry:
            self.eflags |= CARRY_FLAG
        else:
            self.eflags &= ~CARRY_FLAG

    def set_zero(self, is_zero):
        if is_zero:
            self.eflags |= ZERO_FLAG
        else:
            self.eflags &= ~ZERO_FLAG

    def set_sign(self, is_sign):
        if is_sign:
            self.eflags |= SIGN_FLAG
        else:
            self.eflags &= ~SIGN_FLAG

    def set_overflow(self, is_overflow):
        if is_overflow:
            self.eflags |= OVERFLOW_FLAG
        else:
            self.eflags &= ~OVERFLOW_FLAG

    def is_carry(self):
        return (self.eflags & CARRY_FLAG) != 0

    def is_zero(self):
        return (self.eflags & ZERO_FLAG) != 0

    def is_sign(self):
        return (self.eflags & SIGN_FLAG) != 0

    def is_overflow(self):
        return (self.eflags & OVERFLOW_FLAG) != 0

    def j(func):
        def wrapper(self, *args, **kwargs):
            if func(self, *args, **kwargs):
                diff = self.get_sign_code8(1)
            else:
                diff = 0
            self.eip += (diff + 2)
        return wrapper

    def jn(func):
        def wrapper(self, *args, **kwargs):
            if func(self, *args, **kwargs):
                diff = 0
            else:
                diff = self.get_sign_code8(1)
            self.eip += (diff + 2)
        return wrapper

    @j
    def jc(self):
        return self.is_carry()

    @jn
    def jnc(self):
        return self.is_carry()

    @j
    def js(self):
        return self.is_sign()

    @jn
    def jns(self):
        return self.is_sign()

    @j
    def jz(self):
        return self.is_zero()

    @jn
    def jnz(self):
        return self.is_zero()

    @j
    def jo(self):
        return self.is_overflow()

    @jn
    def jno(self):
        return self.is_overflow()

    def jl(self):
        if self.is_sign() != self.is_overflow():
            diff = self.get_sign_code8(1)
            if diff & 0x80:
                diff -= 0x100
        else:
            diff = 0
        self.eip += (diff + 2)

    def jle(self):
        if self.is_zero() or self.is_sign() != self.is_overflow():
            diff = self.get_sign_code8(1)
            if diff & 0x80:
                diff -= 0x100
        else:
            diff = 0
        self.eip += (diff + 2)

    def in_al_dx(self):
        address = self.get_register32(self.register_name_extended.index("EDX")) & 0xffff
        value = ord(self.io_in8(address))
        self.set_register8(self.register_name.index("AL"), value)
        self.eip += 1

    def out_dx_al(self):
        address = self.get_register32(self.register_name_extended.index("EDX")) & 0xffff
        value = self.get_register32(self.register_name.index("AL")) & 0xff
        self.io_out8(address, value)
        self.eip += 1

    def io_in8(self, address):
        if address == 0x03f8:
            return sys.stdin.read(1)

    def io_out8(self, address, value):
        if address == 0x03f8:
            sys.stdout.write(chr(value))
            sys.stdout.flush()

    def get_register8(self, index):
        if index < 4:
            reg_name = self.register_name_extended[index]
            return self.registers[reg_name] & 0xff
        else:
            reg_name = self.register_name_extended[index-4]
            return (self.registers[reg_name] >> 8) & 0xff

    def set_register8(self, index, value):
        if index < 4:
            reg_name = self.register_name_extended[index]
            r = self.registers[reg_name] & 0xffffff00
            self.registers[reg_name] = r | value
        else:
            reg_name = self.register_name_extended[index-4]
            r = self.registers[reg_name] & 0xffff00ff
            self.registers[reg_name] = r | (value << 8)

    def mov_r8_imm8(self):
        reg = self.get_code8(0) - 0xB0
        self.set_register8(reg, self.get_code8(1))
        self.eip += 2

    def cmp_al_imm8(self):
        value = self.get_code8(1)
        al = self.get_register8(self.register_name_extended.index("EAX")) & 0xff
        result = al - value
        self.update_eflags_sub(al, value, result)
        self.eip += 2

    def mov_r8_rm8(self):
        self.eip += 1
        m = self.parse_modrm()
        rm8 = self.get_rm8(m)
        self.set_r8(m, rm8)

    def set_r8(self, m, value):
        self.set_register8(m.modrm["reg_index"], value)

    def get_rm8(self, m):
        if m.modrm["mod"] == 3:
            return self.get_register32(m["rm"])
        else:
            address = self.calc_memory_address(m)
            return self.get_memory32(address)

    def swi(self):
        int_index = self.get_code8(1)
        self.eip += 2
        if int_index == 0x10:
            self.bios_video()
        else:
            print("unknown interrupt: 0x{:02x}".format(int_index))

    def put_string(self, string, size):
        for i in range(size):
            self.io_out8(0x03f8, ord(string[i]))

    def bios_video_teletype(self):
        color = self.get_register8(self.register_name.index("BL")) & 0x0f
        ch = self.get_register8(self.register_name.index("AL"))
        terminal_color = bios_to_terminal[color & 0x07]
        if color & 0x08:
            bright = 1
        else:
            bright = 0
        buf = "\x1b[{};{}m{}\x1b[0m".format(bright, terminal_color, chr(ch))

        self.put_string(buf, len(buf))

    def bios_video(self):
        func = self.get_register8(self.register_name.index("AH"))
        if func == 0x0e:
            self.bios_video_teletype()
        else:
            print("not implemented BIOS video function: 0x{:02x}".format(func))

bios_to_terminal = [30, 34, 32, 36, 31, 35, 33, 37]

CARRY_FLAG = 1
ZERO_FLAG = 1 << 6
SIGN_FLAG = 1 << 7
OVERFLOW_FLAG = 1 << 11

mem_size = 1024 * 1024

emu = Emulator()
emu.create_emu(mem_size, 0x7c00, 0x7c00)
binary = open('subroutine32.bin', 'rb')
offset = 0x7c00
while True:
    b = binary.read(1)
    if b == b'':
        break
    emu.memory[offset] = b
    offset += 1
binary.close()

quiet = 0
if '-q' in sys.argv:
    quiet = 1

emu.init_instructions()
while emu.eip < mem_size:
    code = emu.get_code8(0)
    if not quiet:
        print("EIP = 0x{:02x}, Code = 0x{:02x}".format(emu.eip, code))
    if emu.instructions[code] == None:
        print("\n\nNot Implemented: 0x{:02x}".format(code))
        break
    emu.instructions[code]()
    if emu.eip == 0x00:
        print("\n\nend of program.\n\n")
        break

emu.dump_registers()

 

なお、上記では指定された文字色をターミナルに出力するためにANSIエスケープシーケンスと 呼ばれるものを利用しています。

"\x1b[<輝度>;<色番号>m<文字列>\x1b[0m"をターミナルに出力することで指定した輝度と色で文字列が出力されます。

 

動作確認

それでは、上記で作成したスクリプトを実行してみます。

なお、事前にアセンブリ言語のプログラムはbinファイルとしてビルドしておきます。

また、WindowsコマンドプロンプトPowerShellではANSIエスケープシーケンスに対応していないため、WSL上のUbuntuで実行しました。

 

 > python3 emulator.py -q
 Hello, World! # 緑色で出力される


 end of program.


 EAX = 0x7f7f0e00
 ECX = 0x00000000
 EDX = 0x00000000
 EBX = 0x0000000a
 ESP = 0x00007c00
 EBP = 0x00000000
 ESI = 0x00007c32
 EDI = 0x00000000
 EIP = 0x00000000

 

問題なくBIOSの機能を使って"Hello, World!"の文字列を出力することが確認できました。

 

参考書籍

自作エミュレータで学ぶx86アーキテクチャ-コンピュータが動く仕組みを徹底理解!