typed_store_derive/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::BTreeMap;
5
6use itertools::Itertools;
7use proc_macro::TokenStream;
8use proc_macro2::Ident;
9use quote::quote;
10use syn::Type::{self};
11use syn::{
12    AngleBracketedGenericArguments, Attribute, Generics, ItemStruct, Lit, Meta, PathArguments,
13    parse_macro_input,
14};
15
16// This is used as default when none is specified
17const DEFAULT_DB_OPTIONS_CUSTOM_FN: &str = "typed_store::rocks::default_db_options";
18// Custom function which returns the option and overrides the defaults for this table
19const DB_OPTIONS_CUSTOM_FUNCTION: &str = "default_options_override_fn";
20// Use a different name for the column than the identifier
21const DB_OPTIONS_RENAME: &str = "rename";
22// Deprecate a column family
23const DB_OPTIONS_DEPRECATE: &str = "deprecated";
24
25/// Options can either be simplified form or
26enum GeneralTableOptions {
27    OverrideFunction(String),
28}
29
30impl Default for GeneralTableOptions {
31    fn default() -> Self {
32        Self::OverrideFunction(DEFAULT_DB_OPTIONS_CUSTOM_FN.to_owned())
33    }
34}
35
36// Extracts the field names, field types, inner types (K,V in {map_type_name}<K, V>), and the options attrs
37fn extract_struct_info(input: ItemStruct) -> ExtractedStructInfo {
38    let mut deprecated_cfs = vec![];
39
40    let info = input.fields.iter().map(|f| {
41        let attrs: BTreeMap<_, _> = f
42            .attrs
43            .iter()
44            .filter(|a| {
45                a.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION)
46                    || a.path.is_ident(DB_OPTIONS_RENAME)
47                    || a.path.is_ident(DB_OPTIONS_DEPRECATE)
48            })
49            .map(|a| (a.path.get_ident().unwrap().to_string(), a))
50            .collect();
51
52        let options = if let Some(options) = attrs.get(DB_OPTIONS_CUSTOM_FUNCTION) {
53            GeneralTableOptions::OverrideFunction(get_options_override_function(options).unwrap())
54        } else {
55            GeneralTableOptions::default()
56        };
57
58        let ty = &f.ty;
59        if let Type::Path(p) = ty {
60            let type_info = &p.path.segments.first().unwrap();
61            let inner_type =
62                if let PathArguments::AngleBracketed(angle_bracket_type) = &type_info.arguments {
63                    angle_bracket_type.clone()
64                } else {
65                    panic!("All struct members must be of type DBMap");
66                };
67
68            let type_str = format!("{}", &type_info.ident);
69            if type_str == "DBMap" {
70                let field_name = f.ident.as_ref().unwrap().clone();
71                let cf_name = if let Some(rename) = attrs.get(DB_OPTIONS_RENAME) {
72                    match rename.parse_meta().expect("Cannot parse meta of attribute") {
73                        Meta::NameValue(val) => {
74                            if let Lit::Str(s) = val.lit {
75                                // convert to ident
76                                s.parse().expect("Rename value must be identifier")
77                            } else {
78                                panic!("Expected string value for rename")
79                            }
80                        }
81                        _ => panic!("Expected string value for rename"),
82                    }
83                } else {
84                    field_name.clone()
85                };
86                if attrs.contains_key(DB_OPTIONS_DEPRECATE) {
87                    deprecated_cfs.push(field_name.clone());
88                }
89
90                return ((field_name, cf_name, type_str), (inner_type, options));
91            } else {
92                panic!("All struct members must be of type DBMap");
93            }
94        }
95        panic!("All struct members must be of type DBMap");
96    });
97
98    let (field_info, inner_types_with_opts): (Vec<_>, Vec<_>) = info.unzip();
99    let (field_names, cf_names, simple_field_type_names): (Vec<_>, Vec<_>, Vec<_>) =
100        field_info.into_iter().multiunzip();
101
102    // Check for homogeneous types
103    if let Some(first) = simple_field_type_names.first() {
104        simple_field_type_names.iter().for_each(|q| {
105            if q != first {
106                panic!("All struct members must be of same type");
107            }
108        })
109    } else {
110        panic!("Cannot derive on empty struct");
111    };
112
113    let (inner_types, options): (Vec<_>, Vec<_>) = inner_types_with_opts.into_iter().unzip();
114
115    ExtractedStructInfo {
116        field_names,
117        cf_names,
118        inner_types,
119        derived_table_options: options,
120        deprecated_cfs,
121    }
122}
123
124/// Extracts the table options override function
125/// The function must take no args and return Options
126fn get_options_override_function(attr: &Attribute) -> syn::Result<String> {
127    let meta = attr.parse_meta()?;
128
129    let val = match meta.clone() {
130        Meta::NameValue(val) => val,
131        _ => {
132            return Err(syn::Error::new_spanned(
133                meta,
134                format!(
135                    "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
136                ),
137            ));
138        }
139    };
140
141    if !val.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION) {
142        return Err(syn::Error::new_spanned(
143            meta,
144            format!(
145                "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
146            ),
147        ));
148    }
149
150    let fn_name = match val.lit {
151        Lit::Str(fn_name) => fn_name,
152        _ => {
153            return Err(syn::Error::new_spanned(
154                meta,
155                format!(
156                    "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
157                ),
158            ));
159        }
160    };
161    Ok(fn_name.value())
162}
163
164fn extract_generics_names(generics: &Generics) -> Vec<Ident> {
165    generics
166        .params
167        .iter()
168        .map(|g| match g {
169            syn::GenericParam::Type(t) => t.ident.clone(),
170            _ => panic!("Unsupported generic type"),
171        })
172        .collect()
173}
174
175struct ExtractedStructInfo {
176    field_names: Vec<Ident>,
177    cf_names: Vec<Ident>,
178    inner_types: Vec<AngleBracketedGenericArguments>,
179    derived_table_options: Vec<GeneralTableOptions>,
180    deprecated_cfs: Vec<Ident>,
181}
182
183#[proc_macro_derive(
184    DBMapUtils,
185    attributes(default_options_override_fn, rename, tidehunter)
186)]
187pub fn derive_dbmap_utils_general(input: TokenStream) -> TokenStream {
188    let input = parse_macro_input!(input as ItemStruct);
189    let name = &input.ident;
190    let is_tidehunter = input
191        .attrs
192        .iter()
193        .any(|attr| attr.path.is_ident("tidehunter"));
194    let generics = &input.generics;
195    let generics_names = extract_generics_names(generics);
196
197    // TODO: use `parse_quote` over `parse()`
198    let ExtractedStructInfo {
199        field_names,
200        cf_names,
201        inner_types,
202        derived_table_options,
203        deprecated_cfs,
204    } = extract_struct_info(input.clone());
205
206    let (key_names, value_names): (Vec<_>, Vec<_>) = inner_types
207        .iter()
208        .map(|q| (q.args.first().unwrap(), q.args.last().unwrap()))
209        .unzip();
210
211    let default_options_override_fn_names: Vec<proc_macro2::TokenStream> = derived_table_options
212        .iter()
213        .map(|q| {
214            let GeneralTableOptions::OverrideFunction(fn_name) = q;
215            fn_name.parse().unwrap()
216        })
217        .collect();
218
219    let generics_bounds =
220        "std::fmt::Debug + serde::Serialize + for<'de> serde::de::Deserialize<'de>";
221    let generics_bounds_token: proc_macro2::TokenStream = generics_bounds.parse().unwrap();
222
223    let intermediate_db_map_struct_name_str = format!("{name}IntermediateDBMapStructPrimary");
224    let intermediate_db_map_struct_name: proc_macro2::TokenStream =
225        intermediate_db_map_struct_name_str.parse().unwrap();
226
227    let secondary_db_map_struct_name_str = format!("{name}ReadOnly");
228    let secondary_db_map_struct_name: proc_macro2::TokenStream =
229        secondary_db_map_struct_name_str.parse().unwrap();
230
231    let base_code = quote! {
232        /// Create an intermediate struct used to open the DBMap tables in primary mode
233        /// This is only used internally
234        struct #intermediate_db_map_struct_name #generics {
235                #(
236                    pub #field_names : DBMap #inner_types,
237                )*
238        }
239
240        impl <
241            #(
242                #generics_names: #generics_bounds_token,
243            )*
244        > #name #generics {
245            /// Returns a list of the tables name and type pairs
246            pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
247                vec![#(
248                    (stringify!(#cf_names).to_owned(), (stringify!(#key_names).to_owned(), stringify!(#value_names).to_owned())),
249                )*].into_iter().collect()
250            }
251        }
252    };
253
254    let secondary_code = quote! {
255        pub struct #secondary_db_map_struct_name #generics {
256            #(
257                pub #field_names : DBMap #inner_types,
258            )*
259        }
260
261        impl <
262                #(
263                    #generics_names: #generics_bounds_token,
264                )*
265            > #secondary_db_map_struct_name #generics {
266            /// Open in read only mode. No limitation on number of processes to do this
267            pub fn open_tables_read_only(
268                primary_path: std::path::PathBuf,
269                with_secondary_path: Option<std::path::PathBuf>,
270                metric_conf: typed_store::rocks::MetricConf,
271                global_db_options_override: Option<typed_store::rocksdb::Options>,
272            ) -> Self {
273                let inner = match with_secondary_path {
274                    Some(q) => #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(q), metric_conf, global_db_options_override, None, false),
275                    None => {
276                        let p: std::path::PathBuf = tempfile::tempdir()
277                        .expect("Failed to open temporary directory")
278                        .keep();
279                        #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(p), metric_conf, global_db_options_override, None, false)
280                    }
281                };
282                Self {
283                    #(
284                        #field_names: inner.#field_names,
285                    )*
286                }
287            }
288
289            fn cf_name_to_table_name(cf_name: &str) -> eyre::Result<&'static str> {
290                Ok(match cf_name {
291                    #(
292                        stringify!(#cf_names) => stringify!(#field_names),
293                    )*
294                    _ => eyre::bail!("No such cf name: {}", cf_name),
295                })
296            }
297
298            /// Dump all key-value pairs in the page at the given table name
299            /// Tables must be opened in read only mode using `open_tables_read_only`
300            pub fn dump(&self, cf_name: &str, page_size: u16, page_number: usize) -> eyre::Result<std::collections::BTreeMap<String, String>> {
301                let table_name = Self::cf_name_to_table_name(cf_name)?;
302
303                Ok(match table_name {
304                    #(
305                        stringify!(#field_names) => {
306                            typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
307                            typed_store::traits::Map::safe_iter(&self.#field_names)
308                                .skip((page_number * (page_size) as usize))
309                                .take(page_size as usize)
310                                .map(|result| result.map(|(k, v)| (format!("{:?}", k), format!("{:?}", v))))
311                                .collect::<eyre::Result<std::collections::BTreeMap<_, _>, _>>()?
312                        }
313                    )*
314
315                    _ => eyre::bail!("No such table name: {}", table_name),
316                })
317            }
318
319            /// Get key value sizes from the db
320            /// Tables must be opened in read only mode using `open_tables_read_only`
321            pub fn table_summary(&self, table_name: &str) -> eyre::Result<typed_store::traits::TableSummary> {
322                let mut count = 0;
323                let mut key_bytes = 0;
324                let mut value_bytes = 0;
325                match table_name {
326                    #(
327                        stringify!(#field_names) => {
328                            typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
329                            self.#field_names.table_summary()
330                        }
331                    )*
332
333                    _ => eyre::bail!("No such table name: {}", table_name),
334                }
335            }
336
337            pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
338                vec![#(
339                    (stringify!(#cf_names).to_owned(), (stringify!(#key_names).to_owned(), stringify!(#value_names).to_owned())),
340                )*].into_iter().collect()
341            }
342
343            /// Try catch up with primary for all tables. This can be a slow operation
344            /// Tables must be opened in read only mode using `open_tables_read_only`
345            pub fn try_catch_up_with_primary_all(&self) -> eyre::Result<()> {
346                #(
347                    typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
348                )*
349                Ok(())
350            }
351        }
352    };
353
354    if is_tidehunter {
355        TokenStream::from(quote! {
356            #base_code
357            impl <
358                    #(
359                        #generics_names: #generics_bounds_token,
360                    )*
361                > #intermediate_db_map_struct_name #generics {
362                /// Opens a set of tables in read-write mode
363                /// If as_secondary_with_path is set, the DB is opened in read only mode with the path specified
364                pub fn open_tables_impl(
365                    path: std::path::PathBuf,
366                    metric_conf: typed_store::rocks::MetricConf,
367                    cf_configs: std::collections::BTreeMap<String, typed_store::tidehunter_util::ThConfig>,
368                ) -> Self {
369                    let mut builder = typed_store::tidehunter_util::KeyShapeBuilder::new();
370                    let (
371                        #(
372                            #field_names,
373                        )*
374                    ) = (
375                        #(
376                            typed_store::tidehunter_util::add_key_space(
377                                &mut builder,
378                                stringify!(#cf_names),
379                                &cf_configs[stringify!(#cf_names)],
380                            ),
381                        )*
382                    );
383                    let key_shape = builder.build();
384                    let inner_db = typed_store::tidehunter_util::open(path.as_path(), key_shape, metric_conf.db_name.clone());
385                    let db = std::sync::Arc::new(typed_store::rocks::Database::new(
386                        typed_store::rocks::Storage::TideHunter(inner_db),
387                        metric_conf));
388                    let (
389                        #(
390                            #field_names
391                        ),*
392                    ) = (#(
393                        DBMap::#inner_types::reopen_th(
394                            db.clone(), stringify!(#cf_names), #field_names,
395                            cf_configs[stringify!(#cf_names)].prefix.clone()
396                        )
397                    ),*);
398                    Self {
399                        #(
400                            #field_names,
401                        )*
402                    }
403                }
404            }
405
406            impl <
407                #(
408                    #generics_names: #generics_bounds_token,
409                )*
410            > #name #generics {
411                pub fn open_tables_read_write(
412                    path: std::path::PathBuf,
413                    metric_conf: typed_store::rocks::MetricConf,
414                    cf_configs: std::collections::BTreeMap<String, typed_store::tidehunter_util::ThConfig>,
415                ) -> Self {
416                    let inner = #intermediate_db_map_struct_name::open_tables_impl(path, metric_conf, cf_configs);
417                    Self {
418                        #(
419                            #field_names: inner.#field_names,
420                        )*
421                    }
422                }
423
424                pub fn get_read_only_handle (
425                    _: std::path::PathBuf,
426                    _: Option<std::path::PathBuf>,
427                    _: Option<typed_store::rocksdb::Options>,
428                    _: typed_store::rocks::MetricConf,
429                ) -> #secondary_db_map_struct_name #generics {
430                    unimplemented!("read only mode is not supported for TideHunter");
431                }
432            }
433
434            pub struct #secondary_db_map_struct_name;
435        })
436    } else {
437        TokenStream::from(quote! {
438            #base_code
439
440            impl <
441                    #(
442                        #generics_names: #generics_bounds_token,
443                    )*
444                > #intermediate_db_map_struct_name #generics {
445                /// Opens a set of tables in read-write mode
446                /// If as_secondary_with_path is set, the DB is opened in read only mode with the path specified
447                pub fn open_tables_impl(
448                    path: std::path::PathBuf,
449                    as_secondary_with_path: Option<std::path::PathBuf>,
450                    metric_conf: typed_store::rocks::MetricConf,
451                    global_db_options_override: Option<typed_store::rocksdb::Options>,
452                    tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
453                    remove_deprecated_tables: bool,
454                ) -> Self {
455                    let path = &path;
456                    let default_cf_opt = if let Some(opt) = global_db_options_override.as_ref() {
457                        typed_store::rocks::DBOptions {
458                            options: opt.clone(),
459                            rw_options: typed_store::rocks::default_db_options().rw_options,
460                        }
461                    } else {
462                        typed_store::rocks::default_db_options()
463                    };
464                    let (db, rwopt_cfs) = {
465                        let opt_cfs = match tables_db_options_override {
466                            None => [
467                                #(
468                                    (stringify!(#cf_names).to_owned(), #default_options_override_fn_names()),
469                                )*
470                            ],
471                            Some(o) => [
472                                #(
473                                    (stringify!(#cf_names).to_owned(), o.to_map().get(stringify!(#cf_names)).unwrap_or(&default_cf_opt).clone()),
474                                )*
475                            ]
476                        };
477                        // Safe to call unwrap because we will have at least one field_name entry in the struct
478                        let rwopt_cfs: std::collections::HashMap<String, typed_store::rocks::ReadWriteOptions> = opt_cfs.iter().map(|q| (q.0.as_str().to_string(), q.1.rw_options.clone())).collect();
479                        let opt_cfs: Vec<_> = opt_cfs.iter().map(|q| (q.0.as_str(), q.1.options.clone())).collect();
480                        let db = match as_secondary_with_path.clone() {
481                            Some(p) => typed_store::rocks::open_cf_opts_secondary(path, Some(&p), global_db_options_override, metric_conf, &opt_cfs),
482                            _ => typed_store::rocks::open_cf_opts(path, global_db_options_override, metric_conf, &opt_cfs)
483                        };
484                        db.map(|d| (d, rwopt_cfs))
485                    }.expect(&format!("Cannot open DB at {:?}", path));
486                    let deprecated_tables = vec![#(stringify!(#deprecated_cfs),)*];
487                    let (
488                            #(
489                                #field_names
490                            ),*
491                    ) = (#(
492                            DBMap::#inner_types::reopen(&db, Some(stringify!(#cf_names)), rwopt_cfs.get(stringify!(#cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()), remove_deprecated_tables && deprecated_tables.contains(&stringify!(#cf_names))).expect(&format!("Cannot open {} CF.", stringify!(#cf_names))[..])
493                        ),*);
494
495                    if as_secondary_with_path.is_none() && remove_deprecated_tables {
496                        #(
497                            db.drop_cf(stringify!(#deprecated_cfs)).expect("failed to drop a deprecated cf");
498                        )*
499                    }
500                    Self {
501                        #(
502                            #field_names,
503                        )*
504                    }
505                }
506            }
507
508            // <----------- This section generates the read-write open logic and other common utils -------------->
509            impl <
510                    #(
511                        #generics_names: #generics_bounds_token,
512                    )*
513                > #name #generics {
514                /// Opens a set of tables in read-write mode
515                /// Only one process is allowed to do this at a time
516                /// `global_db_options_override` apply to the whole DB
517                /// `tables_db_options_override` apply to each table. If `None`, the attributes from `default_options_override_fn` are used if any
518                #[allow(unused_parens)]
519                pub fn open_tables_read_write(
520                    path: std::path::PathBuf,
521                    metric_conf: typed_store::rocks::MetricConf,
522                    global_db_options_override: Option<typed_store::rocksdb::Options>,
523                    tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
524                ) -> Self {
525                    let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, metric_conf, global_db_options_override, tables_db_options_override, false);
526                    Self {
527                        #(
528                            #field_names: inner.#field_names,
529                        )*
530                    }
531                }
532
533                #[allow(unused_parens)]
534                pub fn open_tables_read_write_with_deprecation_option(
535                    path: std::path::PathBuf,
536                    metric_conf: typed_store::rocks::MetricConf,
537                    global_db_options_override: Option<typed_store::rocksdb::Options>,
538                    tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
539                    remove_deprecated_tables: bool,
540                ) -> Self {
541                    let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, metric_conf, global_db_options_override, tables_db_options_override, remove_deprecated_tables);
542                    Self {
543                        #(
544                            #field_names: inner.#field_names,
545                        )*
546                    }
547                }
548
549                /// This opens the DB in read only mode and returns a struct which exposes debug features
550                pub fn get_read_only_handle (
551                    primary_path: std::path::PathBuf,
552                    with_secondary_path: Option<std::path::PathBuf>,
553                    global_db_options_override: Option<typed_store::rocksdb::Options>,
554                    metric_conf: typed_store::rocks::MetricConf,
555                    ) -> #secondary_db_map_struct_name #generics {
556                    #secondary_db_map_struct_name::open_tables_read_only(primary_path, with_secondary_path, metric_conf, global_db_options_override)
557                }
558            }
559            #secondary_code
560        })
561    }
562}