bankai_verify/bankai/
mmr.rs1extern crate alloc;
2use alloc::vec::Vec;
3
4use alloy_primitives::{keccak256, B256};
5use bankai_types::{fetch::evm::MmrProof, proofs::HashingFunctionDto, utils::mmr::hash_to_leaf};
6use starknet_crypto::{poseidon_hash, Felt};
7
8use crate::VerifyError;
9trait MmrHasher {
10 type Word: Copy + PartialEq;
11
12 fn from_b256(x: &B256) -> Self::Word;
14
15 fn hash_pair(left: &Self::Word, right: &Self::Word) -> Self::Word;
17
18 fn leaf_from_header_hash(header_hash: &B256) -> Self::Word;
20
21 fn bag_peaks(peaks: &[Self::Word]) -> Self::Word {
23 match peaks.len() {
24 0 => Self::from_b256(&B256::ZERO),
25 1 => peaks[0],
26 _ => {
27 let mut acc = *peaks.last().unwrap();
28 for i in (0..peaks.len() - 1).rev() {
29 acc = Self::hash_pair(&peaks[i], &acc);
30 }
31 acc
32 }
33 }
34 }
35
36 fn mmr_root(elements_count: u128, bag: &Self::Word) -> Self::Word;
38}
39
40struct KeccakHasher;
42
43impl MmrHasher for KeccakHasher {
44 type Word = B256;
45
46 #[inline]
47 fn from_b256(x: &B256) -> Self::Word {
48 *x
49 }
50
51 #[inline]
52 fn hash_pair(left: &Self::Word, right: &Self::Word) -> Self::Word {
53 let mut buf = [0u8; 64];
54 buf[..32].copy_from_slice(left.as_slice());
55 buf[32..].copy_from_slice(right.as_slice());
56 keccak256(buf)
57 }
58
59 #[inline]
60 fn leaf_from_header_hash(header_hash: &B256) -> Self::Word {
61 hash_to_leaf(*header_hash, &HashingFunctionDto::Keccak)
62 }
63
64 #[inline]
65 fn mmr_root(elements_count: u128, bag: &Self::Word) -> Self::Word {
66 let mut size_be = [0u8; 32];
67 size_be[16..].copy_from_slice(&elements_count.to_be_bytes());
68 let mut buf = [0u8; 64];
69 buf[..32].copy_from_slice(&size_be);
70 buf[32..].copy_from_slice(bag.as_slice());
71 keccak256(buf)
72 }
73}
74
75struct PoseidonHasher;
77
78impl MmrHasher for PoseidonHasher {
79 type Word = Felt;
80
81 #[inline]
82 fn from_b256(x: &B256) -> Self::Word {
83 Felt::from_bytes_be_slice(x.as_slice())
84 }
85
86 #[inline]
87 fn hash_pair(left: &Self::Word, right: &Self::Word) -> Self::Word {
88 poseidon_hash(*left, *right)
89 }
90
91 #[inline]
92 fn leaf_from_header_hash(header_hash: &B256) -> Self::Word {
93 let leaf_bytes = hash_to_leaf(*header_hash, &HashingFunctionDto::Poseidon);
95 Felt::from_bytes_be_slice(leaf_bytes.as_slice())
96 }
97
98 #[inline]
99 fn mmr_root(elements_count: u128, bag: &Self::Word) -> Self::Word {
100 let size_felt = Felt::from_bytes_be_slice(&elements_count.to_be_bytes());
101 poseidon_hash(size_felt, *bag)
102 }
103}
104
105fn hash_subtree_path_iter<H: MmrHasher>(
110 mut element: H::Word,
111 mut height: usize,
112 mut position: usize,
113 path: &[H::Word],
114) -> H::Word {
115 if path.is_empty() {
116 return element;
117 }
118 for sibling in path.iter() {
119 let position_height = compute_height(position);
120 let next_height = compute_height(position + 1);
121 if next_height == position_height + 1 {
122 element = H::hash_pair(sibling, &element);
123 position += 1;
124 } else {
125 element = H::hash_pair(&element, sibling);
126 position += 1usize << (height + 1);
127 }
128 height += 1;
129 }
130 element
131}
132
133pub struct MmrVerifier;
134
135impl MmrVerifier {
136 pub fn verify_mmr_proof(proof: &MmrProof) -> Result<bool, VerifyError> {
138 assert_mmr_size_is_valid(proof.elements_count as usize)?;
139
140 let expected_peaks_len = compute_expected_peaks_len(proof.elements_count as usize)?;
141 if proof.peaks.len() != expected_peaks_len {
142 return Err(VerifyError::InvalidMmrTree);
143 }
144
145 match proof.hashing_function {
146 HashingFunctionDto::Keccak => verify_with_hasher::<KeccakHasher>(proof),
147 HashingFunctionDto::Poseidon => verify_with_hasher::<PoseidonHasher>(proof),
148 }
149 }
150}
151
152fn verify_with_hasher<H: MmrHasher>(proof: &MmrProof) -> Result<bool, VerifyError> {
154 let elements_count = proof.elements_count as usize;
155 let element_index = proof.elements_index as usize;
156
157 let (peak_index, peak_height) =
158 get_peak_info(elements_count, element_index).ok_or(VerifyError::InvalidMmrProof)?;
159
160 if element_index != elements_count {
161 if proof.path.len() != peak_height {
162 return Err(VerifyError::InvalidMmrProof);
163 }
164 } else if !proof.path.is_empty() {
165 return Err(VerifyError::InvalidMmrProof);
166 }
167
168 let leaf = H::leaf_from_header_hash(&proof.header_hash);
169
170 let computed_peak = if element_index == elements_count {
171 leaf
172 } else {
173 let siblings: Vec<H::Word> = proof.path.iter().map(H::from_b256).collect();
174 hash_subtree_path_iter::<H>(leaf, 0, element_index, &siblings)
175 };
176
177 let peaks: Vec<H::Word> = proof.peaks.iter().map(H::from_b256).collect();
178 if peaks[peak_index] != computed_peak {
179 return Err(VerifyError::InvalidMmrProof);
180 }
181
182 let bag = H::bag_peaks(&peaks);
183 let root = H::mmr_root(elements_count as u128, &bag);
184 let proof_root = H::from_b256(&proof.root);
185 if root != proof_root {
186 return Err(VerifyError::InvalidMmrRoot);
187 }
188 Ok(true)
189}
190
191fn assert_mmr_size_is_valid(x: usize) -> Result<(), VerifyError> {
193 if x == 0 {
194 return Err(VerifyError::InvalidMmrTree);
195 }
196
197 let mut n = x;
198 let mut prev_peak = 0usize;
199 while n > 0 {
200 let i = bit_length(n);
201 if i == 0 {
202 return Err(VerifyError::InvalidMmrTree);
203 }
204 let peak_tmp = (1usize << i) - 1;
205 let peak = if n < peak_tmp {
206 (1usize << (i - 1)) - 1
207 } else {
208 peak_tmp
209 };
210 if peak == 0 || peak == prev_peak {
211 return Err(VerifyError::InvalidMmrTree);
212 }
213 n -= peak;
214 prev_peak = peak;
215 }
216 Ok(())
217}
218
219fn compute_expected_peaks_len(mmr_size: usize) -> Result<usize, VerifyError> {
221 assert_mmr_size_is_valid(mmr_size)?;
222 let mut n = mmr_size;
223 let mut count = 0usize;
224 let mut prev_peak = 0usize;
225 while n > 0 {
226 let i = bit_length(n);
227 let peak_tmp = (1usize << i) - 1;
228 let peak = if n < peak_tmp {
229 (1usize << (i - 1)) - 1
230 } else {
231 peak_tmp
232 };
233 if peak == 0 || peak == prev_peak {
234 return Err(VerifyError::InvalidMmrTree);
235 }
236 count += 1;
237 n -= peak;
238 prev_peak = peak;
239 }
240 Ok(count)
241}
242
243fn get_peak_info(mut elements_count: usize, mut element_index: usize) -> Option<(usize, usize)> {
246 if element_index == 0 || element_index > elements_count {
247 return None;
248 }
249 let mut mountain_height = bit_length(elements_count);
250 let mut mountain_elements_count = (1usize << mountain_height) - 1;
251 let mut mountain_index = 0usize;
252 loop {
253 if mountain_elements_count <= elements_count {
254 if element_index <= mountain_elements_count {
255 return Some((mountain_index, mountain_height.saturating_sub(1)));
256 }
257 elements_count -= mountain_elements_count;
258 element_index -= mountain_elements_count;
259 mountain_index += 1;
260 }
261 mountain_elements_count >>= 1;
262 mountain_height = mountain_height.saturating_sub(1);
263 }
264}
265
266fn compute_height(mut x: usize) -> usize {
270 loop {
271 let bit_len = bit_length(x);
272 if bit_len == 0 {
273 return 0;
274 }
275 let n = 1usize << (bit_len - 1);
276 let n2 = 1usize << bit_len; if x == n2 - 1 {
278 return bit_len - 1;
279 } else {
280 x = x - n + 1;
282 }
283 }
284}
285
286fn bit_length(n: usize) -> usize {
288 (usize::BITS as usize) - n.leading_zeros() as usize
289}