Engineering Note

プログラミングなどの技術的なメモ

Python x86エミュレータの作成(in/out命令)

cpu

本記事はPythonで簡単なx86エミュレータを作成します。

前回 では条件分岐命令で使用されるeflagsの使い方について学びました。

今回はIOポートの読み書きに使用されるin/out命令について学んでいきます。

 

 

in/out命令とは

前回 では条件分岐命令で使用されるeflagsの使い方について学びました。

 

 

CPUはレジスタやメモリへのアクセス以外にも、キーボードやマウスなど様々なデバイス(周辺機器)へのアクセスも行います。

これらのデバイスにはIOポートという部品を介してアクセスし、その際の読み書きに使われる命令がin/out命令になります。

また、一般的なx86では0x0000~0xffffのアドレス空間に各デバイスマッピングされているため、そのアドレスを指定し、in/out命令によって読み書きを実行します。

 

今回は以下の参考書籍と同じアセンブリ言語プログラムを使用します。

 

 ;select.asm
BITS 32
    org 0x7c00
start:
    mov edx, 0x03f8
mainloop:
    mov al, '>'
    out dx, al
input:
    in al, dx 
    cmp al, 'h'
    je puthello
    cmp al, 'w'
    je putworld
    cmp al, 'q'
    je fin 
    jmp input
puthello:
    mov esi, msghello
    call puts
    jmp mainloop
putworld:
    mov esi, msgworld
    call puts
    jmp mainloop
fin:
    jmp 0

puts:
    mov al, [esi]
    inc esi
    cmp al, 0
    je putsend
    out dx, al
    jmp puts
putsend:
    ret

msghello:
    db "hello", 0x0d, 0x0a, 0
msgworld:
    db "world", 0x0d, 0x0a, 0

上記は標準入力から1文字を読み取り、それが'h'であれば"hello"を、"w"であれば"world"を標準出力し、"q"であればプログラムを終了します(それ以外では何もせず次の入力を待ちます)

なお、標準入出力の部分はsys.stdin.read()およびsys.stdout.write()でエミュレートしておきます。

 

上記を逆アセンブルしたものは以下になります。

00000000  BAF8030000        mov edx,0x3f8
00000005  B03E              mov al,0x3e
00000007  EE                out dx,al
00000008  EC                in al,dx
00000009  3C68              cmp al,0x68
0000000B  740A              jz 0x17
0000000D  3C77              cmp al,0x77
0000000F  7412              jz 0x23
00000011  3C71              cmp al,0x71
00000013  741A              jz 0x2f
00000015  EBF1              jmp short 0x8
00000017  BE3F7C0000        mov esi,0x7c3f
0000001C  E813000000        call 0x34
00000021  EBE2              jmp short 0x5
00000023  BE477C0000        mov esi,0x7c47
00000028  E807000000        call 0x34
0000002D  EBD6              jmp short 0x5
0000002F  E9CC83FFFF        jmp 0xffff8400
00000034  8A06              mov al,[esi]
00000036  46                inc esi
00000037  3C00              cmp al,0x0
00000039  7403              jz 0x3e
0000003B  EE                out dx,al
0000003C  EBF6              jmp short 0x34
0000003E  C3                ret
0000003F  68656C6C6F        push dword 0x6f6c6c65
00000044  0D0A00776F        or eax,0x6f77000a
00000049  726C              jc 0xb7
0000004B  64                fs
0000004C  0D                db 0x0d
0000004D  0A00              or al,[eax]

 

上記では0x7c3f番地に"hello"が、0x7c47番地に"world"の文字列が配置されており、入力された一文字が'h'か'w'を確認し、mov esi,[address]でesiにその番地をコピーしmov al,[esi]でalレジスタに一文字コピーした後、inc esiでesiの指しているアドレスを一つ進め、cmp al,0x0で文字列の最後('\0')であればputsを終了し、そうでなければout dx,alで一文字出力します。

'q'が入力されればjmp 0xffff8400でプログラムが配置してある先頭番地にジャンプして終了します。

 

Pythonによるスクリプトの作成

それでは、Pythonでin/out命令を実装していきます。

なお、参考書籍と同様に処理ごとにEIPとオペコードの状態を表示させると見づらくなるので、'-q'オプションで非表示にさせます。

 

# emulator.py
import sys

class ModRM:
    def __init__(self):
        self.modrm = {
            "mod"       :0x00,
            "opecode"   :0x00,
            "reg_index" :0x00,
            "rm"        :0x00,
            "sib"       :0x00,
            "disp8"     :0x00,
            "disp32"    :0x00
        }

class Emulator:
    def __init__(self):
        self.register_name_extended = [
            "EAX", "ECX", "EDX", "EBX", "ESP",
            "EBP", "ESI", "EDI"]
        self.register_name = [
            "AL", "CL", "DL", "BL", "AH", "CH", "DH", "BH"
            ]
        self.registers = {
            "EAX": 0x00,
            "ECX": 0x00,
            "EDX": 0x00,
            "EBX": 0x00,
            "ESP": 0x00,
            "EBP": 0x00,
            "ESI": 0x00,
            "EDI": 0x00,
            }
        self.eflags = 0x00
        self.memory = None
        self.eip = 0x00
        self.instructions = [None for i in range(256)]

    def init_instructions(self):
        self.instructions[0x01] = self.add_rm32_r32
        self.instructions[0x3b] = self.cmp_r32_rm32
        self.instructions[0x3c] = self.cmp_al_imm8
        for i in range(8):
            self.instructions[0x40+i] = self.inc_r32
        for i in range(8):
            self.instructions[0x50+i] = self.push_r32
        for i in range(8):
            self.instructions[0x58+i] = self.pop_r32
        self.instructions[0x68] = self.push_imm32
        self.instructions[0x6a] = self.push_imm8
        self.instructions[0x70] = self.jo
        self.instructions[0x71] = self.jno
        self.instructions[0x72] = self.jc
        self.instructions[0x73] = self.jnc
        self.instructions[0x74] = self.jz
        self.instructions[0x75] = self.jnz
        self.instructions[0x78] = self.js
        self.instructions[0x79] = self.jns
        self.instructions[0x7c] = self.jl
        self.instructions[0x7e] = self.jle
        self.instructions[0x83] = self.code_83
        self.instructions[0x89] = self.mov_rm32_r32
        self.instructions[0x8a] = self.mov_r8_rm8
        self.instructions[0x8b] = self.mov_r32_rm32
        for i in range(8):
            self.instructions[0xb0 + i] = self.mov_r8_imm8
        for i in range(8):
            self.instructions[0xb8 + i] = self.mov_r32_imm32
        self.instructions[0xc3] = self.ret
        self.instructions[0xc7] = self.mov_rm32_imm32
        self.instructions[0xc9] = self.leave
        self.instructions[0xe8] = self.call_rel32
        self.instructions[0xe9] = self.near_jump
        self.instructions[0xeb] = self.short_jump
        self.instructions[0xec] = self.in_al_dx
        self.instructions[0xee] = self.out_dx_al
        self.instructions[0xff] = self.code_ff

    def create_emu(self, size, eip, esp):
        self.eip = eip
        self.registers["ESP"] = esp
        self.memory = [0x00 for _ in range(size)]

    def dump_registers(self):
        for i in range(8):
            name = self.register_name_extended[i]
            print("{} = 0x{:08x}".format(name, self.registers[name]))
        print("EIP = 0x{:08x}".format(self.eip))

    def mov_r32_imm32(self):
        reg = self.get_code8(0) - 0xb8
        value = self.get_code32(1)
        reg_name = self.register_name_extended[reg]
        self.registers[reg_name] = value
        self.eip += 5
        if self.eip >= 0x100000000:
            self.eip ^= 0x100000000

    def short_jump(self):
        diff = self.get_sign_code8(1)
        if diff & 0x80:
            diff -= 0x100
        self.eip += (diff + 2)

    def get_code8(self, index):
        code = self.memory[self.eip + index]
        if not type(code) == int:
            code = int.from_bytes(code, 'little')
        return code

    def get_sign_code8(self, index):
        code =  self.memory[self.eip + index]
        if not type(code) == int:
            code = int.from_bytes(code, 'little')
        return code & 0xff

    def get_code32(self, index):
        ret = 0x00
        for i in range(4):
            ret |= self.get_code8(index + i) << (i * 8)
        return ret

    def get_sign_code32(self, index):
        return  self.get_code32(index)

    def near_jump(self):
        diff = self.get_sign_code32(1)
        if diff & 0x80000000:
            diff -= 0x100000000
        self.eip += (diff + 5)

    def parse_modrm(self):
        m = ModRM()
        code = self.get_code8(0)
        m.modrm["mod"] = ((code & 0xc0) >> 6)
        m.modrm["opecode"] = m.modrm["reg_index"] = ((code & 0x38) >> 3)
        m.modrm["rm"] = code & 0x07

        self.eip += 1
        if (m.modrm["mod"] != 3 and m.modrm["rm"] == 4):
            m.modrm["sib"] = self.get_code8(0)
            eip += 1
        if (m.modrm["mod"] == 0 and m.modrm["rm"] == 5) or m.modrm["mod"] == 2:
            m.modrm["disp32"] = self.get_sign_code32(0)
            m.modrm["disp8"] = m.modrm["disp32"] & 0xff
            eip += 4
        elif m.modrm["mod"] == 1:
            m.modrm["disp8"] = m.modrm["disp32"] = self.get_sign_code8(0)
            self.eip += 1

        return m

    def mov_rm32_imm32(self):
        self.eip += 1
        m = self.parse_modrm()
        value = self.get_code32(0)
        self.eip += 4
        self.set_rm32(m, value)

    def set_rm32(self, m, value):
        if m.modrm["mod"] == 3:
            self.set_register32(m.modrm["rm"], value)
        else:
            address = self.calc_memory_address(m)
            self.set_memory32(address, value)

    def set_memory8(self, address, value):
        self.memory[address] = value & 0xff

    def set_memory32(self, address, value):
        for i in range(4):
            self.set_memory8(address+i, value >> (i*8))

    def calc_memory_address(self, m):
        if m.modrm["mod"] == 0:
            if m.modrm["rm"] == 4:
                print("not implemented ModRM mod = 0, rm = 4")
                sys.exit(0)
            elif m.modrm["rm"] == 5:
                return m.modrm["disp32"]
            else:
                return self.get_register32(m.modrm["rm"])
        elif m.modrm["mod"] == 1:
            if m.modrm["rm"] == 4:
                print("not implemented ModRM mod = 1, rm = 4")
                sys.exit(0)
            else:
                return self.get_register32(m.modrm["rm"]) + m.modrm["disp8"]
        elif m.modrm["mod"] == 2:
            if m.modrm["rm"] == 4:
                print("not implemented ModRM mod = 2, rm = 4")
                sys.exit(0)
            else:
                return self.get_register32(m.modrm["rm"]) + m.modrm["disp32"]
        else:
            print("not implemented ModRM mod = 3")
            sys.exit(0)

    def mov_rm32_r32(self):
        self.eip += 1
        m = self.parse_modrm()
        r32 = self.get_r32(m)
        self.set_rm32(m, r32)

    def mov_r32_rm32(self):
        self.eip += 1
        m = self.parse_modrm()
        rm32 = self.get_rm32(m)
        self.set_r32(m, rm32)

    def get_rm32(self, m):
        if m.modrm["mod"] == 3:
            return self.get_register32(m.modrm["rm"])
        else:
            address = self.calc_memory_address(m)
            return self.get_memory32(address)

    def get_memory8(self, address):
        return self.memory[address]

    def get_memory32(self, address):
        ret = 0
        for i in range(4):
            mem = self.get_memory8(address + i)
            if not type(mem) == int:
                mem = ord(mem)
            ret |= mem << (8*i)
        return ret

    def set_r32(self, m, value):
        self.set_register32(m.modrm["reg_index"], value)

    def get_r32(self, m):
        return self.get_register32(m.modrm["reg_index"])

    def add_rm32_r32(self):
        self.eip += 1
        m = self.parse_modrm()
        r32 = self.get_r32(m)
        rm32 = self.get_rm32(m)
        self.set_rm32(m, rm32 + r32)

    def sub_rm32_imm8(self, m):
        rm32 = self.get_rm32(m)
        imm8 = self.get_sign_code8(0)
        self.eip += 1
        result = rm32 - imm8
        self.set_rm32(m, result)
        self.update_eflags_sub(rm32, imm8, result)

    def code_83(self):
        self.eip += 1
        m = self.parse_modrm()
        if m.modrm["opecode"] == 0:
            self.add_rm32_imm8(m)
        elif m.modrm["opecode"] == 5:
            self.sub_rm32_imm8(m)
        elif m.modrm["opecode"] == 7:
            self.cmp_rm32_imm8(m)
        else:
            print("not implemented: 83 /{}".format(m.modrm["opecode"]))
            sys.exit(1)

    def inc_rm32(self, m):
        value = self.get_rm32(m)
        self.set_rm32(m, value + 1)

    def inc_r32(self):
        reg = self.get_code8(0) - 0x40
        self.set_register32(reg, self.get_register32(reg)+1)
        self.eip += 1

    def code_ff(self):
        self.eip += 1
        m = self.parse_modrm()

        if m.modrm["opecode"] == 0:
            self.inc_rm32(m)
        else:
            print("not implemented: FF /{}".format(m.modrm["opecode"]))
            sys.exit(1)

    def get_register32(self, index):
        reg = self.register_name_extended[index]
        return self.registers[reg]

    def set_register32(self, index, value):
        reg = self.register_name_extended[index]
        self.registers[reg] = value

    def push_r32(self):
        reg = self.get_code8(0) - 0x50
        self.push32(self.get_register32(reg))
        self.eip += 1

    def pop_r32(self):
        reg = self.get_code8(0) - 0x58
        self.set_register32(reg, self.pop32())
        self.eip += 1

    def push32(self, value):
        esp = self.register_name_extended.index("ESP")
        address = self.get_register32(esp) - 4
        self.set_register32(esp, address)
        self.set_memory32(address, value)

    def pop32(self):
        esp = self.register_name_extended.index("ESP")
        address = self.get_register32(esp)
        ret = self.get_memory32(address)
        self.set_register32(esp, address + 4)
        return ret

    def call_rel32(self):
        diff = self.get_sign_code32(1)
        if diff & 0x80000000:
            diff -= 0x100000000
        self.push32(self.eip + 5)
        self.eip += (diff + 5)

    def ret(self):
        self.eip = self.pop32()

    def leave(self):
        ebp = self.get_register32(self.register_name_extended.index("EBP"))
        self.set_register32(self.register_name_extended.index("ESP"), ebp)
        self.set_register32(self.register_name_extended.index("EBP"), self.pop32())
        self.eip += 1

    def push_imm8(self):
        value = self.get_code8(1)
        self.push32(value)
        self.eip += 2

    def push_imm32(self):
        value = self.get_code32(1)
        self.push32(value)
        self.eip += 5

    def add_rm32_imm8(self, m):
        rm32 = self.get_rm32(m)
        imm8 = self.get_sign_code8(0)
        self.eip += 1
        self.set_rm32(m, rm32+imm8)

    def cmp_r32_rm32(self):
        self.eip += 1
        m = self.parse_modrm()
        r32 = self.get_r32(m)
        rm32 = self.get_rm32(m)
        result = r32 - rm32
        self.update_eflags_sub(r32, rm32, result)

    def cmp_rm32_imm8(self, m):
        rm32 = self.get_rm32(m)
        imm8 = self.get_sign_code8(0)
        print(rm32, imm8)
        self.eip += 1
        result = rm32 - imm8
        self.update_eflags_sub(rm32, imm8, result)

    def update_eflags_sub(self, v1, v2, result):
        sign1 = v1 >> 31
        sign2 = v2 >> 31
        signr = (result >> 31) & 1
        self.set_carry(result >> 32)
        self.set_zero(result == 0)
        self.set_sign(signr)
        self.set_overflow(sign1 != sign2 and sign1 != signr)

    def set_carry(self, is_carry):
        if is_carry:
            self.eflags |= CARRY_FLAG
        else:
            self.eflags &= ~CARRY_FLAG

    def set_zero(self, is_zero):
        if is_zero:
            self.eflags |= ZERO_FLAG
        else:
            self.eflags &= ~ZERO_FLAG

    def set_sign(self, is_sign):
        if is_sign:
            self.eflags |= SIGN_FLAG
        else:
            self.eflags &= ~SIGN_FLAG

    def set_overflow(self, is_overflow):
        if is_overflow:
            self.eflags |= OVERFLOW_FLAG
        else:
            self.eflags &= ~OVERFLOW_FLAG

    def is_carry(self):
        return (self.eflags & CARRY_FLAG) != 0

    def is_zero(self):
        return (self.eflags & ZERO_FLAG) != 0

    def is_sign(self):
        return (self.eflags & SIGN_FLAG) != 0

    def is_overflow(self):
        return (self.eflags & OVERFLOW_FLAG) != 0

    def j(func):
        def wrapper(self, *args, **kwargs):
            if func(self, *args, **kwargs):
                diff = self.get_sign_code8(1)
            else:
                diff = 0
            self.eip += (diff + 2)
        return wrapper

    def jn(func):
        def wrapper(self, *args, **kwargs):
            if func(self, *args, **kwargs):
                diff = 0
            else:
                diff = self.get_sign_code8(1)
            self.eip += (diff + 2)
        return wrapper

    @j
    def jc(self):
        return self.is_carry()

    @jn
    def jnc(self):
        return self.is_carry()

    @j
    def js(self):
        return self.is_sign()

    @jn
    def jns(self):
        return self.is_sign()

    @j
    def jz(self):
        return self.is_zero()

    @jn
    def jnz(self):
        return self.is_zero()

    @j
    def jo(self):
        return self.is_overflow()

    @jn
    def jno(self):
        return self.is_overflow()

    def jl(self):
        if self.is_sign() != self.is_overflow():
            diff = self.get_sign_code8(1)
            if diff & 0x80:
                diff -= 0x100
        else:
            diff = 0
        self.eip += (diff + 2)

    def jle(self):
        if self.is_zero() or self.is_sign() != self.is_overflow():
            diff = self.get_sign_code8(1)
            if diff & 0x80:
                diff -= 0x100
        else:
            diff = 0
        self.eip += (diff + 2)

    def in_al_dx(self):
        address = self.get_register32(self.register_name_extended.index("EDX")) & 0xffff
        value = ord(self.io_in8(address))
        self.set_register8(self.register_name.index("AL"), value)
        self.eip += 1

    def out_dx_al(self):
        address = self.get_register32(self.register_name_extended.index("EDX")) & 0xffff
        value = self.get_register32(self.register_name.index("AL")) & 0xff
        self.io_out8(address, value)
        self.eip += 1

    def io_in8(self, address):
        if address == 0x03f8:
            return sys.stdin.read(1)

    def io_out8(self, address, value):
        if address == 0x03f8:
            sys.stdout.write(chr(value))
            sys.stdout.flush()

    def get_register8(self, index):
        if index < 4:
            reg_name = self.register_name_extended[index]
            return self.registers[reg_name] & 0xff
        else:
            reg_name = self.register_name_extended[index-4]
            return (self.registers[reg_name] >> 8) & 0xff

    def set_register8(self, index, value):
        if index < 4:
            reg_name = self.register_name_extended[index]
            r = self.registers[reg_name] & 0xffffff00
            self.registers[reg_name] = r | value
        else:
            reg_name = self.register_name_extended[index-4]
            r = self.registers[reg_name] & 0xffff00ff
            self.registers[reg_name] = r | (value << 8)

    def mov_r8_imm8(self):
        reg = self.get_code8(0) - 0xB0
        self.set_register8(reg, self.get_code8(1))
        self.eip += 2

    def cmp_al_imm8(self):
        value = self.get_code8(1)
        al = self.get_register8(self.register_name_extended.index("EAX")) & 0xff
        result = al - value
        self.update_eflags_sub(al, value, result)
        self.eip += 2

    def mov_r8_rm8(self):
        self.eip += 1
        m = self.parse_modrm()
        rm8 = self.get_rm8(m)
        self.set_r8(m, rm8)

    def set_r8(self, m, value):
        self.set_register8(m.modrm["reg_index"], value)

    def get_rm8(self, m):
        if m.modrm["mod"] == 3:
            return self.get_register32(m["rm"])
        else:
            address = self.calc_memory_address(m)
            return self.get_memory32(address)

CARRY_FLAG = 1
ZERO_FLAG = 1 << 6
SIGN_FLAG = 1 << 7
OVERFLOW_FLAG = 1 << 11

mem_size = 1024 * 1024

emu = Emulator()
emu.create_emu(mem_size, 0x7c00, 0x7c00)
binary = open('select.bin', 'rb')
offset = 0x7c00
while True:
    b = binary.read(1)
    if b == b'':
        break
    emu.memory[offset] = b
    offset += 1
binary.close()

quiet = 0
if '-q' in sys.argv:
    quiet = 1

emu.init_instructions()
while emu.eip < mem_size:
    code = emu.get_code8(0)
    if not quiet:
        print("EIP = 0x{:02x}, Code = 0x{:02x}".format(emu.eip, code))
    if emu.instructions[code] == None:
        print("\n\nNot Implemented: 0x{:02x}".format(code))
        break
    emu.instructions[code]()
    if emu.eip == 0x00:
        print("\n\nend of program.\n\n")
        break

emu.dump_registers()

 

動作確認

それでは、上記で作成したスクリプトを実行してみます。

なお、事前にアセンブリ言語のプログラムはbinファイルとしてビルドしておきます。

 

 > python .\emulator.py -q
 >h
 hello
 >w
 world
 >q


 end of program.


 EAX = 0x7f7f7f71
 ECX = 0x00000000
 EDX = 0x000003f8
 EBX = 0x00000000
 ESP = 0x00007c00
 EBP = 0x00000000
 ESI = 0x00007c4f
 EDI = 0x00000000
 EIP = 0x00000000

 

問題なくin/out命令が実行され、alに最後に入力した'q'(0x71)が格納できたことが確認できました。

 

参考書籍

自作エミュレータで学ぶx86アーキテクチャ-コンピュータが動く仕組みを徹底理解!

デバッガによるx86プログラム解析入門【x64対応版】