sui_crypto/zklogin/poseidon/
mod.rs1#![allow(clippy::needless_range_loop)]
7
8use ark_bn254::Fr;
9use ark_ff::fields::Field;
10use ark_std::str::FromStr;
11use ark_std::Zero;
12use core::ops::AddAssign;
13use core::ops::MulAssign;
14
15mod constants;
16
17pub static POSEIDON: std::sync::LazyLock<Poseidon> = std::sync::LazyLock::new(Poseidon::new);
18
19#[derive(Debug)]
20struct Constants {
21 pub c: Vec<Vec<Fr>>,
22 pub m: Vec<Vec<Vec<Fr>>>,
23 pub n_rounds_f: usize,
24 pub n_rounds_p: Vec<usize>,
25}
26
27fn load_constants() -> Constants {
28 let (c_str, m_str) = constants::constants();
29 let mut c: Vec<Vec<Fr>> = Vec::new();
30 for i in 0..c_str.len() {
31 let mut cci: Vec<Fr> = Vec::new();
32 for j in 0..c_str[i].len() {
33 let b: Fr = Fr::from_str(c_str[i][j]).unwrap();
34 cci.push(b);
35 }
36 c.push(cci);
37 }
38 let mut m: Vec<Vec<Vec<Fr>>> = Vec::new();
39 for i in 0..m_str.len() {
40 let mut mi: Vec<Vec<Fr>> = Vec::new();
41 for j in 0..m_str[i].len() {
42 let mut mij: Vec<Fr> = Vec::new();
43 for k in 0..m_str[i][j].len() {
44 let b: Fr = Fr::from_str(m_str[i][j][k]).unwrap();
45 mij.push(b);
46 }
47 mi.push(mij);
48 }
49 m.push(mi);
50 }
51 Constants {
52 c,
53 m,
54 n_rounds_f: 8,
55 n_rounds_p: vec![
56 56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68,
57 ],
58 }
59}
60
61pub struct Poseidon {
62 constants: Constants,
63}
64
65impl Poseidon {
66 pub fn new() -> Poseidon {
67 Poseidon {
68 constants: load_constants(),
69 }
70 }
71
72 fn ark(&self, state: &mut [Fr], c: &[Fr], it: usize) {
73 for i in 0..state.len() {
74 state[i].add_assign(&c[it + i]);
75 }
76 }
77
78 fn sbox(&self, n_rounds_f: usize, n_rounds_p: usize, state: &mut [Fr], i: usize) {
79 if i < n_rounds_f / 2 || i >= n_rounds_f / 2 + n_rounds_p {
80 for j in 0..state.len() {
81 let aux = state[j];
82 state[j] = state[j].square();
83 state[j] = state[j].square();
84 state[j].mul_assign(&aux);
85 }
86 } else {
87 let aux = state[0];
88 state[0] = state[0].square();
89 state[0] = state[0].square();
90 state[0].mul_assign(&aux);
91 }
92 }
93
94 fn mix(&self, state: &[Fr], m: &[Vec<Fr>]) -> Vec<Fr> {
95 let mut new_state: Vec<Fr> = Vec::new();
96 for i in 0..state.len() {
97 new_state.push(Fr::zero());
98 for j in 0..state.len() {
99 let mut mij = m[i][j];
100 mij.mul_assign(&state[j]);
101 new_state[i].add_assign(&mij);
102 }
103 }
104 new_state.clone()
105 }
106
107 pub fn hash(&self, inp: &[Fr]) -> Result<Fr, String> {
108 let t = inp.len() + 1;
109 if inp.is_empty() || inp.len() > self.constants.n_rounds_p.len() {
110 return Err("Wrong inputs length".to_string());
111 }
112 let n_rounds_f = self.constants.n_rounds_f;
113 let n_rounds_p = self.constants.n_rounds_p[t - 2];
114
115 let mut state = vec![Fr::zero(); t];
116 state[1..].clone_from_slice(inp);
117
118 for i in 0..(n_rounds_f + n_rounds_p) {
119 self.ark(&mut state, &self.constants.c[t - 2], i * t);
120 self.sbox(n_rounds_f, n_rounds_p, &mut state, i);
121 state = self.mix(&state, &self.constants.m[t - 2]);
122 }
123
124 Ok(state[0])
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[cfg(test)]
133 #[cfg(target_arch = "wasm32")]
134 use wasm_bindgen_test::wasm_bindgen_test as test;
135
136 #[test]
137 fn test_load_constants() {
138 let cons = load_constants();
139 assert_eq!(
140 cons.c[0][0].to_string(),
141 "4417881134626180770308697923359573201005643519861877412381846989312604493735"
142 );
143 assert_eq!(
144 cons.c[cons.c.len() - 1][0].to_string(),
145 "21579410516734741630578831791708254656585702717204712919233299001262271512412"
146 );
147 assert_eq!(
148 cons.m[0][0][0].to_string(),
149 "2910766817845651019878574839501801340070030115151021261302834310722729507541"
150 );
151 assert_eq!(
152 cons.m[cons.m.len() - 1][0][0].to_string(),
153 "11497693837059016825308731789443585196852778517742143582474723527597064448312"
154 );
155 }
156
157 #[test]
158 fn test_hash() {
159 let b0: Fr = Fr::from_str("0").unwrap();
160 let b1: Fr = Fr::from_str("1").unwrap();
161 let b2: Fr = Fr::from_str("2").unwrap();
162 let b3: Fr = Fr::from_str("3").unwrap();
163 let b4: Fr = Fr::from_str("4").unwrap();
164 let b5: Fr = Fr::from_str("5").unwrap();
165 let b6: Fr = Fr::from_str("6").unwrap();
166 let b7: Fr = Fr::from_str("7").unwrap();
167 let b8: Fr = Fr::from_str("8").unwrap();
168 let b9: Fr = Fr::from_str("9").unwrap();
169 let b10: Fr = Fr::from_str("10").unwrap();
170 let b11: Fr = Fr::from_str("11").unwrap();
171 let b12: Fr = Fr::from_str("12").unwrap();
172 let b13: Fr = Fr::from_str("13").unwrap();
173 let b14: Fr = Fr::from_str("14").unwrap();
174 let b15: Fr = Fr::from_str("15").unwrap();
175 let b16: Fr = Fr::from_str("16").unwrap();
176
177 let poseidon = Poseidon::new();
178
179 let big_arr: Vec<Fr> = vec![b1];
180 let h = poseidon.hash(&big_arr).unwrap();
181 assert_eq!(
182 h.to_string(),
183 "18586133768512220936620570745912940619677854269274689475585506675881198879027"
184 );
185
186 let big_arr: Vec<Fr> = vec![b1, b2];
187 let h = poseidon.hash(&big_arr).unwrap();
188 assert_eq!(
189 h.to_string(),
190 "7853200120776062878684798364095072458815029376092732009249414926327459813530"
191 );
192
193 let big_arr: Vec<Fr> = vec![b1, b2, b0, b0, b0];
194 let h = poseidon.hash(&big_arr).unwrap();
195 assert_eq!(
196 h.to_string(),
197 "1018317224307729531995786483840663576608797660851238720571059489595066344487"
198 );
199
200 let big_arr: Vec<Fr> = vec![b1, b2, b0, b0, b0, b0];
201 let h = poseidon.hash(&big_arr).unwrap();
202 assert_eq!(
203 h.to_string(),
204 "15336558801450556532856248569924170992202208561737609669134139141992924267169"
205 );
206
207 let big_arr: Vec<Fr> = vec![b3, b4, b0, b0, b0];
208 let h = poseidon.hash(&big_arr).unwrap();
209 assert_eq!(
210 h.to_string(),
211 "5811595552068139067952687508729883632420015185677766880877743348592482390548"
212 );
213
214 let big_arr: Vec<Fr> = vec![b3, b4, b0, b0, b0, b0];
215 let h = poseidon.hash(&big_arr).unwrap();
216 assert_eq!(
217 h.to_string(),
218 "12263118664590987767234828103155242843640892839966517009184493198782366909018"
219 );
220
221 let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6];
222 let h = poseidon.hash(&big_arr).unwrap();
223 assert_eq!(
224 h.to_string(),
225 "20400040500897583745843009878988256314335038853985262692600694741116813247201"
226 );
227
228 let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14];
229 let h = poseidon.hash(&big_arr).unwrap();
230 assert_eq!(
231 h.to_string(),
232 "8354478399926161176778659061636406690034081872658507739535256090879947077494"
233 );
234
235 let big_arr: Vec<Fr> = vec![b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0];
236 let h = poseidon.hash(&big_arr).unwrap();
237 assert_eq!(
238 h.to_string(),
239 "5540388656744764564518487011617040650780060800286365721923524861648744699539"
240 );
241
242 let big_arr: Vec<Fr> = vec![
243 b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0, b0, b0,
244 ];
245 let h = poseidon.hash(&big_arr).unwrap();
246 assert_eq!(
247 h.to_string(),
248 "11882816200654282475720830292386643970958445617880627439994635298904836126497"
249 );
250
251 let big_arr: Vec<Fr> = vec![
252 b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16,
253 ];
254 let h = poseidon.hash(&big_arr).unwrap();
255 assert_eq!(
256 h.to_string(),
257 "9989051620750914585850546081941653841776809718687451684622678807385399211877"
258 );
259 }
260 #[test]
261 fn test_wrong_inputs() {
262 let b0: Fr = Fr::from_str("0").unwrap();
263 let b1: Fr = Fr::from_str("1").unwrap();
264 let b2: Fr = Fr::from_str("2").unwrap();
265
266 let poseidon = Poseidon::new();
267
268 let big_arr: Vec<Fr> = vec![
269 b1, b2, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0,
270 ];
271 poseidon.hash(&big_arr).expect_err("Wrong inputs length");
272 }
273}