sui_open_rpc/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4extern crate core;
5
6use std::collections::btree_map::Entry::Occupied;
7use std::collections::{BTreeMap, HashMap};
8
9use schemars::JsonSchema;
10use schemars::r#gen::{SchemaGenerator, SchemaSettings};
11use schemars::schema::SchemaObject;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use versions::Versioning;
15
16/// OPEN-RPC documentation following the OpenRPC specification <https://spec.open-rpc.org>
17/// The implementation is partial, only required fields and subset of optional fields
18/// in the specification are implemented catered to Sui's need.
19#[derive(Serialize, Deserialize, Clone)]
20pub struct Project {
21    openrpc: String,
22    info: Info,
23    methods: Vec<Method>,
24    components: Components,
25    // Method routing for backward compatibility, not part of the open rpc spec.
26    #[serde(skip)]
27    pub method_routing: HashMap<String, MethodRouting>,
28}
29
30impl Project {
31    pub fn new(
32        version: &str,
33        title: &str,
34        description: &str,
35        contact_name: &str,
36        url: &str,
37        email: &str,
38        license: &str,
39        license_url: &str,
40    ) -> Self {
41        let openrpc = "1.2.6".to_string();
42        Self {
43            openrpc,
44            info: Info {
45                title: title.to_string(),
46                description: Some(description.to_string()),
47                contact: Some(Contact {
48                    name: contact_name.to_string(),
49                    url: Some(url.to_string()),
50                    email: Some(email.to_string()),
51                }),
52                license: Some(License {
53                    name: license.to_string(),
54                    url: Some(license_url.to_string()),
55                }),
56                version: version.to_owned(),
57                ..Default::default()
58            },
59            methods: vec![],
60            components: Components {
61                content_descriptors: Default::default(),
62                schemas: Default::default(),
63            },
64            method_routing: Default::default(),
65        }
66    }
67
68    pub fn add_module(&mut self, module: Module) {
69        self.methods.extend(module.methods);
70
71        self.methods.sort_by(|m, n| m.name.cmp(&n.name));
72
73        self.components.schemas.extend(module.components.schemas);
74        self.components
75            .content_descriptors
76            .extend(module.components.content_descriptors);
77        self.method_routing.extend(module.method_routing);
78    }
79
80    pub fn add_examples(&mut self, mut example_provider: BTreeMap<String, Vec<ExamplePairing>>) {
81        for method in &mut self.methods {
82            if let Occupied(entry) = example_provider.entry(method.name.clone()) {
83                let examples = entry.remove();
84                let param_names = method
85                    .params
86                    .iter()
87                    .map(|p| p.name.clone())
88                    .collect::<Vec<_>>();
89
90                // Make sure example's parameters are correct.
91                for example in examples.iter() {
92                    let example_param_names = example
93                        .params
94                        .iter()
95                        .map(|param| param.name.clone())
96                        .collect::<Vec<_>>();
97                    assert_eq!(
98                        param_names, example_param_names,
99                        "Provided example parameters doesn't match the function parameters."
100                    );
101                }
102
103                method.examples = examples
104            } else {
105                println!("No example found for method: {}", method.name);
106            }
107        }
108    }
109}
110
111pub struct Module {
112    methods: Vec<Method>,
113    components: Components,
114    method_routing: BTreeMap<String, MethodRouting>,
115}
116
117pub struct RpcModuleDocBuilder {
118    schema_generator: SchemaGenerator,
119    methods: BTreeMap<String, Method>,
120    method_routing: BTreeMap<String, MethodRouting>,
121    content_descriptors: BTreeMap<String, ContentDescriptor>,
122}
123
124#[derive(Serialize, Deserialize, Default, Clone)]
125pub struct ContentDescriptor {
126    name: String,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    summary: Option<String>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    description: Option<String>,
131    #[serde(skip_serializing_if = "default")]
132    required: bool,
133    schema: SchemaObject,
134    #[serde(skip_serializing_if = "default")]
135    deprecated: bool,
136}
137
138#[derive(Serialize, Deserialize, Default, Clone)]
139struct Method {
140    name: String,
141    #[serde(skip_serializing_if = "Vec::is_empty")]
142    tags: Vec<Tag>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    description: Option<String>,
145    params: Vec<ContentDescriptor>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    result: Option<ContentDescriptor>,
148    #[serde(skip_serializing_if = "Vec::is_empty")]
149    examples: Vec<ExamplePairing>,
150    #[serde(skip_serializing_if = "std::ops::Not::not")]
151    deprecated: bool,
152}
153#[derive(Clone, Debug)]
154pub struct MethodRouting {
155    min: Option<Versioning>,
156    max: Option<Versioning>,
157    pub route_to: String,
158}
159
160impl MethodRouting {
161    pub fn le(version: &str, route_to: &str) -> Self {
162        Self {
163            min: None,
164            max: Some(Versioning::new(version).unwrap()),
165            route_to: route_to.to_string(),
166        }
167    }
168
169    pub fn eq(version: &str, route_to: &str) -> Self {
170        Self {
171            min: Some(Versioning::new(version).unwrap()),
172            max: Some(Versioning::new(version).unwrap()),
173            route_to: route_to.to_string(),
174        }
175    }
176
177    pub fn matches(&self, version: &str) -> bool {
178        let version = Versioning::new(version);
179        match (&version, &self.min, &self.max) {
180            (Some(version), None, Some(max)) => version <= max,
181            (Some(version), Some(min), None) => version >= min,
182            (Some(version), Some(min), Some(max)) => version >= min && version <= max,
183            (_, _, _) => false,
184        }
185    }
186}
187
188#[test]
189fn test_version_matching() {
190    let routing = MethodRouting::eq("1.5", "test");
191    assert!(routing.matches("1.5"));
192    assert!(!routing.matches("1.6"));
193    assert!(!routing.matches("1.4"));
194
195    let routing = MethodRouting::le("1.5", "test");
196    assert!(routing.matches("1.5"));
197    assert!(routing.matches("1.4.5"));
198    assert!(routing.matches("1.4"));
199    assert!(routing.matches("1.3"));
200
201    assert!(!routing.matches("1.6"));
202    assert!(!routing.matches("1.5.1"));
203}
204
205#[derive(Serialize, Deserialize, Default, Clone)]
206pub struct ExamplePairing {
207    name: String,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    description: Option<String>,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    summary: Option<String>,
212    params: Vec<Example>,
213    result: Example,
214}
215
216impl ExamplePairing {
217    pub fn new(name: &str, params: Vec<(&str, Value)>, result: Value) -> Self {
218        Self {
219            name: name.to_string(),
220            description: None,
221            summary: None,
222            params: params
223                .into_iter()
224                .map(|(name, value)| Example {
225                    name: name.to_string(),
226                    summary: None,
227                    description: None,
228                    value,
229                })
230                .collect(),
231            result: Example {
232                name: "Result".to_string(),
233                summary: None,
234                description: None,
235                value: result,
236            },
237        }
238    }
239}
240
241#[derive(Serialize, Deserialize, Default, Clone)]
242pub struct Example {
243    name: String,
244    #[serde(skip_serializing_if = "Option::is_none")]
245    summary: Option<String>,
246    #[serde(skip_serializing_if = "Option::is_none")]
247    description: Option<String>,
248    value: Value,
249}
250
251#[derive(Serialize, Deserialize, Default, Clone)]
252struct Tag {
253    name: String,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    summary: Option<String>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    description: Option<String>,
258}
259
260impl Tag {
261    pub fn new(name: &str) -> Self {
262        Self {
263            name: name.to_string(),
264            summary: None,
265            description: None,
266        }
267    }
268}
269
270#[derive(Serialize, Deserialize, Default, Clone)]
271#[serde(rename_all = "camelCase")]
272struct Info {
273    title: String,
274    #[serde(skip_serializing_if = "Option::is_none")]
275    description: Option<String>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    terms_of_service: Option<String>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    contact: Option<Contact>,
280    #[serde(skip_serializing_if = "Option::is_none")]
281    license: Option<License>,
282    version: String,
283}
284
285fn default<T>(value: &T) -> bool
286where
287    T: Default + PartialEq,
288{
289    value == &T::default()
290}
291
292#[derive(Serialize, Deserialize, Default, Clone)]
293struct Contact {
294    name: String,
295    #[serde(skip_serializing_if = "Option::is_none")]
296    url: Option<String>,
297    #[serde(skip_serializing_if = "Option::is_none")]
298    email: Option<String>,
299}
300#[derive(Serialize, Deserialize, Default, Clone)]
301struct License {
302    name: String,
303    #[serde(skip_serializing_if = "Option::is_none")]
304    url: Option<String>,
305}
306
307impl Default for RpcModuleDocBuilder {
308    fn default() -> Self {
309        let schema_generator = SchemaSettings::default()
310            .with(|s| {
311                s.definitions_path = "#/components/schemas/".to_string();
312            })
313            .into_generator();
314
315        Self {
316            schema_generator,
317            methods: BTreeMap::new(),
318            method_routing: Default::default(),
319            content_descriptors: BTreeMap::new(),
320        }
321    }
322}
323
324impl RpcModuleDocBuilder {
325    pub fn build(mut self) -> Module {
326        Module {
327            methods: self.methods.into_values().collect(),
328            components: Components {
329                content_descriptors: self.content_descriptors,
330                schemas: self
331                    .schema_generator
332                    .root_schema_for::<u8>()
333                    .definitions
334                    .into_iter()
335                    .map(|(name, schema)| (name, schema.into_object()))
336                    .collect::<BTreeMap<_, _>>(),
337            },
338            method_routing: self.method_routing,
339        }
340    }
341
342    pub fn add_method_routing(
343        &mut self,
344        namespace: &str,
345        name: &str,
346        route_to: &str,
347        comparator: &str,
348        version: &str,
349    ) {
350        let name = format!("{namespace}_{name}");
351        let route_to = format!("{namespace}_{route_to}");
352        let routing = match comparator {
353            "<=" => MethodRouting::le(version, &route_to),
354            "=" => MethodRouting::eq(version, &route_to),
355            _ => panic!("Unsupported version comparator {comparator}"),
356        };
357        if self.method_routing.insert(name.clone(), routing).is_some() {
358            panic!("Routing for method [{name}] already exists.")
359        }
360    }
361
362    pub fn add_method(
363        &mut self,
364        namespace: &str,
365        name: &str,
366        params: Vec<ContentDescriptor>,
367        result: Option<ContentDescriptor>,
368        doc: &str,
369        tag: Option<String>,
370        deprecated: bool,
371    ) {
372        let tags = tag.map(|t| Tag::new(&t)).into_iter().collect::<Vec<_>>();
373        self.add_method_internal(namespace, name, params, result, doc, tags, deprecated)
374    }
375
376    pub fn add_subscription(
377        &mut self,
378        namespace: &str,
379        name: &str,
380        params: Vec<ContentDescriptor>,
381        result: Option<ContentDescriptor>,
382        doc: &str,
383        tag: Option<String>,
384        deprecated: bool,
385    ) {
386        let mut tags = tag.map(|t| Tag::new(&t)).into_iter().collect::<Vec<_>>();
387        tags.push(Tag::new("Websocket"));
388        tags.push(Tag::new("PubSub"));
389        self.add_method_internal(namespace, name, params, result, doc, tags, deprecated)
390    }
391
392    fn add_method_internal(
393        &mut self,
394        namespace: &str,
395        name: &str,
396        params: Vec<ContentDescriptor>,
397        result: Option<ContentDescriptor>,
398        doc: &str,
399        tags: Vec<Tag>,
400        deprecated: bool,
401    ) {
402        let description = if doc.trim().is_empty() {
403            None
404        } else {
405            Some(doc.trim().to_string())
406        };
407        let name = format!("{}_{}", namespace, name);
408        self.methods.insert(
409            name.clone(),
410            Method {
411                name,
412                description,
413                params,
414                result,
415                tags,
416                examples: Vec::new(),
417                deprecated,
418            },
419        );
420    }
421
422    pub fn create_content_descriptor<T: JsonSchema>(
423        &mut self,
424        name: &str,
425        summary: Option<String>,
426        description: Option<String>,
427        required: bool,
428    ) -> ContentDescriptor {
429        let schema = self.schema_generator.subschema_for::<T>().into_object();
430        ContentDescriptor {
431            name: name.replace(' ', ""),
432            summary,
433            description,
434            required,
435            schema,
436            deprecated: false,
437        }
438    }
439}
440
441#[derive(Serialize, Deserialize, Clone)]
442#[serde(rename_all = "camelCase")]
443struct Components {
444    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
445    content_descriptors: BTreeMap<String, ContentDescriptor>,
446    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
447    schemas: BTreeMap<String, SchemaObject>,
448}