1use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2use crate::config::{self, Config};
3use crate::connect_tls::connect_tls;
4use crate::maybe_tls_stream::MaybeTlsStream;
5use crate::tls::{TlsConnect, TlsStream};
6use crate::{Client, Connection, Error};
7use bytes::BytesMut;
8use fallible_iterator::FallibleIterator;
9use futures_channel::mpsc;
10use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
11use postgres_protocol::authentication;
12use postgres_protocol::authentication::sasl;
13use postgres_protocol::authentication::sasl::ScramSha256;
14use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
15use postgres_protocol::message::frontend;
16use std::borrow::Cow;
17use std::collections::{HashMap, VecDeque};
18use std::io;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_util::codec::Framed;
23
24pub struct StartupStream<S, T> {
25 inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
26 buf: BackendMessages,
27 delayed: VecDeque<BackendMessage>,
28}
29
30impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
31where
32 S: AsyncRead + AsyncWrite + Unpin,
33 T: AsyncRead + AsyncWrite + Unpin,
34{
35 type Error = io::Error;
36
37 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
38 Pin::new(&mut self.inner).poll_ready(cx)
39 }
40
41 fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
42 Pin::new(&mut self.inner).start_send(item)
43 }
44
45 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46 Pin::new(&mut self.inner).poll_flush(cx)
47 }
48
49 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
50 Pin::new(&mut self.inner).poll_close(cx)
51 }
52}
53
54impl<S, T> Stream for StartupStream<S, T>
55where
56 S: AsyncRead + AsyncWrite + Unpin,
57 T: AsyncRead + AsyncWrite + Unpin,
58{
59 type Item = io::Result<Message>;
60
61 fn poll_next(
62 mut self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 ) -> Poll<Option<io::Result<Message>>> {
65 loop {
66 match self.buf.next() {
67 Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
68 Ok(None) => {}
69 Err(e) => return Poll::Ready(Some(Err(e))),
70 }
71
72 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
73 Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
74 Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
75 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
76 None => return Poll::Ready(None),
77 }
78 }
79 }
80}
81
82pub async fn connect_raw<S, T>(
83 stream: S,
84 tls: T,
85 has_hostname: bool,
86 config: &Config,
87) -> Result<(Client, Connection<S, T::Stream>), Error>
88where
89 S: AsyncRead + AsyncWrite + Unpin,
90 T: TlsConnect<S>,
91{
92 let stream = connect_tls(
93 stream,
94 config.ssl_mode,
95 config.ssl_negotiation,
96 tls,
97 has_hostname,
98 )
99 .await?;
100
101 let mut stream = StartupStream {
102 inner: Framed::new(stream, PostgresCodec),
103 buf: BackendMessages::empty(),
104 delayed: VecDeque::new(),
105 };
106
107 let user = config
108 .user
109 .as_deref()
110 .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed);
111
112 startup(&mut stream, config, &user).await?;
113 authenticate(&mut stream, config, &user).await?;
114 let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
115
116 let (sender, receiver) = mpsc::unbounded();
117 let client = Client::new(
118 sender,
119 config.ssl_mode,
120 config.ssl_negotiation,
121 process_id,
122 secret_key,
123 );
124 let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
125
126 Ok((client, connection))
127}
128
129async fn startup<S, T>(
130 stream: &mut StartupStream<S, T>,
131 config: &Config,
132 user: &str,
133) -> Result<(), Error>
134where
135 S: AsyncRead + AsyncWrite + Unpin,
136 T: AsyncRead + AsyncWrite + Unpin,
137{
138 let mut params = vec![("client_encoding", "UTF8")];
139 params.push(("user", user));
140 if let Some(dbname) = &config.dbname {
141 params.push(("database", &**dbname));
142 }
143 if let Some(options) = &config.options {
144 params.push(("options", &**options));
145 }
146 if let Some(application_name) = &config.application_name {
147 params.push(("application_name", &**application_name));
148 }
149
150 let mut buf = BytesMut::new();
151 frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
152
153 stream
154 .send(FrontendMessage::Raw(buf.freeze()))
155 .await
156 .map_err(Error::io)
157}
158
159async fn authenticate<S, T>(
160 stream: &mut StartupStream<S, T>,
161 config: &Config,
162 user: &str,
163) -> Result<(), Error>
164where
165 S: AsyncRead + AsyncWrite + Unpin,
166 T: TlsStream + Unpin,
167{
168 match stream.try_next().await.map_err(Error::io)? {
169 Some(Message::AuthenticationOk) => {
170 can_skip_channel_binding(config)?;
171 return Ok(());
172 }
173 Some(Message::AuthenticationCleartextPassword) => {
174 can_skip_channel_binding(config)?;
175
176 let pass = config
177 .password
178 .as_ref()
179 .ok_or_else(|| Error::config("password missing".into()))?;
180
181 authenticate_password(stream, pass).await?;
182 }
183 Some(Message::AuthenticationMd5Password(body)) => {
184 can_skip_channel_binding(config)?;
185
186 let pass = config
187 .password
188 .as_ref()
189 .ok_or_else(|| Error::config("password missing".into()))?;
190
191 let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
192 authenticate_password(stream, output.as_bytes()).await?;
193 }
194 Some(Message::AuthenticationSasl(body)) => {
195 authenticate_sasl(stream, body, config).await?;
196 }
197 Some(Message::AuthenticationKerberosV5)
198 | Some(Message::AuthenticationScmCredential)
199 | Some(Message::AuthenticationGss)
200 | Some(Message::AuthenticationSspi) => {
201 return Err(Error::authentication(
202 "unsupported authentication method".into(),
203 ))
204 }
205 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
206 Some(_) => return Err(Error::unexpected_message()),
207 None => return Err(Error::closed()),
208 }
209
210 match stream.try_next().await.map_err(Error::io)? {
211 Some(Message::AuthenticationOk) => Ok(()),
212 Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
213 Some(_) => Err(Error::unexpected_message()),
214 None => Err(Error::closed()),
215 }
216}
217
218fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
219 match config.channel_binding {
220 config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
221 config::ChannelBinding::Require => Err(Error::authentication(
222 "server did not use channel binding".into(),
223 )),
224 }
225}
226
227async fn authenticate_password<S, T>(
228 stream: &mut StartupStream<S, T>,
229 password: &[u8],
230) -> Result<(), Error>
231where
232 S: AsyncRead + AsyncWrite + Unpin,
233 T: AsyncRead + AsyncWrite + Unpin,
234{
235 let mut buf = BytesMut::new();
236 frontend::password_message(password, &mut buf).map_err(Error::encode)?;
237
238 stream
239 .send(FrontendMessage::Raw(buf.freeze()))
240 .await
241 .map_err(Error::io)
242}
243
244async fn authenticate_sasl<S, T>(
245 stream: &mut StartupStream<S, T>,
246 body: AuthenticationSaslBody,
247 config: &Config,
248) -> Result<(), Error>
249where
250 S: AsyncRead + AsyncWrite + Unpin,
251 T: TlsStream + Unpin,
252{
253 let password = config
254 .password
255 .as_ref()
256 .ok_or_else(|| Error::config("password missing".into()))?;
257
258 let mut has_scram = false;
259 let mut has_scram_plus = false;
260 let mut mechanisms = body.mechanisms();
261 while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
262 match mechanism {
263 sasl::SCRAM_SHA_256 => has_scram = true,
264 sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
265 _ => {}
266 }
267 }
268
269 let channel_binding = stream
270 .inner
271 .get_ref()
272 .channel_binding()
273 .tls_server_end_point
274 .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
275 .map(sasl::ChannelBinding::tls_server_end_point);
276
277 let (channel_binding, mechanism) = if has_scram_plus {
278 match channel_binding {
279 Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
280 None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
281 }
282 } else if has_scram {
283 match channel_binding {
284 Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
285 None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
286 }
287 } else {
288 return Err(Error::authentication("unsupported SASL mechanism".into()));
289 };
290
291 if mechanism != sasl::SCRAM_SHA_256_PLUS {
292 can_skip_channel_binding(config)?;
293 }
294
295 let mut scram = ScramSha256::new(password, channel_binding);
296
297 let mut buf = BytesMut::new();
298 frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
299 stream
300 .send(FrontendMessage::Raw(buf.freeze()))
301 .await
302 .map_err(Error::io)?;
303
304 let body = match stream.try_next().await.map_err(Error::io)? {
305 Some(Message::AuthenticationSaslContinue(body)) => body,
306 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
307 Some(_) => return Err(Error::unexpected_message()),
308 None => return Err(Error::closed()),
309 };
310
311 scram
312 .update(body.data())
313 .map_err(|e| Error::authentication(e.into()))?;
314
315 let mut buf = BytesMut::new();
316 frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
317 stream
318 .send(FrontendMessage::Raw(buf.freeze()))
319 .await
320 .map_err(Error::io)?;
321
322 let body = match stream.try_next().await.map_err(Error::io)? {
323 Some(Message::AuthenticationSaslFinal(body)) => body,
324 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
325 Some(_) => return Err(Error::unexpected_message()),
326 None => return Err(Error::closed()),
327 };
328
329 scram
330 .finish(body.data())
331 .map_err(|e| Error::authentication(e.into()))?;
332
333 Ok(())
334}
335
336async fn read_info<S, T>(
337 stream: &mut StartupStream<S, T>,
338) -> Result<(i32, i32, HashMap<String, String>), Error>
339where
340 S: AsyncRead + AsyncWrite + Unpin,
341 T: AsyncRead + AsyncWrite + Unpin,
342{
343 let mut process_id = 0;
344 let mut secret_key = 0;
345 let mut parameters = HashMap::new();
346
347 loop {
348 match stream.try_next().await.map_err(Error::io)? {
349 Some(Message::BackendKeyData(body)) => {
350 process_id = body.process_id();
351 secret_key = body.secret_key();
352 }
353 Some(Message::ParameterStatus(body)) => {
354 parameters.insert(
355 body.name().map_err(Error::parse)?.to_string(),
356 body.value().map_err(Error::parse)?.to_string(),
357 );
358 }
359 Some(msg @ Message::NoticeResponse(_)) => {
360 stream.delayed.push_back(BackendMessage::Async(msg))
361 }
362 Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
363 Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
364 Some(_) => return Err(Error::unexpected_message()),
365 None => return Err(Error::closed()),
366 }
367 }
368}