105 lines
2.2 KiB
Python
105 lines
2.2 KiB
Python
|
# https://en.wikipedia.org/wiki/Treap
|
||
|
|
||
|
import random
|
||
|
|
||
|
import time
|
||
|
|
||
|
|
||
|
class Treap:
|
||
|
def __init__(self, key):
|
||
|
self.key = key
|
||
|
self.prio = random.randint(0, 1000000000)
|
||
|
self.size = 1
|
||
|
self.left = None
|
||
|
self.right = None
|
||
|
|
||
|
def update(self):
|
||
|
self.size = 1 + size(self.left) + size(self.right)
|
||
|
|
||
|
|
||
|
def size(treap):
|
||
|
return 0 if treap is None else treap.size
|
||
|
|
||
|
|
||
|
def split(root, minRight):
|
||
|
if root is None:
|
||
|
return None, None
|
||
|
if root.key >= minRight:
|
||
|
left, right = split(root.left, minRight)
|
||
|
root.left = right
|
||
|
root.update()
|
||
|
return left, root
|
||
|
else:
|
||
|
left, right = split(root.right, minRight)
|
||
|
root.right = left
|
||
|
root.update()
|
||
|
return root, right
|
||
|
|
||
|
|
||
|
def merge(left, right):
|
||
|
if left is None:
|
||
|
return right
|
||
|
if right is None:
|
||
|
return left
|
||
|
if left.prio > right.prio:
|
||
|
left.right = merge(left.right, right)
|
||
|
left.update()
|
||
|
return left
|
||
|
else:
|
||
|
right.left = merge(left, right.left)
|
||
|
right.update()
|
||
|
return right
|
||
|
|
||
|
|
||
|
def insert(root, key):
|
||
|
left, right = split(root, key)
|
||
|
return merge(merge(left, Treap(key)), right)
|
||
|
|
||
|
|
||
|
def remove(root, key):
|
||
|
left, right = split(root, key)
|
||
|
return merge(left, split(right, key + 1)[1])
|
||
|
|
||
|
|
||
|
def kth(root, k):
|
||
|
if k < size(root.left):
|
||
|
return kth(root.left, k)
|
||
|
elif k > size(root.left):
|
||
|
return kth(root.right, k - size(root.left) - 1)
|
||
|
return root.key
|
||
|
|
||
|
|
||
|
def print_treap(root):
|
||
|
def dfs_print(root):
|
||
|
if root is None:
|
||
|
return
|
||
|
dfs_print(root.left)
|
||
|
print(str(root.key) + ' ', end='')
|
||
|
dfs_print(root.right)
|
||
|
|
||
|
dfs_print(root)
|
||
|
print()
|
||
|
|
||
|
|
||
|
def test():
|
||
|
start = time.time()
|
||
|
treap = None
|
||
|
s = set()
|
||
|
for i in range(100000):
|
||
|
key = random.randint(0, 10000)
|
||
|
if random.randint(0, 1) == 0:
|
||
|
if key in s:
|
||
|
treap = remove(treap, key)
|
||
|
s.remove(key)
|
||
|
elif key not in s:
|
||
|
treap = insert(treap, key)
|
||
|
s.add(key)
|
||
|
assert len(s) == size(treap)
|
||
|
|
||
|
for i in range(size(treap)):
|
||
|
assert kth(treap, i) in s
|
||
|
print(time.time() - start)
|
||
|
|
||
|
|
||
|
test()
|