Refactor SASL exchange to return tri-state status
authorDaniel Gustafsson <[email protected]>
Thu, 21 Mar 2024 13:45:46 +0000 (14:45 +0100)
committerDaniel Gustafsson <[email protected]>
Thu, 21 Mar 2024 13:45:46 +0000 (14:45 +0100)
The SASL exchange callback returned state in to output variables:
done and success.  This refactors that logic by introducing a new
return variable of type SASLStatus which makes the code easier to
read and understand, and prepares for future SASL exchanges which
operate asynchronously.

This was extracted from a larger patchset to introduce OAuthBearer
authentication and authorization.

Author: Jacob Champion <[email protected]>
Discussion: https://postgr.es/m/d1b467a78e0e36ed85a09adf979d04cf124a9d4b[email protected]

src/interfaces/libpq/fe-auth-sasl.h
src/interfaces/libpq/fe-auth-scram.c
src/interfaces/libpq/fe-auth.c
src/tools/pgindent/typedefs.list

index ee5d1525b559ee0790e14e054a7d26e37c46097c..4eecf53a1503097d25dbec3518c55fc61ef83fdb 100644 (file)
 
 #include "libpq-fe.h"
 
+/*
+ * Possible states for the SASL exchange, see the comment on exchange for an
+ * explanation of these.
+ */
+typedef enum
+{
+   SASL_COMPLETE = 0,
+   SASL_FAILED,
+   SASL_CONTINUE,
+} SASLStatus;
+
 /*
  * Frontend SASL mechanism callbacks.
  *
@@ -59,7 +70,8 @@ typedef struct pg_fe_sasl_mech
     * Produces a client response to a server challenge.  As a special case
     * for client-first SASL mechanisms, exchange() is called with a NULL
     * server response once at the start of the authentication exchange to
-    * generate an initial response.
+    * generate an initial response. Returns a SASLStatus indicating the
+    * state and status of the exchange.
     *
     * Input parameters:
     *
@@ -79,22 +91,23 @@ typedef struct pg_fe_sasl_mech
     *
     *  output:    A malloc'd buffer containing the client's response to
     *             the server (can be empty), or NULL if the exchange should
-    *             be aborted.  (*success should be set to false in the
+    *             be aborted.  (The callback should return SASL_FAILED in the
     *             latter case.)
     *
     *  outputlen: The length (0 or higher) of the client response buffer,
     *             ignored if output is NULL.
     *
-    *  done:      Set to true if the SASL exchange should not continue,
-    *             because the exchange is either complete or failed
+    * Return value:
     *
-    *  success:   Set to true if the SASL exchange completed successfully.
-    *             Ignored if *done is false.
+    *  SASL_CONTINUE:  The output buffer is filled with a client response.
+    *                  Additional server challenge is expected
+    *  SASL_COMPLETE:  The SASL exchange has completed successfully.
+    *  SASL_FAILED:    The exchange has failed and the connection should be
+    *                  dropped.
     *--------
     */
-   void        (*exchange) (void *state, char *input, int inputlen,
-                            char **output, int *outputlen,
-                            bool *done, bool *success);
+   SASLStatus  (*exchange) (void *state, char *input, int inputlen,
+                            char **output, int *outputlen);
 
    /*--------
     * channel_bound()
index 04f0e5713d0685d0caf9d7cccc085ec381018489..0bb820e0d97419509d51f5bac206afdb49d75e54 100644 (file)
@@ -24,9 +24,8 @@
 /* The exported SCRAM callback mechanism. */
 static void *scram_init(PGconn *conn, const char *password,
                        const char *sasl_mechanism);
-static void scram_exchange(void *opaq, char *input, int inputlen,
-                          char **output, int *outputlen,
-                          bool *done, bool *success);
+static SASLStatus scram_exchange(void *opaq, char *input, int inputlen,
+                                char **output, int *outputlen);
 static bool scram_channel_bound(void *opaq);
 static void scram_free(void *opaq);
 
@@ -202,17 +201,14 @@ scram_free(void *opaq)
 /*
  * Exchange a SCRAM message with backend.
  */
-static void
+static SASLStatus
 scram_exchange(void *opaq, char *input, int inputlen,
-              char **output, int *outputlen,
-              bool *done, bool *success)
+              char **output, int *outputlen)
 {
    fe_scram_state *state = (fe_scram_state *) opaq;
    PGconn     *conn = state->conn;
    const char *errstr = NULL;
 
-   *done = false;
-   *success = false;
    *output = NULL;
    *outputlen = 0;
 
@@ -225,12 +221,12 @@ scram_exchange(void *opaq, char *input, int inputlen,
        if (inputlen == 0)
        {
            libpq_append_conn_error(conn, "malformed SCRAM message (empty message)");
-           goto error;
+           return SASL_FAILED;
        }
        if (inputlen != strlen(input))
        {
            libpq_append_conn_error(conn, "malformed SCRAM message (length mismatch)");
-           goto error;
+           return SASL_FAILED;
        }
    }
 
@@ -240,61 +236,59 @@ scram_exchange(void *opaq, char *input, int inputlen,
            /* Begin the SCRAM handshake, by sending client nonce */
            *output = build_client_first_message(state);
            if (*output == NULL)
-               goto error;
+               return SASL_FAILED;
 
            *outputlen = strlen(*output);
-           *done = false;
            state->state = FE_SCRAM_NONCE_SENT;
-           break;
+           return SASL_CONTINUE;
 
        case FE_SCRAM_NONCE_SENT:
            /* Receive salt and server nonce, send response. */
            if (!read_server_first_message(state, input))
-               goto error;
+               return SASL_FAILED;
 
            *output = build_client_final_message(state);
            if (*output == NULL)
-               goto error;
+               return SASL_FAILED;
 
            *outputlen = strlen(*output);
-           *done = false;
            state->state = FE_SCRAM_PROOF_SENT;
-           break;
+           return SASL_CONTINUE;
 
        case FE_SCRAM_PROOF_SENT:
-           /* Receive server signature */
-           if (!read_server_final_message(state, input))
-               goto error;
-
-           /*
-            * Verify server signature, to make sure we're talking to the
-            * genuine server.
-            */
-           if (!verify_server_signature(state, success, &errstr))
-           {
-               libpq_append_conn_error(conn, "could not verify server signature: %s", errstr);
-               goto error;
-           }
-
-           if (!*success)
            {
-               libpq_append_conn_error(conn, "incorrect server signature");
+               bool        match;
+
+               /* Receive server signature */
+               if (!read_server_final_message(state, input))
+                   return SASL_FAILED;
+
+               /*
+                * Verify server signature, to make sure we're talking to the
+                * genuine server.
+                */
+               if (!verify_server_signature(state, &match, &errstr))
+               {
+                   libpq_append_conn_error(conn, "could not verify server signature: %s", errstr);
+                   return SASL_FAILED;
+               }
+
+               if (!match)
+               {
+                   libpq_append_conn_error(conn, "incorrect server signature");
+               }
+               state->state = FE_SCRAM_FINISHED;
+               state->conn->client_finished_auth = true;
+               return match ? SASL_COMPLETE : SASL_FAILED;
            }
-           *done = true;
-           state->state = FE_SCRAM_FINISHED;
-           state->conn->client_finished_auth = true;
-           break;
 
        default:
            /* shouldn't happen */
            libpq_append_conn_error(conn, "invalid SCRAM exchange state");
-           goto error;
+           break;
    }
-   return;
 
-error:
-   *done = true;
-   *success = false;
+   return SASL_FAILED;
 }
 
 /*
index 1a8e4f6fbfa354af3cccc537d8caeee45ab5b9e7..cf8af4c62e53d6da5a0245ec207dd0a8a75898e7 100644 (file)
@@ -423,11 +423,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 {
    char       *initialresponse = NULL;
    int         initialresponselen;
-   bool        done;
-   bool        success;
    const char *selected_mechanism;
    PQExpBufferData mechanism_buf;
    char       *password;
+   SASLStatus  status;
 
    initPQExpBuffer(&mechanism_buf);
 
@@ -575,12 +574,11 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        goto oom_error;
 
    /* Get the mechanism-specific Initial Client Response, if any */
-   conn->sasl->exchange(conn->sasl_state,
-                        NULL, -1,
-                        &initialresponse, &initialresponselen,
-                        &done, &success);
+   status = conn->sasl->exchange(conn->sasl_state,
+                                 NULL, -1,
+                                 &initialresponse, &initialresponselen);
 
-   if (done && !success)
+   if (status == SASL_FAILED)
        goto error;
 
    /*
@@ -629,10 +627,9 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 {
    char       *output;
    int         outputlen;
-   bool        done;
-   bool        success;
    int         res;
    char       *challenge;
+   SASLStatus  status;
 
    /* Read the SASL challenge from the AuthenticationSASLContinue message. */
    challenge = malloc(payloadlen + 1);
@@ -651,13 +648,12 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
    /* For safety and convenience, ensure the buffer is NULL-terminated. */
    challenge[payloadlen] = '\0';
 
-   conn->sasl->exchange(conn->sasl_state,
-                        challenge, payloadlen,
-                        &output, &outputlen,
-                        &done, &success);
+   status = conn->sasl->exchange(conn->sasl_state,
+                                 challenge, payloadlen,
+                                 &output, &outputlen);
    free(challenge);            /* don't need the input anymore */
 
-   if (final && !done)
+   if (final && status == SASL_CONTINUE)
    {
        if (outputlen != 0)
            free(output);
@@ -670,7 +666,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
     * If the exchange is not completed yet, we need to make sure that the
     * SASL mechanism has generated a message to send back.
     */
-   if (output == NULL && !done)
+   if (output == NULL && status == SASL_CONTINUE)
    {
        libpq_append_conn_error(conn, "no client response found after SASL exchange success");
        return STATUS_ERROR;
@@ -692,7 +688,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
            return STATUS_ERROR;
    }
 
-   if (done && !success)
+   if (status == SASL_FAILED)
        return STATUS_ERROR;
 
    return STATUS_OK;
index 3b8cec58abc8d823254dedb6d76f9f77fe2510b0..e2a0525dd4a5027750af8c0f2d14e4d5d3e1a893 100644 (file)
@@ -2442,6 +2442,7 @@ RuleLock
 RuleStmt
 RunningTransactions
 RunningTransactionsData
+SASLStatus
 SC_HANDLE
 SECURITY_ATTRIBUTES
 SECURITY_STATUS