sui_crypto/zklogin/poseidon/
mod.rs

1//! Poseidon Hash implementation using ark-ff
2//!
3//! This module is vendored from <https://github.com/arnaucube/poseidon-ark> at commit
4//! 6d2487aa1308d9d3860a2b724c485d73095c1c68 with a few minor changes on top.
5
6#![allow(clippy::needless_range_loop)]
7
8use ark_bn254::Fr;
9use ark_ff::fields::Field;
10use ark_std::str::FromStr;
11use ark_std::Zero;
12use core::ops::AddAssign;
13use core::ops::MulAssign;
14
15mod constants;
16
17pub static POSEIDON: std::sync::LazyLock<Poseidon> = std::sync::LazyLock::new(Poseidon::new);
18
19#[derive(Debug)]
20struct Constants {
21    pub c: Vec<Vec<Fr>>,
22    pub m: Vec<Vec<Vec<Fr>>>,
23    pub n_rounds_f: usize,
24    pub n_rounds_p: Vec<usize>,
25}
26
27fn load_constants() -> Constants {
28    let (c_str, m_str) = constants::constants();
29    let mut c: Vec<Vec<Fr>> = Vec::new();
30    for i in 0..c_str.len() {
31        let mut cci: Vec<Fr> = Vec::new();
32        for j in 0..c_str[i].len() {
33            let b: Fr = Fr::from_str(c_str[i][j]).unwrap();
34            cci.push(b);
35        }
36        c.push(cci);
37    }
38    let mut m: Vec<Vec<Vec<Fr>>> = Vec::new();
39    for i in 0..m_str.len() {
40        let mut mi: Vec<Vec<Fr>> = Vec::new();
41        for j in 0..m_str[i].len() {
42            let mut mij: Vec<Fr> = Vec::new();
43            for k in 0..m_str[i][j].len() {
44                let b: Fr = Fr::from_str(m_str[i][j][k]).unwrap();
45                mij.push(b);
46            }
47            mi.push(mij);
48        }
49        m.push(mi);
50    }
51    Constants {
52        c,
53        m,
54        n_rounds_f: 8,
55        n_rounds_p: vec![
56            56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68,
57        ],
58    }
59}
60
61pub struct Poseidon {
62    constants: Constants,
63}
64
65impl Poseidon {
66    pub fn new() -> Poseidon {
67        Poseidon {
68            constants: load_constants(),
69        }
70    }
71
72    fn ark(&self, state: &mut [Fr], c: &[Fr], it: usize) {
73        for i in 0..state.len() {
74            state[i].add_assign(&c[it + i]);
75        }
76    }
77
78    fn sbox(&self, n_rounds_f: usize, n_rounds_p: usize, state: &mut [Fr], i: usize) {
79        if i < n_rounds_f / 2 || i >= n_rounds_f / 2 + n_rounds_p {
80            for j in 0..state.len() {
81                let aux = state[j];
82                state[j] = state[j].square();
83                state[j] = state[j].square();
84                state[j].mul_assign(&aux);
85            }
86        } else {
87            let aux = state[0];
88            state[0] = state[0].square();
89            state[0] = state[0].square();
90            state[0].mul_assign(&aux);
91        }
92    }
93
94    fn mix(&self, state: &[Fr], m: &[Vec<Fr>]) -> Vec<Fr> {
95        let mut new_state: Vec<Fr> = Vec::new();
96        for i in 0..state.len() {
97            new_state.push(Fr::zero());
98            for j in 0..state.len() {
99                let mut mij = m[i][j];
100                mij.mul_assign(&state[j]);
101                new_state[i].add_assign(&mij);
102            }
103        }
104        new_state.clone()
105    }
106
107    pub fn hash(&self, inp: &[Fr]) -> Result<Fr, String> {
108        let t = inp.len() + 1;
109        if inp.is_empty() || inp.len() > self.constants.n_rounds_p.len() {
110            return Err("Wrong inputs length".to_string());
111        }
112        let n_rounds_f = self.constants.n_rounds_f;
113        let n_rounds_p = self.constants.n_rounds_p[t - 2];
114
115        let mut state = vec![Fr::zero(); t];
116        state[1..].clone_from_slice(inp);
117
118        for i in 0..(n_rounds_f + n_rounds_p) {
119            self.ark(&mut state, &self.constants.c[t - 2], i * t);
120            self.sbox(n_rounds_f, n_rounds_p, &mut state, i);
121            state = self.mix(&state, &self.constants.m[t - 2]);
122        }
123
124        Ok(state[0])
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[cfg(test)]
133    #[cfg(target_arch = "wasm32")]
134    use wasm_bindgen_test::wasm_bindgen_test as test;
135
136    #[test]
137    fn test_load_constants() {
138        let cons = load_constants();
139        assert_eq!(
140            cons.c[0][0].to_string(),
141            "4417881134626180770308697923359573201005643519861877412381846989312604493735"
142        );
143        assert_eq!(
144            cons.c[cons.c.len() - 1][0].to_string(),
145            "21579410516734741630578831791708254656585702717204712919233299001262271512412"
146        );
147        assert_eq!(
148            cons.m[0][0][0].to_string(),
149            "2910766817845651019878574839501801340070030115151021261302834310722729507541"
150        );
151        assert_eq!(
152            cons.m[cons.m.len() - 1][0][0].to_string(),
153            "11497693837059016825308731789443585196852778517742143582474723527597064448312"
154        );
155    }
156
157    #[test]
158    fn test_hash() {
159        let b0: Fr = Fr::from_str("0").unwrap();
160        let b1: Fr = Fr::from_str("1").unwrap();
161        let b2: Fr = Fr::from_str("2").unwrap();
162        let b3: Fr = Fr::from_str("3").unwrap();
163        let b4: Fr = Fr::from_str("4").unwrap();
164        let b5: Fr = Fr::from_str("5").unwrap();
165        let b6: Fr = Fr::from_str("6").unwrap();
166        let b7: Fr = Fr::from_str("7").unwrap();
167        let b8: Fr = Fr::from_str("8").unwrap();
168        let b9: Fr = Fr::from_str("9").unwrap();
169        let b10: Fr = Fr::from_str("10").unwrap();
170        let b11: Fr = Fr::from_str("11").unwrap();
171        let b12: Fr = Fr::from_str("12").unwrap();
172        let b13: Fr = Fr::from_str("13").unwrap();
173        let b14: Fr = Fr::from_str("14").unwrap();
174        let b15: Fr = Fr::from_str("15").unwrap();
175        let b16: Fr = Fr::from_str("16").unwrap();
176
177        let poseidon = Poseidon::new();
178
179        let big_arr: Vec<Fr> = vec![b1];
180        let h = poseidon.hash(&big_arr).unwrap();
181        assert_eq!(
182            h.to_string(),
183            "18586133768512220936620570745912940619677854269274689475585506675881198879027"
184        );
185
186        let big_arr: Vec<Fr> = vec![b1, b2];
187        let h = poseidon.hash(&big_arr).unwrap();
188        assert_eq!(
189            h.to_string(),
190            "7853200120776062878684798364095072458815029376092732009249414926327459813530"
191        );
192
193        let big_arr: Vec<Fr> = vec![b1, b2, b0, b0, b0];
194        let h = poseidon.hash(&big_arr).unwrap();
195        assert_eq!(
196            h.to_string(),
197            "1018317224307729531995786483840663576608797660851238720571059489595066344487"
198        );
199
200        let big_arr: Vec<Fr> = vec![b1, b2, b0, b0, b0, b0];
201        let h = poseidon.hash(&big_arr).unwrap();
202        assert_eq!(
203            h.to_string(),
204            "15336558801450556532856248569924170992202208561737609669134139141992924267169"
205        );
206
207        let big_arr: Vec<Fr> = vec![b3, b4, b0, b0, b0];
208        let h = poseidon.hash(&big_arr).unwrap();
209        assert_eq!(
210            h.to_string(),
211            "5811595552068139067952687508729883632420015185677766880877743348592482390548"
212        );
213
214        let big_arr: Vec<Fr> = vec![b3, b4, b0, b0, b0, b0];
215        let h = poseidon.hash(&big_arr).unwrap();
216        assert_eq!(
217            h.to_string(),
218            "12263118664590987767234828103155242843640892839966517009184493198782366909018"
219        );
220
221        let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6];
222        let h = poseidon.hash(&big_arr).unwrap();
223        assert_eq!(
224            h.to_string(),
225            "20400040500897583745843009878988256314335038853985262692600694741116813247201"
226        );
227
228        let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14];
229        let h = poseidon.hash(&big_arr).unwrap();
230        assert_eq!(
231            h.to_string(),
232            "8354478399926161176778659061636406690034081872658507739535256090879947077494"
233        );
234
235        let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0];
236        let h = poseidon.hash(&big_arr).unwrap();
237        assert_eq!(
238            h.to_string(),
239            "5540388656744764564518487011617040650780060800286365721923524861648744699539"
240        );
241
242        let big_arr: Vec<Fr> = vec![
243            b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0, b0, b0,
244        ];
245        let h = poseidon.hash(&big_arr).unwrap();
246        assert_eq!(
247            h.to_string(),
248            "11882816200654282475720830292386643970958445617880627439994635298904836126497"
249        );
250
251        let big_arr: Vec<Fr> = vec![
252            b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16,
253        ];
254        let h = poseidon.hash(&big_arr).unwrap();
255        assert_eq!(
256            h.to_string(),
257            "9989051620750914585850546081941653841776809718687451684622678807385399211877"
258        );
259    }
260    #[test]
261    fn test_wrong_inputs() {
262        let b0: Fr = Fr::from_str("0").unwrap();
263        let b1: Fr = Fr::from_str("1").unwrap();
264        let b2: Fr = Fr::from_str("2").unwrap();
265
266        let poseidon = Poseidon::new();
267
268        let big_arr: Vec<Fr> = vec![
269            b1, b2, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0,
270        ];
271        poseidon.hash(&big_arr).expect_err("Wrong inputs length");
272    }
273}