21
21
22
22
#include <gnutls/gnutls.h>
23
23
#include <gnutls/x509.h>
24
+ #include <gnutls/abstract.h>
24
25
#include <ma_global.h>
25
26
#include <ma_sys.h>
26
27
#include <ma_common.h>
@@ -39,6 +40,12 @@ extern unsigned int mariadb_deinitialize_ssl;
39
40
40
41
static int my_verify_callback (gnutls_session_t ssl );
41
42
43
+ struct st_gnutls_data {
44
+ MYSQL * mysql ;
45
+ gnutls_privkey_t key ;
46
+ gnutls_pcert_st cert ;
47
+ };
48
+
42
49
struct st_cipher_map {
43
50
const char * openssl_name ;
44
51
const char * priority ;
@@ -96,6 +103,20 @@ const struct st_cipher_map gtls_ciphers[]=
96
103
{NULL , NULL , 0 , 0 , 0 }
97
104
};
98
105
106
+ /* free data assigned to the connection */
107
+ static void free_gnutls_data (struct st_gnutls_data * data )
108
+ {
109
+ if (data )
110
+ {
111
+ if (data -> key )
112
+ gnutls_privkey_deinit (data -> key );
113
+ gnutls_pcert_deinit (& data -> cert );
114
+ free (data );
115
+ }
116
+ }
117
+
118
+ /* map the gnutls cipher suite (defined by key exchange algorithm, cipher
119
+ and mac algorithm) to the corresponding OpenSSL cipher name */
99
120
static const char * openssl_cipher_name (gnutls_kx_algorithm_t kx ,
100
121
gnutls_cipher_algorithm_t cipher ,
101
122
gnutls_mac_algorithm_t mac )
@@ -112,6 +133,7 @@ static const char *openssl_cipher_name(gnutls_kx_algorithm_t kx,
112
133
return NULL ;
113
134
}
114
135
136
+ /* get priority string for a given openssl cipher name */
115
137
static const char * get_priority (const char * cipher_name )
116
138
{
117
139
unsigned int i = 0 ;
@@ -126,7 +148,7 @@ static const char *get_priority(const char *cipher_name)
126
148
127
149
#define MAX_SSL_ERR_LEN 100
128
150
129
- static void ma_tls_set_error (MYSQL * mysql , int ssl_errno )
151
+ static void ma_tls_set_error (MYSQL * mysql , void * ssl , int ssl_errno )
130
152
{
131
153
char ssl_error [MAX_SSL_ERR_LEN ];
132
154
const char * ssl_error_reason ;
@@ -137,6 +159,21 @@ static void ma_tls_set_error(MYSQL *mysql, int ssl_errno)
137
159
pvio -> set_error (mysql , CR_SSL_CONNECTION_ERROR , SQLSTATE_UNKNOWN , "Unknown SSL error" );
138
160
return ;
139
161
}
162
+
163
+ /* give a more descriptive error message for alerts */
164
+ if (ssl_errno == GNUTLS_E_FATAL_ALERT_RECEIVED )
165
+ {
166
+ gnutls_alert_description_t alert_desc ;
167
+ const char * alert_name ;
168
+ alert_desc = gnutls_alert_get ((gnutls_session_t )ssl );
169
+ alert_name = gnutls_alert_get_name (alert_desc );
170
+ snprintf (ssl_error , MAX_SSL_ERR_LEN , "fatal alert received: %s" ,
171
+ alert_name );
172
+ pvio -> set_error (mysql , CR_SSL_CONNECTION_ERROR , SQLSTATE_UNKNOWN , 0 ,
173
+ ssl_error );
174
+ return ;
175
+ }
176
+
140
177
if ((ssl_error_reason = gnutls_strerror (ssl_errno )))
141
178
{
142
179
pvio -> set_error (mysql , CR_SSL_CONNECTION_ERROR , SQLSTATE_UNKNOWN , 0 ,
@@ -249,15 +286,7 @@ static int ma_gnutls_set_ciphers(gnutls_session_t ssl, char *cipher_str)
249
286
while (token )
250
287
{
251
288
const char * p = get_priority (token );
252
- /* if cipher was not found, we pass the original token to
253
- the priority string, this will allow to specify gnutls
254
- specific settings via cipher */
255
- if (!p )
256
- {
257
- strncat (prio , ":" , PRIO_SIZE - strlen (prio ) - 1 );
258
- strncat (prio , token , PRIO_SIZE - strlen (prio ) - 1 );
259
- }
260
- else
289
+ if (p )
261
290
strncat (prio , p , PRIO_SIZE - strlen (prio ) - 1 );
262
291
token = strtok (NULL , ":" );
263
292
}
@@ -266,8 +295,6 @@ static int ma_gnutls_set_ciphers(gnutls_session_t ssl, char *cipher_str)
266
295
267
296
static int ma_tls_set_certs (MYSQL * mysql )
268
297
{
269
- char * certfile = mysql -> options .ssl_cert ,
270
- * keyfile = mysql -> options .ssl_key ;
271
298
int ssl_error = 0 ;
272
299
273
300
if (mysql -> options .ssl_ca )
@@ -282,32 +309,78 @@ static int ma_tls_set_certs(MYSQL *mysql)
282
309
gnutls_certificate_set_verify_function (GNUTLS_xcred ,
283
310
my_verify_callback );
284
311
285
- /* GNUTLS doesn't support ca_path */
312
+ return 1 ;
313
+
314
+ error :
315
+ return ssl_error ;
316
+ }
317
+
318
+ static int
319
+ client_cert_callback (gnutls_session_t session ,
320
+ const gnutls_datum_t * req_ca_rdn __attribute__((unused )),
321
+ int nreqs __attribute__((unused )),
322
+ const gnutls_pk_algorithm_t * sign_algos __attribute__((unused )),
323
+ int sign_algos_length __attribute__((unused )),
324
+ gnutls_pcert_st * * pcert ,
325
+ unsigned int * pcert_length , gnutls_privkey_t * pkey )
326
+ {
327
+ struct st_gnutls_data * session_data ;
328
+ char * certfile , * keyfile ;
329
+ gnutls_datum_t data ;
330
+ MYSQL * mysql ;
331
+ gnutls_certificate_type_t type = gnutls_certificate_type_get (session );
332
+
333
+ session_data = (struct st_gnutls_data * )gnutls_session_get_ptr (session );
334
+
335
+ if (!session_data -> mysql ||
336
+ type != GNUTLS_CRT_X509 )
337
+ return -1 ;
338
+
339
+ mysql = session_data -> mysql ;
286
340
341
+ certfile = session_data -> mysql -> options .ssl_cert ;
342
+ keyfile = session_data -> mysql -> options .ssl_key ;
343
+
344
+ if (!certfile && !keyfile )
345
+ return 0 ;
287
346
if (keyfile && !certfile )
288
347
certfile = keyfile ;
289
348
if (certfile && !keyfile )
290
349
keyfile = certfile ;
291
350
292
- /* set key */
293
- if (certfile || keyfile )
351
+ if (gnutls_load_file (certfile , & data ) < 0 )
352
+ return -1 ;
353
+ if (gnutls_pcert_import_x509_raw (& session_data -> cert , & data , GNUTLS_X509_FMT_PEM , 0 ) < 0 )
294
354
{
295
- if ((ssl_error = gnutls_certificate_set_x509_key_file2 (GNUTLS_xcred ,
296
- certfile , keyfile , GNUTLS_X509_FMT_PEM ,
297
- OPT_HAS_EXT_VAL (mysql , tls_pw ) ? mysql -> options .extension -> tls_pw : NULL ,
298
- 0 )) < 0 )
299
- goto error ;
355
+ gnutls_free (data .data );
356
+ return -1 ;
300
357
}
301
- return 1 ;
358
+ gnutls_free (data .data );
359
+
360
+ if (gnutls_load_file (keyfile , & data ) < 0 )
361
+ return -1 ;
362
+ gnutls_privkey_init (& session_data -> key );
363
+ if (gnutls_privkey_import_x509_raw (session_data -> key , & data ,
364
+ GNUTLS_X509_FMT_PEM ,
365
+ mysql -> options .extension ? mysql -> options .extension -> tls_pw : NULL ,
366
+ 0 ) < 0 )
367
+ {
368
+ gnutls_free (data .data );
369
+ return -1 ;
370
+ }
371
+ gnutls_free (data .data );
302
372
303
- error :
304
- return ssl_error ;
373
+ * pcert_length = 1 ;
374
+ * pcert = & session_data -> cert ;
375
+ * pkey = session_data -> key ;
376
+ return 0 ;
305
377
}
306
378
307
379
void * ma_tls_init (MYSQL * mysql )
308
380
{
309
381
gnutls_session_t ssl = NULL ;
310
382
int ssl_error = 0 ;
383
+ struct st_gnutls_data * data = NULL ;
311
384
312
385
pthread_mutex_lock (& LOCK_gnutls_config );
313
386
@@ -316,19 +389,28 @@ void *ma_tls_init(MYSQL *mysql)
316
389
317
390
if ((ssl_error = gnutls_init (& ssl , GNUTLS_CLIENT & GNUTLS_NONBLOCK )) < 0 )
318
391
goto error ;
319
- gnutls_session_set_ptr (ssl , (void * )mysql );
392
+
393
+ if (!(data = (struct st_gnutls_data * )calloc (1 , sizeof (struct st_gnutls_data ))))
394
+ goto error ;
395
+
396
+ data -> mysql = mysql ;
397
+ gnutls_certificate_set_retrieve_function2 (GNUTLS_xcred , client_cert_callback );
398
+ gnutls_session_set_ptr (ssl , (void * )data );
320
399
321
400
ssl_error = ma_gnutls_set_ciphers (ssl , mysql -> options .ssl_cipher );
322
401
if (ssl_error < 0 )
323
402
goto error ;
324
403
404
+ /* we don't load private key and cert by default - if the server requests
405
+ a client certificate we will send it via callback function */
325
406
if ((ssl_error = gnutls_credentials_set (ssl , GNUTLS_CRD_CERTIFICATE , GNUTLS_xcred )) < 0 )
326
407
goto error ;
327
408
328
409
pthread_mutex_unlock (& LOCK_gnutls_config );
329
410
return (void * )ssl ;
330
411
error :
331
- ma_tls_set_error (mysql , ssl_error );
412
+ free_gnutls_data (data );
413
+ ma_tls_set_error (mysql , ssl , ssl_error );
332
414
if (ssl )
333
415
gnutls_deinit (ssl );
334
416
pthread_mutex_unlock (& LOCK_gnutls_config );
@@ -362,7 +444,9 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls)
362
444
MYSQL * mysql ;
363
445
MARIADB_PVIO * pvio ;
364
446
int ret ;
365
- mysql = (MYSQL * )gnutls_session_get_ptr (ssl );
447
+ struct st_gnutls_data * data ;
448
+ data = (struct st_gnutls_data * )gnutls_session_get_ptr (ssl );
449
+ mysql = data -> mysql ;
366
450
367
451
if (!mysql )
368
452
return 1 ;
@@ -386,16 +470,16 @@ my_bool ma_tls_connect(MARIADB_TLS *ctls)
386
470
387
471
if (ret < 0 )
388
472
{
389
- ma_tls_set_error (mysql , ret );
473
+ ma_tls_set_error (mysql , ssl , ret );
390
474
/* restore blocking mode */
391
475
gnutls_deinit ((gnutls_session_t )ctls -> ssl );
476
+ free_gnutls_data (data );
392
477
ctls -> ssl = NULL ;
393
478
if (!blocking )
394
479
pvio -> methods -> blocking (pvio , FALSE, 0 );
395
480
return 1 ;
396
481
}
397
482
ctls -> ssl = (void * )ssl ;
398
-
399
483
return 0 ;
400
484
}
401
485
@@ -413,7 +497,15 @@ my_bool ma_tls_close(MARIADB_TLS *ctls)
413
497
{
414
498
if (ctls -> ssl )
415
499
{
416
- gnutls_bye ((gnutls_session_t )ctls -> ssl , GNUTLS_SHUT_WR );
500
+ MARIADB_PVIO * pvio = ctls -> pvio ;
501
+ struct st_gnutls_data * data =
502
+ (struct st_gnutls_data * )gnutls_session_get_ptr (ctls -> ssl );
503
+ /* this would be the correct way, however can't dectect afterwards
504
+ if the socket is closed or not, so we don't send encrypted
505
+ finish alert.
506
+ rc= gnutls_bye((gnutls_session_t )ctls->ssl, GNUTLS_SHUT_WR);
507
+ */
508
+ free_gnutls_data (data );
417
509
gnutls_deinit ((gnutls_session_t )ctls -> ssl );
418
510
ctls -> ssl = NULL ;
419
511
}
@@ -443,14 +535,19 @@ const char *ma_tls_get_cipher(MARIADB_TLS *ctls)
443
535
444
536
static int my_verify_callback (gnutls_session_t ssl )
445
537
{
446
- unsigned int status ;
538
+ unsigned int status = 0 ;
447
539
const gnutls_datum_t * cert_list ;
448
540
unsigned int cert_list_size ;
449
- MYSQL * mysql = (MYSQL * )gnutls_session_get_ptr (ssl );
450
- MARIADB_PVIO * pvio = mysql -> net .pvio ;
541
+ struct st_gnutls_data * data = (struct st_gnutls_data * )gnutls_session_get_ptr (ssl );
542
+ MYSQL * mysql ;
543
+ MARIADB_PVIO * pvio ;
544
+
451
545
gnutls_x509_crt_t cert ;
452
546
const char * hostname ;
453
547
548
+ mysql = data -> mysql ;
549
+ pvio = mysql -> net .pvio ;
550
+
454
551
/* read hostname */
455
552
hostname = mysql -> host ;
456
553
@@ -518,11 +615,13 @@ unsigned int ma_tls_get_finger_print(MARIADB_TLS *ctls, char *fp, unsigned int l
518
615
size_t fp_len = len ;
519
616
const gnutls_datum_t * cert_list ;
520
617
unsigned int cert_list_size ;
618
+ struct st_gnutls_data * data ;
521
619
522
620
if (!ctls || !ctls -> ssl )
523
621
return 0 ;
524
622
525
- mysql = (MYSQL * )gnutls_session_get_ptr (ctls -> ssl );
623
+ data = (struct st_gnutls_data * )gnutls_session_get_ptr (ctls -> ssl );
624
+ mysql = (MYSQL * )data -> mysql ;
526
625
527
626
cert_list = gnutls_certificate_get_peers (ctls -> ssl , & cert_list_size );
528
627
if (cert_list == NULL )
0 commit comments