postgres_native_tls/
lib.rs1#![warn(rust_2018_idioms, clippy::all, missing_docs)]
55
56use native_tls::TlsConnectorBuilder;
57use std::future::Future;
58use std::io;
59use std::pin::Pin;
60use std::task::{Context, Poll};
61use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
62use tokio_postgres::tls;
63#[cfg(feature = "runtime")]
64use tokio_postgres::tls::MakeTlsConnect;
65use tokio_postgres::tls::{ChannelBinding, TlsConnect};
66
67#[cfg(test)]
68mod test;
69
70#[cfg(feature = "runtime")]
74#[derive(Clone)]
75pub struct MakeTlsConnector(native_tls::TlsConnector);
76
77#[cfg(feature = "runtime")]
78impl MakeTlsConnector {
79 pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
81 MakeTlsConnector(connector)
82 }
83}
84
85#[cfg(feature = "runtime")]
86impl<S> MakeTlsConnect<S> for MakeTlsConnector
87where
88 S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
89{
90 type Stream = TlsStream<S>;
91 type TlsConnect = TlsConnector;
92 type Error = native_tls::Error;
93
94 fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
95 Ok(TlsConnector::new(self.0.clone(), domain))
96 }
97}
98
99pub struct TlsConnector {
101 connector: tokio_native_tls::TlsConnector,
102 domain: String,
103}
104
105impl TlsConnector {
106 pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
108 TlsConnector {
109 connector: tokio_native_tls::TlsConnector::from(connector),
110 domain: domain.to_string(),
111 }
112 }
113}
114
115impl<S> TlsConnect<S> for TlsConnector
116where
117 S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
118{
119 type Stream = TlsStream<S>;
120 type Error = native_tls::Error;
121 #[allow(clippy::type_complexity)]
122 type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
123
124 fn connect(self, stream: S) -> Self::Future {
125 let stream = BufReader::with_capacity(8192, stream);
126 let future = async move {
127 let stream = self.connector.connect(&self.domain, stream).await?;
128
129 Ok(TlsStream(stream))
130 };
131
132 Box::pin(future)
133 }
134}
135
136pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
138
139impl<S> AsyncRead for TlsStream<S>
140where
141 S: AsyncRead + AsyncWrite + Unpin,
142{
143 fn poll_read(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &mut ReadBuf<'_>,
147 ) -> Poll<io::Result<()>> {
148 Pin::new(&mut self.0).poll_read(cx, buf)
149 }
150}
151
152impl<S> AsyncWrite for TlsStream<S>
153where
154 S: AsyncRead + AsyncWrite + Unpin,
155{
156 fn poll_write(
157 mut self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 buf: &[u8],
160 ) -> Poll<io::Result<usize>> {
161 Pin::new(&mut self.0).poll_write(cx, buf)
162 }
163
164 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165 Pin::new(&mut self.0).poll_flush(cx)
166 }
167
168 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169 Pin::new(&mut self.0).poll_shutdown(cx)
170 }
171}
172
173impl<S> tls::TlsStream for TlsStream<S>
174where
175 S: AsyncRead + AsyncWrite + Unpin,
176{
177 fn channel_binding(&self) -> ChannelBinding {
178 match self.0.get_ref().tls_server_end_point().ok().flatten() {
179 Some(buf) => ChannelBinding::tls_server_end_point(buf),
180 None => ChannelBinding::none(),
181 }
182 }
183}
184
185pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) {
189 builder.request_alpns(&["postgresql"]);
190}