15
15
16
16
#include <ctype.h>
17
17
18
- #include "common/hex_decode .h"
18
+ #include "common/hex .h"
19
19
#include "mb/pg_wchar.h"
20
20
#include "utils/builtins.h"
21
21
#include "utils/memutils.h"
32
32
*/
33
33
struct pg_encoding
34
34
{
35
- uint64 (* encode_len ) (const char * data , size_t dlen );
36
- uint64 (* decode_len ) (const char * data , size_t dlen );
37
- uint64 (* encode ) (const char * data , size_t dlen , char * res );
38
- uint64 (* decode ) (const char * data , size_t dlen , char * res );
35
+ uint64 (* encode_len ) (const char * src , size_t srclen );
36
+ uint64 (* decode_len ) (const char * src , size_t srclen );
37
+ uint64 (* encode ) (const char * src , size_t srclen ,
38
+ char * dst , size_t dstlen );
39
+ uint64 (* decode ) (const char * src , size_t srclen ,
40
+ char * dst , size_t dstlen );
39
41
};
40
42
41
43
static const struct pg_encoding * pg_find_encoding (const char * name );
@@ -81,11 +83,7 @@ binary_encode(PG_FUNCTION_ARGS)
81
83
82
84
result = palloc (VARHDRSZ + resultlen );
83
85
84
- res = enc -> encode (dataptr , datalen , VARDATA (result ));
85
-
86
- /* Make this FATAL 'cause we've trodden on memory ... */
87
- if (res > resultlen )
88
- elog (FATAL , "overflow - encode estimate too small" );
86
+ res = enc -> encode (dataptr , datalen , VARDATA (result ), resultlen );
89
87
90
88
SET_VARSIZE (result , VARHDRSZ + res );
91
89
@@ -129,11 +127,7 @@ binary_decode(PG_FUNCTION_ARGS)
129
127
130
128
result = palloc (VARHDRSZ + resultlen );
131
129
132
- res = enc -> decode (dataptr , datalen , VARDATA (result ));
133
-
134
- /* Make this FATAL 'cause we've trodden on memory ... */
135
- if (res > resultlen )
136
- elog (FATAL , "overflow - decode estimate too small" );
130
+ res = enc -> decode (dataptr , datalen , VARDATA (result ), resultlen );
137
131
138
132
SET_VARSIZE (result , VARHDRSZ + res );
139
133
@@ -145,32 +139,20 @@ binary_decode(PG_FUNCTION_ARGS)
145
139
* HEX
146
140
*/
147
141
148
- static const char hextbl [] = "0123456789abcdef" ;
149
-
150
- uint64
151
- hex_encode (const char * src , size_t len , char * dst )
152
- {
153
- const char * end = src + len ;
154
-
155
- while (src < end )
156
- {
157
- * dst ++ = hextbl [(* src >> 4 ) & 0xF ];
158
- * dst ++ = hextbl [* src & 0xF ];
159
- src ++ ;
160
- }
161
- return (uint64 ) len * 2 ;
162
- }
163
-
142
+ /*
143
+ * Those two wrappers are still needed to match with the layer of
144
+ * src/common/.
145
+ */
164
146
static uint64
165
147
hex_enc_len (const char * src , size_t srclen )
166
148
{
167
- return ( uint64 ) srclen << 1 ;
149
+ return pg_hex_enc_len ( srclen ) ;
168
150
}
169
151
170
152
static uint64
171
153
hex_dec_len (const char * src , size_t srclen )
172
154
{
173
- return ( uint64 ) srclen >> 1 ;
155
+ return pg_hex_dec_len ( srclen ) ;
174
156
}
175
157
176
158
/*
@@ -192,12 +174,12 @@ static const int8 b64lookup[128] = {
192
174
};
193
175
194
176
static uint64
195
- pg_base64_encode (const char * src , size_t len , char * dst )
177
+ pg_base64_encode (const char * src , size_t srclen , char * dst , size_t dstlen )
196
178
{
197
179
char * p ,
198
180
* lend = dst + 76 ;
199
181
const char * s ,
200
- * end = src + len ;
182
+ * end = src + srclen ;
201
183
int pos = 2 ;
202
184
uint32 buf = 0 ;
203
185
@@ -213,6 +195,8 @@ pg_base64_encode(const char *src, size_t len, char *dst)
213
195
/* write it out */
214
196
if (pos < 0 )
215
197
{
198
+ if ((p - dst + 4 ) > dstlen )
199
+ elog (ERROR , "overflow of destination buffer in base64 encoding" );
216
200
* p ++ = _base64 [(buf >> 18 ) & 0x3f ];
217
201
* p ++ = _base64 [(buf >> 12 ) & 0x3f ];
218
202
* p ++ = _base64 [(buf >> 6 ) & 0x3f ];
@@ -223,25 +207,30 @@ pg_base64_encode(const char *src, size_t len, char *dst)
223
207
}
224
208
if (p >= lend )
225
209
{
210
+ if ((p - dst + 1 ) > dstlen )
211
+ elog (ERROR , "overflow of destination buffer in base64 encoding" );
226
212
* p ++ = '\n' ;
227
213
lend = p + 76 ;
228
214
}
229
215
}
230
216
if (pos != 2 )
231
217
{
218
+ if ((p - dst + 4 ) > dstlen )
219
+ elog (ERROR , "overflow of destination buffer in base64 encoding" );
232
220
* p ++ = _base64 [(buf >> 18 ) & 0x3f ];
233
221
* p ++ = _base64 [(buf >> 12 ) & 0x3f ];
234
222
* p ++ = (pos == 0 ) ? _base64 [(buf >> 6 ) & 0x3f ] : '=' ;
235
223
* p ++ = '=' ;
236
224
}
237
225
226
+ Assert ((p - dst ) <= dstlen );
238
227
return p - dst ;
239
228
}
240
229
241
230
static uint64
242
- pg_base64_decode (const char * src , size_t len , char * dst )
231
+ pg_base64_decode (const char * src , size_t srclen , char * dst , size_t dstlen )
243
232
{
244
- const char * srcend = src + len ,
233
+ const char * srcend = src + srclen ,
245
234
* s = src ;
246
235
char * p = dst ;
247
236
char c ;
@@ -289,11 +278,21 @@ pg_base64_decode(const char *src, size_t len, char *dst)
289
278
pos ++ ;
290
279
if (pos == 4 )
291
280
{
281
+ if ((p - dst + 1 ) > dstlen )
282
+ elog (ERROR , "overflow of destination buffer in base64 decoding" );
292
283
* p ++ = (buf >> 16 ) & 255 ;
293
284
if (end == 0 || end > 1 )
285
+ {
286
+ if ((p - dst + 1 ) > dstlen )
287
+ elog (ERROR , "overflow of destination buffer in base64 decoding" );
294
288
* p ++ = (buf >> 8 ) & 255 ;
289
+ }
295
290
if (end == 0 || end > 2 )
291
+ {
292
+ if ((p - dst + 1 ) > dstlen )
293
+ elog (ERROR , "overflow of destination buffer in base64 decoding" );
296
294
* p ++ = buf & 255 ;
295
+ }
297
296
buf = 0 ;
298
297
pos = 0 ;
299
298
}
@@ -305,6 +304,7 @@ pg_base64_decode(const char *src, size_t len, char *dst)
305
304
errmsg ("invalid base64 end sequence" ),
306
305
errhint ("Input data is missing padding, is truncated, or is otherwise corrupted." )));
307
306
307
+ Assert ((p - dst ) <= dstlen );
308
308
return p - dst ;
309
309
}
310
310
@@ -340,7 +340,7 @@ pg_base64_dec_len(const char *src, size_t srclen)
340
340
#define DIG (VAL ) ((VAL) + '0')
341
341
342
342
static uint64
343
- esc_encode (const char * src , size_t srclen , char * dst )
343
+ esc_encode (const char * src , size_t srclen , char * dst , size_t dstlen )
344
344
{
345
345
const char * end = src + srclen ;
346
346
char * rp = dst ;
@@ -352,6 +352,8 @@ esc_encode(const char *src, size_t srclen, char *dst)
352
352
353
353
if (c == '\0' || IS_HIGHBIT_SET (c ))
354
354
{
355
+ if ((rp - dst + 4 ) > dstlen )
356
+ elog (ERROR , "overflow of destination buffer in escape encoding" );
355
357
rp [0 ] = '\\' ;
356
358
rp [1 ] = DIG (c >> 6 );
357
359
rp [2 ] = DIG ((c >> 3 ) & 7 );
@@ -361,25 +363,30 @@ esc_encode(const char *src, size_t srclen, char *dst)
361
363
}
362
364
else if (c == '\\' )
363
365
{
366
+ if ((rp - dst + 2 ) > dstlen )
367
+ elog (ERROR , "overflow of destination buffer in escape encoding" );
364
368
rp [0 ] = '\\' ;
365
369
rp [1 ] = '\\' ;
366
370
rp += 2 ;
367
371
len += 2 ;
368
372
}
369
373
else
370
374
{
375
+ if ((rp - dst + 1 ) > dstlen )
376
+ elog (ERROR , "overflow of destination buffer in escape encoding" );
371
377
* rp ++ = c ;
372
378
len ++ ;
373
379
}
374
380
375
381
src ++ ;
376
382
}
377
383
384
+ Assert ((rp - dst ) <= dstlen );
378
385
return len ;
379
386
}
380
387
381
388
static uint64
382
- esc_decode (const char * src , size_t srclen , char * dst )
389
+ esc_decode (const char * src , size_t srclen , char * dst , size_t dstlen )
383
390
{
384
391
const char * end = src + srclen ;
385
392
char * rp = dst ;
@@ -388,7 +395,11 @@ esc_decode(const char *src, size_t srclen, char *dst)
388
395
while (src < end )
389
396
{
390
397
if (src [0 ] != '\\' )
398
+ {
399
+ if ((rp - dst + 1 ) > dstlen )
400
+ elog (ERROR , "overflow of destination buffer in escape decoding" );
391
401
* rp ++ = * src ++ ;
402
+ }
392
403
else if (src + 3 < end &&
393
404
(src [1 ] >= '0' && src [1 ] <= '3' ) &&
394
405
(src [2 ] >= '0' && src [2 ] <= '7' ) &&
@@ -400,12 +411,16 @@ esc_decode(const char *src, size_t srclen, char *dst)
400
411
val <<= 3 ;
401
412
val += VAL (src [2 ]);
402
413
val <<= 3 ;
414
+ if ((rp - dst + 1 ) > dstlen )
415
+ elog (ERROR , "overflow of destination buffer in escape decoding" );
403
416
* rp ++ = val + VAL (src [3 ]);
404
417
src += 4 ;
405
418
}
406
419
else if (src + 1 < end &&
407
420
(src [1 ] == '\\' ))
408
421
{
422
+ if ((rp - dst + 1 ) > dstlen )
423
+ elog (ERROR , "overflow of destination buffer in escape decoding" );
409
424
* rp ++ = '\\' ;
410
425
src += 2 ;
411
426
}
@@ -423,6 +438,7 @@ esc_decode(const char *src, size_t srclen, char *dst)
423
438
len ++ ;
424
439
}
425
440
441
+ Assert ((rp - dst ) <= dstlen );
426
442
return len ;
427
443
}
428
444
@@ -504,7 +520,7 @@ static const struct
504
520
{
505
521
"hex" ,
506
522
{
507
- hex_enc_len , hex_dec_len , hex_encode , hex_decode
523
+ hex_enc_len , hex_dec_len , pg_hex_encode , pg_hex_decode
508
524
}
509
525
},
510
526
{
0 commit comments