bankai_verify/bankai/
mmr.rs

1extern crate alloc;
2use alloc::vec::Vec;
3
4use alloy_primitives::{keccak256, B256};
5use bankai_types::{fetch::evm::MmrProof, proofs::HashingFunctionDto, utils::mmr::hash_to_leaf};
6use starknet_crypto::{poseidon_hash, Felt};
7
8use crate::VerifyError;
9trait MmrHasher {
10    type Word: Copy + PartialEq;
11
12    /// Convert a 32-byte value from proofs into the internal word type.
13    fn from_b256(x: &B256) -> Self::Word;
14
15    /// Hash an ordered pair `(left, right)`.
16    fn hash_pair(left: &Self::Word, right: &Self::Word) -> Self::Word;
17
18    /// Compute the leaf from the raw header hash.
19    fn leaf_from_header_hash(header_hash: &B256) -> Self::Word;
20
21    /// Bag peaks right-to-left: `H(peak1, H(peak2, ... H(peakN)))`.
22    fn bag_peaks(peaks: &[Self::Word]) -> Self::Word {
23        match peaks.len() {
24            0 => Self::from_b256(&B256::ZERO),
25            1 => peaks[0],
26            _ => {
27                let mut acc = *peaks.last().unwrap();
28                for i in (0..peaks.len() - 1).rev() {
29                    acc = Self::hash_pair(&peaks[i], &acc);
30                }
31                acc
32            }
33        }
34    }
35
36    /// Compute the MMR root from the elements count and bagged peaks.
37    fn mmr_root(elements_count: u128, bag: &Self::Word) -> Self::Word;
38}
39
40/// Keccak-based hasher using `B256` as the word type.
41struct KeccakHasher;
42
43impl MmrHasher for KeccakHasher {
44    type Word = B256;
45
46    #[inline]
47    fn from_b256(x: &B256) -> Self::Word {
48        *x
49    }
50
51    #[inline]
52    fn hash_pair(left: &Self::Word, right: &Self::Word) -> Self::Word {
53        let mut buf = [0u8; 64];
54        buf[..32].copy_from_slice(left.as_slice());
55        buf[32..].copy_from_slice(right.as_slice());
56        keccak256(buf)
57    }
58
59    #[inline]
60    fn leaf_from_header_hash(header_hash: &B256) -> Self::Word {
61        hash_to_leaf(*header_hash, &HashingFunctionDto::Keccak)
62    }
63
64    #[inline]
65    fn mmr_root(elements_count: u128, bag: &Self::Word) -> Self::Word {
66        let mut size_be = [0u8; 32];
67        size_be[16..].copy_from_slice(&elements_count.to_be_bytes());
68        let mut buf = [0u8; 64];
69        buf[..32].copy_from_slice(&size_be);
70        buf[32..].copy_from_slice(bag.as_slice());
71        keccak256(buf)
72    }
73}
74
75/// Poseidon-based hasher using `Felt` as the word type.
76struct PoseidonHasher;
77
78impl MmrHasher for PoseidonHasher {
79    type Word = Felt;
80
81    #[inline]
82    fn from_b256(x: &B256) -> Self::Word {
83        Felt::from_bytes_be_slice(x.as_slice())
84    }
85
86    #[inline]
87    fn hash_pair(left: &Self::Word, right: &Self::Word) -> Self::Word {
88        poseidon_hash(*left, *right)
89    }
90
91    #[inline]
92    fn leaf_from_header_hash(header_hash: &B256) -> Self::Word {
93        // Reuse existing leaf definition then convert to Felt
94        let leaf_bytes = hash_to_leaf(*header_hash, &HashingFunctionDto::Poseidon);
95        Felt::from_bytes_be_slice(leaf_bytes.as_slice())
96    }
97
98    #[inline]
99    fn mmr_root(elements_count: u128, bag: &Self::Word) -> Self::Word {
100        let size_felt = Felt::from_bytes_be_slice(&elements_count.to_be_bytes());
101        poseidon_hash(size_felt, *bag)
102    }
103}
104
105/// Iterative subtree path hashing generic over the hasher.
106///
107/// Replays the sibling path from the leaf to its peak, following the same
108/// left/right rules as the Cairo implementation but without recursion.
109fn hash_subtree_path_iter<H: MmrHasher>(
110    mut element: H::Word,
111    mut height: usize,
112    mut position: usize,
113    path: &[H::Word],
114) -> H::Word {
115    if path.is_empty() {
116        return element;
117    }
118    for sibling in path.iter() {
119        let position_height = compute_height(position);
120        let next_height = compute_height(position + 1);
121        if next_height == position_height + 1 {
122            element = H::hash_pair(sibling, &element);
123            position += 1;
124        } else {
125            element = H::hash_pair(&element, sibling);
126            position += 1usize << (height + 1);
127        }
128        height += 1;
129    }
130    element
131}
132
133pub struct MmrVerifier;
134
135impl MmrVerifier {
136    /// Verifies a single MMR proof, dispatching to the appropriate hasher.
137    pub fn verify_mmr_proof(proof: &MmrProof) -> Result<bool, VerifyError> {
138        assert_mmr_size_is_valid(proof.elements_count as usize)?;
139
140        let expected_peaks_len = compute_expected_peaks_len(proof.elements_count as usize)?;
141        if proof.peaks.len() != expected_peaks_len {
142            return Err(VerifyError::InvalidMmrTree);
143        }
144
145        match proof.hashing_function {
146            HashingFunctionDto::Keccak => verify_with_hasher::<KeccakHasher>(proof),
147            HashingFunctionDto::Poseidon => verify_with_hasher::<PoseidonHasher>(proof),
148        }
149    }
150}
151
152/// Generic verification with a concrete `MmrHasher` implementation.
153fn verify_with_hasher<H: MmrHasher>(proof: &MmrProof) -> Result<bool, VerifyError> {
154    let elements_count = proof.elements_count as usize;
155    let element_index = proof.elements_index as usize;
156
157    let (peak_index, peak_height) =
158        get_peak_info(elements_count, element_index).ok_or(VerifyError::InvalidMmrProof)?;
159
160    if element_index != elements_count {
161        if proof.path.len() != peak_height {
162            return Err(VerifyError::InvalidMmrProof);
163        }
164    } else if !proof.path.is_empty() {
165        return Err(VerifyError::InvalidMmrProof);
166    }
167
168    let leaf = H::leaf_from_header_hash(&proof.header_hash);
169
170    let computed_peak = if element_index == elements_count {
171        leaf
172    } else {
173        let siblings: Vec<H::Word> = proof.path.iter().map(H::from_b256).collect();
174        hash_subtree_path_iter::<H>(leaf, 0, element_index, &siblings)
175    };
176
177    let peaks: Vec<H::Word> = proof.peaks.iter().map(H::from_b256).collect();
178    if peaks[peak_index] != computed_peak {
179        return Err(VerifyError::InvalidMmrProof);
180    }
181
182    let bag = H::bag_peaks(&peaks);
183    let root = H::mmr_root(elements_count as u128, &bag);
184    let proof_root = H::from_b256(&proof.root);
185    if root != proof_root {
186        return Err(VerifyError::InvalidMmrRoot);
187    }
188    Ok(true)
189}
190
191/// Validates that an MMR size can be decomposed into distinct peaks of the form `(2^k - 1)`.
192fn assert_mmr_size_is_valid(x: usize) -> Result<(), VerifyError> {
193    if x == 0 {
194        return Err(VerifyError::InvalidMmrTree);
195    }
196
197    let mut n = x;
198    let mut prev_peak = 0usize;
199    while n > 0 {
200        let i = bit_length(n);
201        if i == 0 {
202            return Err(VerifyError::InvalidMmrTree);
203        }
204        let peak_tmp = (1usize << i) - 1;
205        let peak = if n < peak_tmp {
206            (1usize << (i - 1)) - 1
207        } else {
208            peak_tmp
209        };
210        if peak == 0 || peak == prev_peak {
211            return Err(VerifyError::InvalidMmrTree);
212        }
213        n -= peak;
214        prev_peak = peak;
215    }
216    Ok(())
217}
218
219/// Computes how many peaks an MMR of `mmr_size` elements should have.
220fn compute_expected_peaks_len(mmr_size: usize) -> Result<usize, VerifyError> {
221    assert_mmr_size_is_valid(mmr_size)?;
222    let mut n = mmr_size;
223    let mut count = 0usize;
224    let mut prev_peak = 0usize;
225    while n > 0 {
226        let i = bit_length(n);
227        let peak_tmp = (1usize << i) - 1;
228        let peak = if n < peak_tmp {
229            (1usize << (i - 1)) - 1
230        } else {
231            peak_tmp
232        };
233        if peak == 0 || peak == prev_peak {
234            return Err(VerifyError::InvalidMmrTree);
235        }
236        count += 1;
237        n -= peak;
238        prev_peak = peak;
239    }
240    Ok(count)
241}
242
243/// Returns `(peak_index, peak_height)` for the 1-indexed `element_index` in an MMR with
244/// `elements_count` total elements. The height is the number of edges from the leaf to the peak.
245fn get_peak_info(mut elements_count: usize, mut element_index: usize) -> Option<(usize, usize)> {
246    if element_index == 0 || element_index > elements_count {
247        return None;
248    }
249    let mut mountain_height = bit_length(elements_count);
250    let mut mountain_elements_count = (1usize << mountain_height) - 1;
251    let mut mountain_index = 0usize;
252    loop {
253        if mountain_elements_count <= elements_count {
254            if element_index <= mountain_elements_count {
255                return Some((mountain_index, mountain_height.saturating_sub(1)));
256            }
257            elements_count -= mountain_elements_count;
258            element_index -= mountain_elements_count;
259            mountain_index += 1;
260        }
261        mountain_elements_count >>= 1;
262        mountain_height = mountain_height.saturating_sub(1);
263    }
264}
265
266/// Computes the height in the implicit binary tree for a 1-indexed position `x`.
267///
268/// This walks left in the implicit perfect binary tree until reaching a peak `(2^k - 1)` and returns `k - 1`.
269fn compute_height(mut x: usize) -> usize {
270    loop {
271        let bit_len = bit_length(x);
272        if bit_len == 0 {
273            return 0;
274        }
275        let n = 1usize << (bit_len - 1);
276        let n2 = 1usize << bit_len; // N
277        if x == n2 - 1 {
278            return bit_len - 1;
279        } else {
280            // Jump left: x = x - n + 1
281            x = x - n + 1;
282        }
283    }
284}
285
286/// Returns the number of bits required to represent `n` (0 => 0, 1.. => floor(log2(n)) + 1).
287fn bit_length(n: usize) -> usize {
288    (usize::BITS as usize) - n.leading_zeros() as usize
289}