sui_rpc/field/
field_mask_util.rs

1use super::FieldMaskTree;
2use super::MessageField;
3use super::MessageFields;
4use super::FIELD_PATH_SEPARATOR;
5use super::FIELD_PATH_WILDCARD;
6use super::FIELD_SEPARATOR;
7
8use prost_types::FieldMask;
9
10pub trait FieldMaskUtil: sealed::Sealed {
11    fn normalize(self) -> FieldMask;
12
13    fn from_str(s: &str) -> FieldMask;
14
15    fn from_paths<I: AsRef<str>, T: IntoIterator<Item = I>>(paths: T) -> FieldMask;
16
17    fn display(&self) -> impl std::fmt::Display + '_;
18
19    fn validate<M: MessageFields>(&self) -> Result<(), &str>;
20}
21
22impl FieldMaskUtil for FieldMask {
23    fn normalize(self) -> FieldMask {
24        FieldMaskTree::from(self).to_field_mask()
25    }
26
27    fn from_str(s: &str) -> FieldMask {
28        Self::from_paths(s.split(FIELD_PATH_SEPARATOR))
29    }
30
31    fn from_paths<I: AsRef<str>, T: IntoIterator<Item = I>>(paths: T) -> FieldMask {
32        FieldMask {
33            paths: paths
34                .into_iter()
35                .filter_map(|path| {
36                    let path = path.as_ref();
37                    if path.is_empty() {
38                        None
39                    } else {
40                        Some(path.to_owned())
41                    }
42                })
43                .collect(),
44        }
45    }
46
47    fn display(&self) -> impl std::fmt::Display + '_ {
48        FieldMaskDisplay(self)
49    }
50
51    fn validate<M: MessageFields>(&self) -> Result<(), &str> {
52        // Determine if the provided path matches one of the provided fields. If a path matches a
53        // field and that field is a message type (which can have its own set of fields), attempt
54        // to match the remainder of the path against a field in the sub_message.
55        fn is_valid_path(mut fields: &[&MessageField], mut path: &str) -> bool {
56            loop {
57                let (field_name, remainder) = path
58                    .split_once(FIELD_SEPARATOR)
59                    .map(|(field, remainder)| (field, (!remainder.is_empty()).then_some(remainder)))
60                    .unwrap_or((path, None));
61
62                if let Some(field) = fields.iter().find(|field| field.name == field_name) {
63                    match (field.message_fields, remainder) {
64                        (None, None) | (Some(_), None) => return true,
65                        (None, Some(_)) => return false,
66                        (Some(sub_message_fields), Some(remainder)) => {
67                            fields = sub_message_fields;
68                            path = remainder;
69                        }
70                    }
71                } else {
72                    return false;
73                }
74            }
75        }
76
77        for path in &self.paths {
78            if path == FIELD_PATH_WILDCARD {
79                continue;
80            }
81            if !is_valid_path(M::FIELDS, path) {
82                return Err(path);
83            }
84        }
85
86        Ok(())
87    }
88}
89
90struct FieldMaskDisplay<'a>(&'a FieldMask);
91
92impl std::fmt::Display for FieldMaskDisplay<'_> {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        use std::fmt::Write;
95
96        let mut first = true;
97
98        for path in &self.0.paths {
99            // Ignore empty paths
100            if path.is_empty() {
101                continue;
102            }
103
104            // If this isn't the first path we've printed,
105            // we need to print a FIELD_PATH_SEPARATOR character
106            if first {
107                first = false;
108            } else {
109                f.write_char(FIELD_PATH_SEPARATOR)?;
110            }
111            f.write_str(path)?;
112        }
113
114        Ok(())
115    }
116}
117
118mod sealed {
119    pub trait Sealed {}
120
121    impl Sealed for prost_types::FieldMask {}
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_to_string() {
130        assert!(FieldMask::display(&FieldMask::default())
131            .to_string()
132            .is_empty());
133
134        let mask = FieldMask::from_paths(["foo"]);
135        assert_eq!(FieldMask::display(&mask).to_string(), "foo");
136        assert_eq!(mask.display().to_string(), "foo");
137        let mask = FieldMask::from_paths(["foo", "bar"]);
138        assert_eq!(FieldMask::display(&mask).to_string(), "foo,bar");
139
140        // empty paths are ignored
141        let mask = FieldMask::from_paths(["", "foo", "", "bar", ""]);
142        assert_eq!(FieldMask::display(&mask).to_string(), "foo,bar");
143    }
144
145    #[test]
146    fn test_from_str() {
147        let mask = FieldMask::from_str("");
148        assert!(mask.paths.is_empty());
149
150        let mask = FieldMask::from_str("foo");
151        assert_eq!(mask.paths.len(), 1);
152        assert_eq!(mask.paths[0], "foo");
153
154        let mask = FieldMask::from_str("foo,bar.baz");
155        assert_eq!(mask.paths.len(), 2);
156        assert_eq!(mask.paths[0], "foo");
157        assert_eq!(mask.paths[1], "bar.baz");
158
159        // empty field paths are ignored
160        let mask = FieldMask::from_str(",foo,,bar,");
161        assert_eq!(mask.paths.len(), 2);
162        assert_eq!(mask.paths[0], "foo");
163        assert_eq!(mask.paths[1], "bar");
164    }
165
166    #[test]
167    fn test_validate() {
168        struct Foo;
169        impl MessageFields for Foo {
170            const FIELDS: &'static [&'static MessageField] = &[
171                &MessageField::new("bar").with_message_fields(Bar::FIELDS),
172                &MessageField::new("baz"),
173            ];
174        }
175        struct Bar;
176
177        impl MessageFields for Bar {
178            const FIELDS: &'static [&'static MessageField] = &[
179                &MessageField {
180                    name: "a",
181                    json_name: "a",
182                    number: 1,
183                    message_fields: None,
184                },
185                &MessageField {
186                    name: "b",
187                    json_name: "b",
188                    number: 2,
189                    message_fields: None,
190                },
191            ];
192        }
193
194        let mask = FieldMask::from_str("");
195        assert_eq!(mask.validate::<Foo>(), Ok(()));
196        let mask = FieldMask::from_str("bar");
197        assert_eq!(mask.validate::<Foo>(), Ok(()));
198        let mask = FieldMask::from_str("bar.a");
199        assert_eq!(mask.validate::<Foo>(), Ok(()));
200        let mask = FieldMask::from_str("bar.a,bar.b");
201        assert_eq!(mask.validate::<Foo>(), Ok(()));
202        let mask = FieldMask::from_str("bar.a,bar.b,bar.c");
203        assert_eq!(mask.validate::<Foo>(), Err("bar.c"));
204        let mask = FieldMask::from_str("baz");
205        assert_eq!(mask.validate::<Foo>(), Ok(()));
206        let mask = FieldMask::from_str("baz.a");
207        assert_eq!(mask.validate::<Foo>(), Err("baz.a"));
208        let mask = FieldMask::from_str("foobar");
209        assert_eq!(mask.validate::<Foo>(), Err("foobar"));
210    }
211}