1use serde::{Deserialize, Serialize, de::Deserializer};
5use serde_with::serde_as;
6use std::path::PathBuf;
7
8pub const DEFAULT_SKETCH_CAPACITY: usize = 50_000;
13pub const DEFAULT_SKETCH_PROBABILITY: f64 = 0.999;
14pub const DEFAULT_SKETCH_TOLERANCE: f64 = 0.2;
15use rand::distributions::Distribution;
16
17const TRAFFIC_SINK_TIMEOUT_SEC: u64 = 300;
18
19#[derive(Clone, Debug, Deserialize, Serialize, Default)]
67#[serde(rename_all = "kebab-case")]
68pub enum ClientIdSource {
69 #[default]
70 SocketAddr,
71 XForwardedFor(usize),
72}
73
74#[derive(Clone, Debug, Deserialize, Serialize)]
75pub struct TrafficControlReconfigParams {
76 pub error_threshold: Option<u64>,
77 pub spam_threshold: Option<u64>,
78 pub dry_run: Option<bool>,
79}
80
81#[derive(Clone, Debug, Deserialize, Serialize)]
82pub struct Weight(f32);
83
84impl Weight {
85 pub fn new(value: f32) -> Result<Self, &'static str> {
86 if (0.0..=1.0).contains(&value) {
87 Ok(Self(value))
88 } else {
89 Err("Weight must be between 0.0 and 1.0")
90 }
91 }
92
93 pub fn one() -> Self {
94 Self(1.0)
95 }
96
97 pub fn zero() -> Self {
98 Self(0.0)
99 }
100
101 pub fn value(&self) -> f32 {
102 self.0
103 }
104
105 pub fn is_sampled(&self) -> bool {
106 let mut rng = rand::thread_rng();
107 let sample = rand::distributions::Uniform::new(0.0, 1.0).sample(&mut rng);
108 sample <= self.value()
109 }
110}
111
112fn validate_sample_rate<'de, D>(deserializer: D) -> Result<Weight, D::Error>
113where
114 D: Deserializer<'de>,
115{
116 let value = f32::deserialize(deserializer)?;
117 Weight::new(value)
118 .map_err(|_| serde::de::Error::custom("spam-sample-rate must be between 0.0 and 1.0"))
119}
120
121impl PartialEq for Weight {
122 fn eq(&self, other: &Self) -> bool {
123 self.value() == other.value()
124 }
125}
126
127#[serde_as]
128#[derive(Clone, Debug, Deserialize, Serialize)]
129#[serde(rename_all = "kebab-case")]
130pub struct RemoteFirewallConfig {
131 pub remote_fw_url: String,
132 pub destination_port: u16,
133 #[serde(default)]
134 pub delegate_spam_blocking: bool,
135 #[serde(default)]
136 pub delegate_error_blocking: bool,
137 #[serde(default = "default_drain_path")]
138 pub drain_path: PathBuf,
139 #[serde(default = "default_drain_timeout")]
142 pub drain_timeout_secs: u64,
143}
144
145fn default_drain_path() -> PathBuf {
146 PathBuf::from("/tmp/drain")
147}
148
149fn default_drain_timeout() -> u64 {
150 TRAFFIC_SINK_TIMEOUT_SEC
151}
152
153#[serde_as]
154#[derive(Clone, Debug, Deserialize, Serialize)]
155#[serde(rename_all = "kebab-case")]
156pub struct FreqThresholdConfig {
157 #[serde(default = "default_client_threshold")]
158 pub client_threshold: u64,
159 #[serde(default = "default_proxied_client_threshold")]
160 pub proxied_client_threshold: u64,
161 #[serde(default = "default_window_size_secs")]
162 pub window_size_secs: u64,
163 #[serde(default = "default_update_interval_secs")]
164 pub update_interval_secs: u64,
165 #[serde(default = "default_sketch_capacity")]
166 pub sketch_capacity: usize,
167 #[serde(default = "default_sketch_probability")]
168 pub sketch_probability: f64,
169 #[serde(default = "default_sketch_tolerance")]
170 pub sketch_tolerance: f64,
171}
172
173impl Default for FreqThresholdConfig {
174 fn default() -> Self {
175 Self {
176 client_threshold: default_client_threshold(),
177 proxied_client_threshold: default_proxied_client_threshold(),
178 window_size_secs: default_window_size_secs(),
179 update_interval_secs: default_update_interval_secs(),
180 sketch_capacity: default_sketch_capacity(),
181 sketch_probability: default_sketch_probability(),
182 sketch_tolerance: default_sketch_tolerance(),
183 }
184 }
185}
186
187fn default_client_threshold() -> u64 {
188 1_000_000
195}
196
197fn default_proxied_client_threshold() -> u64 {
198 10
199}
200
201fn default_window_size_secs() -> u64 {
202 30
203}
204
205fn default_update_interval_secs() -> u64 {
206 5
207}
208
209fn default_sketch_capacity() -> usize {
210 DEFAULT_SKETCH_CAPACITY
211}
212
213fn default_sketch_probability() -> f64 {
214 DEFAULT_SKETCH_PROBABILITY
215}
216
217fn default_sketch_tolerance() -> f64 {
218 DEFAULT_SKETCH_TOLERANCE
219}
220
221#[derive(Clone, Serialize, Deserialize, Debug, Default)]
224pub enum PolicyType {
225 #[default]
227 NoOp,
228
229 #[serde(rename = "freq-threshold", alias = "FreqThreshold")]
233 FreqThreshold(FreqThresholdConfig),
234
235 TestNConnIP(u64),
241 TestPanicOnInvocation,
244}
245
246#[serde_as]
247#[derive(Clone, Debug, Deserialize, Serialize)]
248#[serde(rename_all = "kebab-case")]
249pub struct PolicyConfig {
250 #[serde(default = "default_client_id_source")]
251 pub client_id_source: ClientIdSource,
252 #[serde(default = "default_connection_blocklist_ttl_sec")]
253 pub connection_blocklist_ttl_sec: u64,
254 #[serde(default)]
255 pub proxy_blocklist_ttl_sec: u64,
256 #[serde(default)]
257 pub spam_policy_type: PolicyType,
258 #[serde(default)]
259 pub error_policy_type: PolicyType,
260 #[serde(default = "default_channel_capacity")]
261 pub channel_capacity: usize,
262 #[serde(
263 default = "default_spam_sample_rate",
264 deserialize_with = "validate_sample_rate"
265 )]
266 pub spam_sample_rate: Weight,
272 #[serde(default = "default_dry_run")]
273 pub dry_run: bool,
274 #[serde(default)]
278 pub allow_list: Option<Vec<String>>,
279}
280
281impl Default for PolicyConfig {
282 fn default() -> Self {
283 Self {
284 client_id_source: default_client_id_source(),
285 connection_blocklist_ttl_sec: 0,
286 proxy_blocklist_ttl_sec: 0,
287 spam_policy_type: PolicyType::NoOp,
288 error_policy_type: PolicyType::NoOp,
289 channel_capacity: 100,
290 spam_sample_rate: default_spam_sample_rate(),
291 dry_run: default_dry_run(),
292 allow_list: None,
293 }
294 }
295}
296
297impl PolicyConfig {
298 pub fn default_dos_protection_policy() -> PolicyConfig {
299 PolicyConfig {
300 client_id_source: ClientIdSource::SocketAddr,
301 spam_policy_type: PolicyType::FreqThreshold(FreqThresholdConfig {
302 client_threshold: 1000,
303 window_size_secs: 5,
304 update_interval_secs: 1,
305 ..FreqThresholdConfig::default()
306 }),
307 error_policy_type: PolicyType::FreqThreshold(FreqThresholdConfig {
308 client_threshold: 50,
309 window_size_secs: 5,
310 update_interval_secs: 1,
311 ..FreqThresholdConfig::default()
312 }),
313 channel_capacity: 6000,
314 spam_sample_rate: Weight::new(1.0).unwrap(),
315 dry_run: true,
316 ..PolicyConfig::default()
317 }
318 }
319}
320
321pub fn default_client_id_source() -> ClientIdSource {
322 ClientIdSource::SocketAddr
323}
324
325pub fn default_connection_blocklist_ttl_sec() -> u64 {
326 60
327}
328pub fn default_channel_capacity() -> usize {
329 100
330}
331
332pub fn default_dry_run() -> bool {
333 true
334}
335
336pub fn default_spam_sample_rate() -> Weight {
337 Weight::new(0.2).unwrap()
338}