mysten_network/
client.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    config::Config,
6    multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
7};
8use eyre::{Context, Result, eyre};
9use hyper_util::client::legacy::connect::{HttpConnector, dns::Name};
10use once_cell::sync::OnceCell;
11use std::{
12    collections::HashMap,
13    fmt,
14    future::Future,
15    io,
16    net::{SocketAddr, ToSocketAddrs},
17    pin::Pin,
18    sync::{Arc, Mutex},
19    task::{self, Poll},
20    time::Instant,
21    vec,
22};
23use tokio::task::JoinHandle;
24use tokio_rustls::rustls::ClientConfig;
25use tonic::transport::{Channel, Endpoint, Uri};
26use tower::Service;
27use tracing::{info, trace};
28
29pub async fn connect(address: &Multiaddr, tls_config: ClientConfig) -> Result<Channel> {
30    let channel = endpoint_from_multiaddr(address, tls_config)?
31        .connect()
32        .await?;
33    Ok(channel)
34}
35
36pub fn connect_lazy(address: &Multiaddr, tls_config: ClientConfig) -> Result<Channel> {
37    let channel = endpoint_from_multiaddr(address, tls_config)?.connect_lazy();
38    Ok(channel)
39}
40
41pub(crate) async fn connect_with_config(
42    address: &Multiaddr,
43    tls_config: ClientConfig,
44    config: &Config,
45) -> Result<Channel> {
46    let channel = endpoint_from_multiaddr(address, tls_config)?
47        .apply_config(config)
48        .connect()
49        .await?;
50    Ok(channel)
51}
52
53pub(crate) fn connect_lazy_with_config(
54    address: &Multiaddr,
55    tls_config: ClientConfig,
56    config: &Config,
57) -> Result<Channel> {
58    let channel = endpoint_from_multiaddr(address, tls_config)?
59        .apply_config(config)
60        .connect_lazy();
61    Ok(channel)
62}
63
64fn endpoint_from_multiaddr(addr: &Multiaddr, tls_config: ClientConfig) -> Result<MyEndpoint> {
65    let mut iter = addr.iter();
66
67    let channel = match iter.next().ok_or_else(|| eyre!("address is empty"))? {
68        Protocol::Dns(_) => {
69            let (dns_name, tcp_port, http_or_https) = parse_dns(addr)?;
70            let uri = format!("{http_or_https}://{dns_name}:{tcp_port}");
71            MyEndpoint::try_from_uri(uri, tls_config)?
72        }
73        Protocol::Ip4(_) => {
74            let (socket_addr, http_or_https) = parse_ip4(addr)?;
75            let uri = format!("{http_or_https}://{socket_addr}");
76            MyEndpoint::try_from_uri(uri, tls_config)?
77        }
78        Protocol::Ip6(_) => {
79            let (socket_addr, http_or_https) = parse_ip6(addr)?;
80            let uri = format!("{http_or_https}://{socket_addr}");
81            MyEndpoint::try_from_uri(uri, tls_config)?
82        }
83        unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
84    };
85
86    Ok(channel)
87}
88
89struct MyEndpoint {
90    endpoint: Endpoint,
91    tls_config: ClientConfig,
92}
93
94static DISABLE_CACHING_RESOLVER: OnceCell<bool> = OnceCell::new();
95
96impl MyEndpoint {
97    fn new(endpoint: Endpoint, tls_config: ClientConfig) -> Self {
98        Self {
99            endpoint,
100            tls_config,
101        }
102    }
103
104    fn try_from_uri(uri: String, tls_config: ClientConfig) -> Result<Self> {
105        let uri: Uri = uri
106            .parse()
107            .with_context(|| format!("unable to create Uri from '{uri}'"))?;
108        let endpoint = Endpoint::from(uri);
109        Ok(Self::new(endpoint, tls_config))
110    }
111
112    fn apply_config(mut self, config: &Config) -> Self {
113        self.endpoint = apply_config_to_endpoint(config, self.endpoint);
114        self
115    }
116
117    fn connect_lazy(self) -> Channel {
118        let disable_caching_resolver = *DISABLE_CACHING_RESOLVER.get_or_init(|| {
119            let disable_caching_resolver = std::env::var("DISABLE_CACHING_RESOLVER").is_ok();
120            info!("DISABLE_CACHING_RESOLVER: {disable_caching_resolver}");
121            disable_caching_resolver
122        });
123
124        if disable_caching_resolver {
125            let mut http = HttpConnector::new();
126            http.enforce_http(false);
127            http.set_nodelay(true);
128            http.set_keepalive(None);
129            http.set_connect_timeout(None);
130
131            Channel::new(
132                hyper_rustls::HttpsConnectorBuilder::new()
133                    .with_tls_config(self.tls_config)
134                    .https_only()
135                    .enable_http2()
136                    .wrap_connector(http),
137                self.endpoint,
138            )
139        } else {
140            let mut http = HttpConnector::new_with_resolver(CachingResolver::new());
141            http.enforce_http(false);
142            http.set_nodelay(true);
143            http.set_keepalive(None);
144            http.set_connect_timeout(None);
145
146            let https = hyper_rustls::HttpsConnectorBuilder::new()
147                .with_tls_config(self.tls_config)
148                .https_only()
149                .enable_http2()
150                .wrap_connector(http);
151            Channel::new(https, self.endpoint)
152        }
153    }
154
155    async fn connect(self) -> Result<Channel> {
156        let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
157            .with_tls_config(self.tls_config)
158            .https_only()
159            .enable_http2()
160            .build();
161        Channel::connect(https_connector, self.endpoint)
162            .await
163            .map_err(Into::into)
164    }
165}
166
167fn apply_config_to_endpoint(config: &Config, mut endpoint: Endpoint) -> Endpoint {
168    if let Some(limit) = config.concurrency_limit_per_connection {
169        endpoint = endpoint.concurrency_limit(limit);
170    }
171
172    if let Some(timeout) = config.request_timeout {
173        endpoint = endpoint.timeout(timeout);
174    }
175
176    if let Some(timeout) = config.connect_timeout {
177        endpoint = endpoint.connect_timeout(timeout);
178    }
179
180    if let Some(http2_keepalive_interval) = config.http2_keepalive_interval {
181        endpoint = endpoint.http2_keep_alive_interval(http2_keepalive_interval);
182    }
183
184    if let Some(http2_keepalive_timeout) = config.http2_keepalive_timeout {
185        endpoint = endpoint.keep_alive_timeout(http2_keepalive_timeout);
186    }
187
188    if let Some((limit, duration)) = config.rate_limit {
189        endpoint = endpoint.rate_limit(limit, duration);
190    }
191
192    endpoint
193        .initial_stream_window_size(config.http2_initial_stream_window_size)
194        .initial_connection_window_size(config.http2_initial_connection_window_size)
195        .tcp_keepalive(config.tcp_keepalive)
196}
197
198type CacheEntry = (Instant, Vec<SocketAddr>);
199
200/// A caching resolver based on hyper_util GaiResolver
201#[derive(Clone)]
202pub struct CachingResolver {
203    cache: Arc<Mutex<HashMap<Name, CacheEntry>>>,
204}
205
206type SocketAddrs = vec::IntoIter<SocketAddr>;
207
208pub struct CachingFuture {
209    inner: JoinHandle<Result<SocketAddrs, io::Error>>,
210}
211
212impl CachingResolver {
213    pub fn new() -> Self {
214        CachingResolver {
215            cache: Arc::new(Mutex::new(HashMap::new())),
216        }
217    }
218}
219
220impl Default for CachingResolver {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226impl Service<Name> for CachingResolver {
227    type Response = SocketAddrs;
228    type Error = io::Error;
229    type Future = CachingFuture;
230
231    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
232        Poll::Ready(Ok(()))
233    }
234
235    fn call(&mut self, name: Name) -> Self::Future {
236        let blocking = {
237            let cache = self.cache.clone();
238            tokio::task::spawn_blocking(move || {
239                let entry = cache.lock().unwrap().get(&name).cloned();
240
241                if let Some((when, addrs)) = entry {
242                    trace!("cached host={:?}", name.as_str());
243
244                    if when.elapsed().as_secs() > 60 {
245                        trace!("refreshing cache for host={:?}", name.as_str());
246                        // Start a new task to update the cache later.
247                        tokio::task::spawn_blocking(move || {
248                            if let Ok(addrs) = (name.as_str(), 0).to_socket_addrs() {
249                                let addrs: Vec<_> = addrs.collect();
250                                trace!("updating cached host={:?}", name.as_str());
251                                cache
252                                    .lock()
253                                    .unwrap()
254                                    .insert(name, (Instant::now(), addrs.clone()));
255                            }
256                        });
257                    }
258
259                    Ok(addrs.into_iter())
260                } else {
261                    trace!("resolving host={:?}", name.as_str());
262                    match (name.as_str(), 0).to_socket_addrs() {
263                        Ok(addrs) => {
264                            let addrs: Vec<_> = addrs.collect();
265                            cache
266                                .lock()
267                                .unwrap()
268                                .insert(name, (Instant::now(), addrs.clone()));
269                            Ok(addrs.into_iter())
270                        }
271                        res => res,
272                    }
273                }
274            })
275        };
276
277        CachingFuture { inner: blocking }
278    }
279}
280
281impl fmt::Debug for CachingResolver {
282    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283        f.pad("CachingResolver")
284    }
285}
286
287impl Future for CachingFuture {
288    type Output = Result<SocketAddrs, io::Error>;
289
290    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
291        Pin::new(&mut self.inner).poll(cx).map(|res| match res {
292            Ok(Ok(addrs)) => Ok(addrs),
293            Ok(Err(err)) => Err(err),
294            Err(join_err) => {
295                if join_err.is_cancelled() {
296                    Err(io::Error::new(io::ErrorKind::Interrupted, join_err))
297                } else {
298                    panic!("background task failed: {:?}", join_err)
299                }
300            }
301        })
302    }
303}
304
305impl fmt::Debug for CachingFuture {
306    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307        f.pad("CachingFuture")
308    }
309}
310
311impl Drop for CachingFuture {
312    fn drop(&mut self) {
313        self.inner.abort();
314    }
315}