本記事は、ソフトバンクパブリッシングから発行されている「定本 Cプログラマのためのアルゴリズムとデータ構造 (SOFTBANK BOOKS)」を参考にPythonでアルゴリズムとデータ構造について学習していきます。
今回は、平衡木の中のB木について学んでいきます。
平衡木とは
前回の二分探索木では、平均的な計算量はO(log n)でありながらも最悪な場合はO(n)になってしまうことを学びました。
上記の最悪な場合が発生する要因は、根から見た左右の部分木の高さが異なっていることが原因となります。
これを改善するには、木の高さをlog2 n程度に調整する必要があります。
このような木を平衡木(Balanced Tree)と言います。
B木とは
前回では最初の平衡木としてAVL木について学びましたが、今回のB木のほうが性能が良いため、AVL木はあまり実用的な価値は無いようです。
B木とは、1972年にBayerとMcCreightによって考案されたものです。
主にブロック単位のランダムアクセスが可能な外部記憶上(ハードディスクなど)での探索に適しているため、とても実用的な価値のあるデータ構造です。
AVL木では二分木をベースとしていますが、B木ではm分木を基としたデータ構造になっております。
これを多分木(Multi-Way Tree)と呼び、このような探索木を多分探索木(Multi-Way Search Tree)と呼び、以下の条件に従います。
- 根は、葉もしくは2からm個の子を持っている
- 根、葉以外の節はm/2からm個の子を持っている
- 根からすべての葉までの距離が等しい
葉にはそれぞれKeyとDataが格納されており、節には境界を示すKeyと葉へのポインタ(節が根の場合は、子へのポインタ)がそれぞれ格納されています。
fig1では、3つのリスト上のブロックが節を意味し、丸の部分が葉を意味しています。
節の上部は境界と呼ばれるもので、探索などをする際に葉を辿るための道しるべになり、下部は節、もしくは葉を格納しています。
PythonでB木を実装してみる
それでは、PythonでB木を実装してみます。
# b_tree.py MAX_CHILD = 5 HALF_CHILD = int(((MAX_CHILD+1)/2)) OK = 1 REMOVED = 2 NEED_REORG = 3 class Internal: def __init__(self): self.nchilds = 0 self.child = [None for _ in range(MAX_CHILD)] self.low = [None for _ in range(MAX_CHILD)] class Leaf: def __init__(self, key, data=None): self.key = key self.data = data class BTree: def __init__(self): self.root = None def locate_subtree(self, node, key): for i in range(node.nchilds-1, 0, -1): if key >= node.low[i]: return i return 0 def search(self, key): if not self.root: print("there is no tree.") return None else: tmp = self.root while(tmp.__class__.__name__ == 'Internal'): index = self.locate_subtree(tmp, key) tmp = tmp.child[index] if key == tmp.key: print("Data '{}' found.".format(key)) return tmp else: print("Data '{}' not found".format(key)) return None def insert_aux(self, parent, node, key, pos=None): newnode = lowest = retv = None if(node.__class__.__name__ == 'Leaf'): if node.key == key: print('Key already exists.') return retv, newnode, lowest else: new = Leaf(key) if key < node.key: if parent.__class__.__name__ == 'Internal': parent.child[pos] = new else: self.root = new lowest = node.key newnode = node else: lowest = key newnode = new return retv, newnode, lowest else: if node == self.root and node.child[0] == None: self.root = Leaf(key) return self.root, newnode, lowest pos = self.locate_subtree(node, key) retv, xnode, xlow = self.insert_aux(node, node.child[pos], key, pos) if not xnode: return retv, xnode, xlow if node.nchilds < MAX_CHILD: for i in range(node.nchilds-1, pos, -1): node.child[i+1] = node.child[i] node.low[i+1] = node.low[i] node.child[pos+1] = xnode node.low[pos+1] = xlow node.nchilds += 1 return retv, newnode, lowest else: new = Internal() if pos < HALF_CHILD - 1: for j, i in enumerate(range(HALF_CHILD-1, MAX_CHILD)): new.child[j] = node.child[i] new.low[j] = node.low[i] for i in range(HALF_CHILD-2, pos, -1): node.child[i+1] = node.child[i] node.low[i+1] = node.low[i] node.child[pos+1] = xnode node.low[pos+1] = xlow else: j = MAX_CHILD - HALF_CHILD for i in range(MAX_CHILD-1, HALF_CHILD-1, -1): if i == pos: new.child[j] = xnode new.low[j] = xlow j -= 1 new.child[j] = node.child[i] new.low[j] = node.low[i] j -= 1 if pos < HALF_CHILD: new.child[0] = xnode new.low[0] = xlow node.nchilds = HALF_CHILD new.nchilds = (MAX_CHILD+1) - HALF_CHILD newnode = new lowest = new.low[0] return retv, newnode, lowest def insert(self, key): if not self.root: self.root = Leaf(key) return self.root else: retv, newnode, lowest = self.insert_aux(self.root, self.root, key) if newnode: new = Internal() new.nchilds = 2; new.child[0] = self.root new.child[1] = newnode new.low[1] = lowest self.root = new return retv def merge_nodes(self, node, x): a = node.child[x] b = node.child[x+1] b.low[0] = node.low[x+1] an = a.nchilds bn = b.nchilds if an + bn <= MAX_CHILD: for i in range(bn): a.child[i+an] = b.child[i] a.low[i+an] = b.low[i] a.nchilds += bn return 1 else: n = int((an + bn) / 2) if an > n: move = an - n for i in range(bn-1, -1, -1): b.child[i+move] = b.child[i] b.low[i+move] = b.low[i] for i in range(move): b.child[i] = a.child[i+n] b.low[i] = a.low[i+n] else: move = n - an for i in range(move): a.child[i+an] = b.child[i] a.low[i+an] = b.low[i] for i in range(bn-move): b.child[i] = b.child[i+move] b.low[i] = b.low[i+move] a.nchilds = n b.nchilds = an + bn - n node.low[x+1] = b.low[0] return 0 def delete_aux(self, parent, node, key, pos=None): result = OK if node.__class__.__name__ == 'Leaf': if node.key == key: if node == self.root: self.root = None else: parent.child[pos] = None result = REMOVED return 1, result else: print('No data found.') return 0, result else: pos = self.locate_subtree(node, key) retv, condition = self.delete_aux(node, node.child[pos], key, pos) if condition == OK: return retv, result if condition == NEED_REORG: sub = 0 if pos == 0 else pos - 1 joined = self.merge_nodes(node, sub) if joined: pos = sub + 1 if (condition == REMOVED) or joined: for i in range(pos, node.nchilds-1): node.child[i] = node.child[i+1] node.low[i] = node.low[i+1] node.nchilds -= 1 if node.nchilds < HALF_CHILD: result = NEED_REORG return retv, result def delete(self, key): if not self.root: print("There is no tree.") return 0 else: root = self.root retv, result = self.delete_aux(root, root, key) if result == REMOVED: self.root = None elif result == NEED_REORG and root.nchilds == 1: self.root = root.child[0] return retv def print_tree(self): if not self.root: print("There is no tree.") return else: self.get_tree(self.root) def get_tree(self, node): try: if(node.__class__.__name__ == 'Leaf'): print("leaf val={}".format(node.key)) else: print("{} childs: ".format(node.nchilds), end="") for i in range(1, node.nchilds): print("{} ".format(node.low[i]), end="") print() for i in range(node.nchilds): self.get_tree(node.child[i]) except: pass t = BTree() data = [6,5,3,8,11,4,15] for i in data: t.insert(i) t.print_tree() print() print("+n:insert n\t-n:delete n\tn:search n") print("'q' to quit.") while True: res = input("> ") if res == 'q': break elif res.startswith('+'): t.insert(int(res[1:])) elif res.startswith('-'): t.delete(int(res[1:])) elif res.isdigit(): t.search(int(res)) else: continue t.print_tree() print()
今回はできるだけ書籍通り(C言語風)にコーディングをしたので、結構なコード量になりました。
動作確認
それでは上記で作成したスクリプトを実行してみます。
> python b_tree.py 2 childs: 6 3 childs: 4 5 leaf val=3 leaf val=4 leaf val=5 4 childs: 8 11 15 leaf val=6 leaf val=8 leaf val=11 leaf val=15 +n:insert n -n:delete n n:search n 'q' to quit. > +9 2 childs: 6 3 childs: 4 5 leaf val=3 leaf val=4 leaf val=5 5 childs: 8 9 11 15 leaf val=6 leaf val=8 leaf val=9 leaf val=11 leaf val=15 > +10 3 childs: 6 10 3 childs: 4 5 leaf val=3 leaf val=4 leaf val=5 3 childs: 8 9 leaf val=6 leaf val=8 leaf val=9 3 childs: 11 15 leaf val=10 leaf val=11 leaf val=15 > -5 2 childs: 10 5 childs: 4 6 8 9 leaf val=3 leaf val=4 leaf val=6 leaf val=8 leaf val=9 3 childs: 11 15 leaf val=10 leaf val=11 leaf val=15
まずデフォルト状態のB木が表示され、根から2つの子(節)が出来上がっています。
さらに9および10を加えると、子の数が3つに分割されたのが確認できます。
その後5を削除するとマージが実行され、子の数が2つに集約されたのが確認できます。
B木はかなり実用的なデータ構造でありますが、節が持つ子の最小単位がm/2個のため、記憶領域の半分は無駄となります。
そのため、節が持つ子の数を2m/3 ~ m個にしたB*木と呼ばれるデータ構造のほうが、効率的に記憶領域を使うことが出来るようです。