"""Succinct binary-recursive wavelet tree with rank-support."""
from __future__ import annotations
from typing import List
__all__ = ["WaveletTree"]
[docs]
class WaveletTree:
"""Wavelet tree over a *bytes* sequence providing **rank1** queries.
The implementation stores a bitvector at each internal node and samples
prefix ranks every *sample_rate* bits (default = 32) for O(1) rank queries.
Memory usage ≈ *n log σ* bits.
"""
__slots__ = (
"length",
"alphabet",
"bitvec",
"sample_rate",
"prefix_ranks",
"left",
"right",
)
def __init__(self, data: bytes, alphabet: bytes, *, sample_rate: int = 32):
self.length = len(data)
self.alphabet = alphabet
self.sample_rate = sample_rate
if len(alphabet) == 1:
# Leaf — no bitvector needed.
self.bitvec = None
self.prefix_ranks = None
self.left = self.right = None
return
mid = len(alphabet) // 2
left_alpha, right_alpha = alphabet[:mid], alphabet[mid:]
# Build bitvector and partition children sequences.
left_bytes: List[int] = []
right_bytes: List[int] = []
bits: List[int] = []
cur, cnt = 0, 0
for b in data:
go_left = b in left_alpha
# 0 → left, 1 → right
cur = (cur << 1) | int(not go_left)
if go_left:
left_bytes.append(b)
else:
right_bytes.append(b)
cnt += 1
if cnt == 8:
bits.append(cur)
cur, cnt = 0, 0
if cnt:
bits.append(cur << (8 - cnt))
self.bitvec = bytes(bits)
# Precompute rank1 samples.
self.prefix_ranks = [0]
total = 0
bit_index = 0
for byte in self.bitvec:
for i in range(8):
if bit_index % self.sample_rate == 0:
self.prefix_ranks.append(total)
total += (byte >> (7 - i)) & 1
bit_index += 1
self.prefix_ranks.append(total) # sentinel
# Recurse.
self.left = (
WaveletTree(bytes(left_bytes), left_alpha, sample_rate=sample_rate)
if left_alpha
else None
)
self.right = (
WaveletTree(bytes(right_bytes), right_alpha, sample_rate=sample_rate)
if right_alpha
else None
)
# ------ Public API ------
[docs]
def rank(self, symbol: int, i: int) -> int:
"""Return *#(symbol) in [0, i)*."""
if len(self.alphabet) == 1 or i <= 0:
return min(i, self.length)
mid = len(self.alphabet) // 2
go_left = symbol in self.alphabet[:mid]
ones = self._rank1(i)
zeros = i - ones
return (
self.left.rank(symbol, zeros) if go_left else self.right.rank(symbol, ones)
)
# ------ Internal helpers ------
def _rank1(self, i: int) -> int:
if i <= 0:
return 0
block = i // self.sample_rate
rank = self.prefix_ranks[block]
start_bit = block * self.sample_rate
bits_to_scan = i - start_bit
byte_idx, bit_off = divmod(start_bit, 8)
scanned = 0
while scanned < bits_to_scan:
byte = self.bitvec[byte_idx]
for bit in range(bit_off, 8):
if scanned == bits_to_scan:
break
rank += (byte >> (7 - bit)) & 1
scanned += 1
byte_idx += 1
bit_off = 0
return rank