postgres_native_tls/
lib.rs

1//! TLS support for `tokio-postgres` and `postgres` via `native-tls`.
2//!
3//! # Examples
4//!
5//! ```no_run
6//! use native_tls::{Certificate, TlsConnector};
7//! # #[cfg(feature = "runtime")]
8//! use postgres_native_tls::MakeTlsConnector;
9//! use std::fs;
10//!
11//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
12//! # #[cfg(feature = "runtime")] {
13//! let cert = fs::read("database_cert.pem")?;
14//! let cert = Certificate::from_pem(&cert)?;
15//! let connector = TlsConnector::builder()
16//!     .add_root_certificate(cert)
17//!     .build()?;
18//! let connector = MakeTlsConnector::new(connector);
19//!
20//! let connect_future = tokio_postgres::connect(
21//!     "host=localhost user=postgres sslmode=require",
22//!     connector,
23//! );
24//! # }
25//!
26//! // ...
27//! # Ok(())
28//! # }
29//! ```
30//!
31//! ```no_run
32//! use native_tls::{Certificate, TlsConnector};
33//! # #[cfg(feature = "runtime")]
34//! use postgres_native_tls::MakeTlsConnector;
35//! use std::fs;
36//!
37//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
38//! # #[cfg(feature = "runtime")] {
39//! let cert = fs::read("database_cert.pem")?;
40//! let cert = Certificate::from_pem(&cert)?;
41//! let connector = TlsConnector::builder()
42//!     .add_root_certificate(cert)
43//!     .build()?;
44//! let connector = MakeTlsConnector::new(connector);
45//!
46//! let client = postgres::Client::connect(
47//!     "host=localhost user=postgres sslmode=require",
48//!     connector,
49//! )?;
50//! # }
51//! # Ok(())
52//! # }
53//! ```
54#![warn(rust_2018_idioms, clippy::all, missing_docs)]
55
56use native_tls::TlsConnectorBuilder;
57use std::future::Future;
58use std::io;
59use std::pin::Pin;
60use std::task::{Context, Poll};
61use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
62use tokio_postgres::tls;
63#[cfg(feature = "runtime")]
64use tokio_postgres::tls::MakeTlsConnect;
65use tokio_postgres::tls::{ChannelBinding, TlsConnect};
66
67#[cfg(test)]
68mod test;
69
70/// A `MakeTlsConnect` implementation using the `native-tls` crate.
71///
72/// Requires the `runtime` Cargo feature (enabled by default).
73#[cfg(feature = "runtime")]
74#[derive(Clone)]
75pub struct MakeTlsConnector(native_tls::TlsConnector);
76
77#[cfg(feature = "runtime")]
78impl MakeTlsConnector {
79    /// Creates a new connector.
80    pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
81        MakeTlsConnector(connector)
82    }
83}
84
85#[cfg(feature = "runtime")]
86impl<S> MakeTlsConnect<S> for MakeTlsConnector
87where
88    S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
89{
90    type Stream = TlsStream<S>;
91    type TlsConnect = TlsConnector;
92    type Error = native_tls::Error;
93
94    fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
95        Ok(TlsConnector::new(self.0.clone(), domain))
96    }
97}
98
99/// A `TlsConnect` implementation using the `native-tls` crate.
100pub struct TlsConnector {
101    connector: tokio_native_tls::TlsConnector,
102    domain: String,
103}
104
105impl TlsConnector {
106    /// Creates a new connector configured to connect to the specified domain.
107    pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
108        TlsConnector {
109            connector: tokio_native_tls::TlsConnector::from(connector),
110            domain: domain.to_string(),
111        }
112    }
113}
114
115impl<S> TlsConnect<S> for TlsConnector
116where
117    S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
118{
119    type Stream = TlsStream<S>;
120    type Error = native_tls::Error;
121    #[allow(clippy::type_complexity)]
122    type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
123
124    fn connect(self, stream: S) -> Self::Future {
125        let stream = BufReader::with_capacity(8192, stream);
126        let future = async move {
127            let stream = self.connector.connect(&self.domain, stream).await?;
128
129            Ok(TlsStream(stream))
130        };
131
132        Box::pin(future)
133    }
134}
135
136/// The stream returned by `TlsConnector`.
137pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
138
139impl<S> AsyncRead for TlsStream<S>
140where
141    S: AsyncRead + AsyncWrite + Unpin,
142{
143    fn poll_read(
144        mut self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146        buf: &mut ReadBuf<'_>,
147    ) -> Poll<io::Result<()>> {
148        Pin::new(&mut self.0).poll_read(cx, buf)
149    }
150}
151
152impl<S> AsyncWrite for TlsStream<S>
153where
154    S: AsyncRead + AsyncWrite + Unpin,
155{
156    fn poll_write(
157        mut self: Pin<&mut Self>,
158        cx: &mut Context<'_>,
159        buf: &[u8],
160    ) -> Poll<io::Result<usize>> {
161        Pin::new(&mut self.0).poll_write(cx, buf)
162    }
163
164    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165        Pin::new(&mut self.0).poll_flush(cx)
166    }
167
168    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169        Pin::new(&mut self.0).poll_shutdown(cx)
170    }
171}
172
173impl<S> tls::TlsStream for TlsStream<S>
174where
175    S: AsyncRead + AsyncWrite + Unpin,
176{
177    fn channel_binding(&self) -> ChannelBinding {
178        match self.0.get_ref().tls_server_end_point().ok().flatten() {
179            Some(buf) => ChannelBinding::tls_server_end_point(buf),
180            None => ChannelBinding::none(),
181        }
182    }
183}
184
185/// Set ALPN for `TlsConnectorBuilder`
186///
187/// This is required when using `sslnegotiation=direct`
188pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) {
189    builder.request_alpns(&["postgresql"]);
190}