Engineering Note

プログラミングなどの技術的なメモ

B木 (Pythonによるアルゴリズムとデータ構造)

b tree

本記事は、ソフトバンクパブリッシングから発行されている「定本 Cプログラマのためのアルゴリズムとデータ構造 (SOFTBANK BOOKS)」を参考にPythonアルゴリズムとデータ構造について学習していきます。

今回は、平衡木の中のB木について学んでいきます。

 

 

平衡木とは

前回の二分探索木では、平均的な計算量はO(log n)でありながらも最悪な場合はO(n)になってしまうことを学びました。

 

 

上記の最悪な場合が発生する要因は、根から見た左右の部分木の高さが異なっていることが原因となります。

これを改善するには、木の高さをlog2 n程度に調整する必要があります。

このような木を平衡木(Balanced Tree)と言います。

 

B木とは

前回では最初の平衡木としてAVL木について学びましたが、今回のB木のほうが性能が良いため、AVL木はあまり実用的な価値は無いようです。

B木とは、1972年にBayerとMcCreightによって考案されたものです。

主にブロック単位のランダムアクセスが可能な外部記憶上(ハードディスクなど)での探索に適しているため、とても実用的な価値のあるデータ構造です。

AVL木では二分木をベースとしていますが、B木ではm分木を基としたデータ構造になっております。

これを多分木(Multi-Way Tree)と呼び、このような探索木を多分探索木(Multi-Way Search Tree)と呼び、以下の条件に従います。

 

  1. 根は、葉もしくは2からm個の子を持っている
  2. 根、葉以外の節はm/2からm個の子を持っている
  3. 根からすべての葉までの距離が等しい

 

葉にはそれぞれKeyとDataが格納されており、節には境界を示すKeyと葉へのポインタ(節が根の場合は、子へのポインタ)がそれぞれ格納されています。

 

b_tree

fig1. B木のデータ構造

fig1では、3つのリスト上のブロックが節を意味し、丸の部分が葉を意味しています。

節の上部は境界と呼ばれるもので、探索などをする際に葉を辿るための道しるべになり、下部は節、もしくは葉を格納しています。

 

PythonでB木を実装してみる

それでは、PythonでB木を実装してみます。

 

# b_tree.py
MAX_CHILD   = 5
HALF_CHILD  = int(((MAX_CHILD+1)/2))
OK          = 1
REMOVED     = 2
NEED_REORG  = 3

class Internal:
    def __init__(self):
        self.nchilds = 0
        self.child = [None for _ in range(MAX_CHILD)]
        self.low   = [None for _ in range(MAX_CHILD)]

class Leaf:
    def __init__(self, key, data=None):
        self.key = key
        self.data = data

class BTree:
    def __init__(self):
        self.root = None

    def locate_subtree(self, node, key):
        for i in range(node.nchilds-1, 0, -1):
            if key >= node.low[i]:
                return i
        return 0

    def search(self, key):
        if not self.root:
            print("there is no tree.")
            return None
        else:
            tmp = self.root
            while(tmp.__class__.__name__ == 'Internal'):
                index = self.locate_subtree(tmp, key)
                tmp = tmp.child[index]
            if key == tmp.key:
                print("Data '{}' found.".format(key))
                return tmp
            else:
                print("Data '{}' not found".format(key))
                return None

    def insert_aux(self, parent, node, key, pos=None):
        newnode = lowest = retv = None
        if(node.__class__.__name__ == 'Leaf'):
            if node.key == key:
                print('Key already exists.')
                return retv, newnode, lowest
            else:
                new = Leaf(key)
                if key < node.key:
                    if parent.__class__.__name__ == 'Internal':
                        parent.child[pos] = new
                    else:
                        self.root = new
                    lowest = node.key
                    newnode = node
                else:
                    lowest = key
                    newnode = new
                return retv, newnode, lowest
        else:
            if node == self.root and node.child[0] == None:
                self.root = Leaf(key)
                return self.root, newnode, lowest
            pos = self.locate_subtree(node, key)
            retv, xnode, xlow = self.insert_aux(node, node.child[pos], key, pos)
            if not xnode:
                return retv, xnode, xlow
            if node.nchilds < MAX_CHILD:
                for i in range(node.nchilds-1, pos, -1):
                    node.child[i+1] = node.child[i]
                    node.low[i+1] = node.low[i]
                node.child[pos+1] = xnode
                node.low[pos+1] = xlow
                node.nchilds += 1
                return retv, newnode, lowest
            else:
                new = Internal()
                if pos < HALF_CHILD - 1:
                    for j, i in enumerate(range(HALF_CHILD-1, MAX_CHILD)):
                        new.child[j] = node.child[i]
                        new.low[j] = node.low[i]
                    for i in range(HALF_CHILD-2, pos, -1):
                        node.child[i+1] = node.child[i]
                        node.low[i+1] = node.low[i]
                    node.child[pos+1] = xnode
                    node.low[pos+1] = xlow
                else:
                    j = MAX_CHILD - HALF_CHILD
                    for i in range(MAX_CHILD-1, HALF_CHILD-1, -1):
                        if i == pos:
                            new.child[j] = xnode
                            new.low[j] = xlow
                            j -= 1
                        new.child[j] = node.child[i]
                        new.low[j] = node.low[i]
                        j -= 1
                    if pos < HALF_CHILD:
                        new.child[0] = xnode
                        new.low[0] = xlow
                node.nchilds = HALF_CHILD
                new.nchilds = (MAX_CHILD+1) - HALF_CHILD
                newnode = new
                lowest = new.low[0]
                return retv, newnode, lowest

    def insert(self, key):
        if not self.root:
            self.root = Leaf(key)
            return self.root
        else:
            retv, newnode, lowest = self.insert_aux(self.root, self.root, key)
            if newnode:
                new = Internal()
                new.nchilds = 2;
                new.child[0] = self.root
                new.child[1] = newnode
                new.low[1] = lowest
                self.root = new
            return retv

    def merge_nodes(self, node, x):
        a = node.child[x]
        b = node.child[x+1]
        b.low[0] = node.low[x+1]
        an = a.nchilds
        bn = b.nchilds
        if an + bn <= MAX_CHILD:
            for i in range(bn):
                a.child[i+an] = b.child[i]
                a.low[i+an] = b.low[i]
            a.nchilds += bn
            return 1
        else:
            n = int((an + bn) / 2)
            if an > n:
                move = an - n
                for i in range(bn-1, -1, -1):
                    b.child[i+move] = b.child[i]
                    b.low[i+move] = b.low[i]
                for i in range(move):
                    b.child[i] = a.child[i+n]
                    b.low[i] = a.low[i+n]
            else:
                move = n - an
                for i in range(move):
                    a.child[i+an] = b.child[i]
                    a.low[i+an] = b.low[i]
                for i in range(bn-move):
                    b.child[i] = b.child[i+move]
                    b.low[i] = b.low[i+move]
            a.nchilds = n
            b.nchilds = an + bn - n
            node.low[x+1] = b.low[0]
            return 0

    def delete_aux(self, parent, node, key, pos=None):
        result = OK
        if node.__class__.__name__ == 'Leaf':
            if node.key == key:
                if node == self.root:
                    self.root = None
                else:
                    parent.child[pos] = None
                result = REMOVED
                return 1, result
            else:
                print('No data found.')
                return 0, result
        else:
            pos = self.locate_subtree(node, key)
            retv, condition = self.delete_aux(node, node.child[pos], key, pos)
            if condition == OK:
                return retv, result
            if condition == NEED_REORG:
                sub = 0 if pos == 0 else pos - 1
                joined = self.merge_nodes(node, sub)
                if joined:
                    pos = sub + 1
            if (condition == REMOVED) or joined:
                for i in range(pos, node.nchilds-1):
                    node.child[i] = node.child[i+1]
                    node.low[i] = node.low[i+1]
                node.nchilds -= 1
                if node.nchilds < HALF_CHILD:
                    result = NEED_REORG
            return retv, result

    def delete(self, key):
        if not self.root:
            print("There is no tree.")
            return 0
        else:
            root = self.root
            retv, result = self.delete_aux(root, root, key)
            if result == REMOVED:
                self.root = None
            elif result == NEED_REORG and root.nchilds == 1:
                self.root = root.child[0]
            return retv


    def print_tree(self):
        if not self.root:
            print("There is no tree.")
            return
        else:
            self.get_tree(self.root)

    def get_tree(self, node):
        try:
            if(node.__class__.__name__ == 'Leaf'):
                print("leaf val={}".format(node.key))
            else:
                print("{} childs: ".format(node.nchilds), end="")
                for i in range(1, node.nchilds):
                    print("{} ".format(node.low[i]), end="")
                print()
                for i in range(node.nchilds):
                    self.get_tree(node.child[i])
        except:
            pass

t = BTree()
data = [6,5,3,8,11,4,15]
for i in data:
    t.insert(i)
t.print_tree()
print()
print("+n:insert n\t-n:delete n\tn:search n")
print("'q' to quit.")
while True:
    res = input("> ")
    if res == 'q':
        break
    elif res.startswith('+'):
        t.insert(int(res[1:]))
    elif res.startswith('-'):
        t.delete(int(res[1:]))
    elif res.isdigit():
        t.search(int(res))
    else:
        continue
    t.print_tree()
    print()

 

今回はできるだけ書籍通り(C言語風)にコーディングをしたので、結構なコード量になりました。

 

動作確認

それでは上記で作成したスクリプトを実行してみます。

 

 > python b_tree.py
 2 childs: 6
 3 childs: 4 5
 leaf val=3
 leaf val=4
 leaf val=5
 4 childs: 8 11 15
 leaf val=6
 leaf val=8
 leaf val=11
 leaf val=15
 
 +n:insert n     -n:delete n     n:search n
 'q' to quit.
 > +9
 2 childs: 6
 3 childs: 4 5
 leaf val=3
 leaf val=4
 leaf val=5
 5 childs: 8 9 11 15
 leaf val=6
 leaf val=8
 leaf val=9
 leaf val=11
 leaf val=15
 
 > +10
 3 childs: 6 10
 3 childs: 4 5
 leaf val=3
 leaf val=4
 leaf val=5
 3 childs: 8 9
 leaf val=6
 leaf val=8
 leaf val=9
 3 childs: 11 15
 leaf val=10
 leaf val=11
 leaf val=15
 
 > -5
 2 childs: 10
 5 childs: 4 6 8 9
 leaf val=3
 leaf val=4
 leaf val=6
 leaf val=8
 leaf val=9
 3 childs: 11 15
 leaf val=10
 leaf val=11
 leaf val=15

 

まずデフォルト状態のB木が表示され、根から2つの子(節)が出来上がっています。

さらに9および10を加えると、子の数が3つに分割されたのが確認できます。

その後5を削除するとマージが実行され、子の数が2つに集約されたのが確認できます。

 

B木はかなり実用的なデータ構造でありますが、節が持つ子の最小単位がm/2個のため、記憶領域の半分は無駄となります。

そのため、節が持つ子の数を2m/3 ~ m個にしたB*木と呼ばれるデータ構造のほうが、効率的に記憶領域を使うことが出来るようです。

 

参考書籍

定本 Cプログラマのためのアルゴリズムとデータ構造 (SOFTBANK BOOKS)