diff options
author | Marko Kreen | 2008-02-28 09:27:25 +0000 |
---|---|---|
committer | Marko Kreen | 2008-02-28 09:27:25 +0000 |
commit | 012e25634c81aaa5ad32084d4d588ec43fd5c838 (patch) | |
tree | 5f9f8cdb66b37cc079fab43f7232c4b59cad3e07 | |
parent | 8408b4ae8498b96695e4c730054beb092a2c7967 (diff) |
bring new quoting & parsing code to head
-rw-r--r-- | python/modules/cquoting.c | 637 | ||||
-rw-r--r-- | python/skytools/__init__.py | 2 | ||||
-rw-r--r-- | python/skytools/_pyquoting.py | 153 | ||||
-rw-r--r-- | python/skytools/parsing.py | 272 | ||||
-rw-r--r-- | python/skytools/quoting.py | 247 | ||||
-rwxr-xr-x | setup.py | 1 |
6 files changed, 1073 insertions, 239 deletions
diff --git a/python/modules/cquoting.c b/python/modules/cquoting.c new file mode 100644 index 00000000..7d22e1ec --- /dev/null +++ b/python/modules/cquoting.c @@ -0,0 +1,637 @@ + +#define PY_SSIZE_T_CLEAN +#include <Python.h> + +#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +#endif + +typedef enum { false = 0, true = 1 } bool; + +/* + * Common buffer management. + */ + +struct Buf { + unsigned char *ptr; + unsigned long pos; + unsigned long alloc; +}; + +static unsigned char *buf_init(struct Buf *buf, unsigned init_size) +{ + if (init_size < 256) + init_size = 256; + buf->ptr = PyMem_Malloc(init_size); + if (buf->ptr) { + buf->pos = 0; + buf->alloc = init_size; + } + return buf->ptr; +} + +/* return new pos */ +static unsigned char *buf_enlarge(struct Buf *buf, unsigned need_room) +{ + unsigned alloc = buf->alloc; + unsigned need_size = buf->pos + need_room; + unsigned char *ptr; + + /* no alloc needed */ + if (need_size < alloc) + return buf->ptr + buf->pos; + + if (alloc <= need_size / 2) + alloc = need_size; + else + alloc = alloc * 2; + + ptr = PyMem_Realloc(buf->ptr, alloc); + if (!ptr) + return NULL; + + buf->ptr = ptr; + buf->alloc = alloc; + return buf->ptr + buf->pos; +} + +static void buf_free(struct Buf *buf) +{ + PyMem_Free(buf->ptr); + buf->ptr = NULL; + buf->pos = buf->alloc = 0; +} + +static inline unsigned char *buf_get_target_for(struct Buf *buf, unsigned len) +{ + if (buf->pos + len <= buf->alloc) + return buf->ptr + buf->pos; + else + return buf_enlarge(buf, len); +} + +static inline void buf_set_target(struct Buf *buf, unsigned char *newpos) +{ + assert(buf->ptr + buf->pos <= newpos); + assert(buf->ptr + buf->alloc >= newpos); + + buf->pos = newpos - buf->ptr; +} + +static inline int buf_put(struct Buf *buf, unsigned char c) +{ + if (buf->pos < buf->alloc) { + buf->ptr[buf->pos++] = c; + return 1; + } else if (buf_enlarge(buf, 1)) { + buf->ptr[buf->pos++] = c; + return 1; + } + return 0; +} + +static PyObject *buf_pystr(struct Buf *buf, unsigned start_pos, unsigned char *newpos) +{ + PyObject *res; + if (newpos) + buf_set_target(buf, newpos); + res = PyString_FromStringAndSize((char *)buf->ptr + start_pos, buf->pos - start_pos); + buf_free(buf); + return res; +} + +/* + * Get string data + */ + +static Py_ssize_t get_buffer(PyObject *obj, unsigned char **buf_p, PyObject **tmp_obj_p) +{ + PyBufferProcs *bfp; + PyObject *str = NULL; + Py_ssize_t res; + + /* check for None */ + if (obj == Py_None) { + PyErr_Format(PyExc_TypeError, "None is not allowed here"); + return -1; + } + + /* is string or unicode ? */ + if (PyString_Check(obj) || PyUnicode_Check(obj)) { + if (PyString_AsStringAndSize(obj, (char**)buf_p, &res) < 0) + return -1; + return res; + } + + /* try to get buffer */ + bfp = obj->ob_type->tp_as_buffer; + if (bfp && bfp->bf_getsegcount(obj, NULL) == 1) + return bfp->bf_getreadbuffer(obj, 0, (void**)buf_p); + + /* + * Not a string-like object, run str() or it. + */ + + /* are we in recursion? */ + if (tmp_obj_p == NULL) { + PyErr_Format(PyExc_TypeError, "Cannot convert to string - get_buffer() recusively failed"); + return -1; + } + + /* do str() then */ + str = PyObject_Str(obj); + res = -1; + if (str != NULL) { + res = get_buffer(str, buf_p, NULL); + if (res >= 0) { + *tmp_obj_p = str; + } else { + Py_CLEAR(str); + } + } + return res; +} + +/* + * Common argument parsing. + */ + +typedef PyObject *(*quote_fn)(unsigned char *src, Py_ssize_t src_len); + +static PyObject *common_quote(PyObject *args, quote_fn qfunc) +{ + unsigned char *src = NULL; + Py_ssize_t src_len = 0; + PyObject *arg, *res, *strtmp = NULL; + if (!PyArg_ParseTuple(args, "O", &arg)) + return NULL; + if (arg != Py_None) { + src_len = get_buffer(arg, &src, &strtmp); + if (src_len < 0) + return NULL; + } + res = qfunc(src, src_len); + Py_CLEAR(strtmp); + return res; +} + +/* + * Simple quoting functions. + */ + +static const char doc_quote_literal[] = +"Quote a literal value for SQL.\n" +"\n" +"If string contains '\\', it is quoted and result is prefixed with E.\n" +"Input value of None results in string \"null\" without quotes.\n" +"\n" +"C implementation.\n"; + +static PyObject *quote_literal_body(unsigned char *src, Py_ssize_t src_len) +{ + struct Buf buf; + unsigned char *esc, *dst, *src_end = src + src_len; + unsigned int start_ofs = 1; + + if (src == NULL) + return PyString_FromString("null"); + + esc = dst = buf_init(&buf, src_len * 2 + 2 + 1); + if (!dst) + return NULL; + + *dst++ = ' '; + *dst++ = '\''; + while (src < src_end) { + if (*src == '\\') { + *dst++ = '\\'; + start_ofs = 0; + } else if (*src == '\'') { + *dst++ = '\''; + } + *dst++ = *src++; + } + *dst++ = '\''; + if (start_ofs == 0) + *esc = 'E'; + return buf_pystr(&buf, start_ofs, dst); +} + +static PyObject *quote_literal(PyObject *self, PyObject *args) +{ + return common_quote(args, quote_literal_body); +} + +/* COPY field */ +static const char doc_quote_copy[] = +"Quoting for COPY data. None is converted to \\N.\n\n" +"C implementation."; + +static PyObject *quote_copy_body(unsigned char *src, Py_ssize_t src_len) +{ + unsigned char *dst, *src_end = src + src_len; + struct Buf buf; + + if (src == NULL) + return PyString_FromString("\\N"); + + dst = buf_init(&buf, src_len * 2); + if (!dst) + return NULL; + + while (src < src_end) { + switch (*src) { + case '\t': *dst++ = '\\'; *dst++ = 't'; src++; break; + case '\n': *dst++ = '\\'; *dst++ = 'n'; src++; break; + case '\r': *dst++ = '\\'; *dst++ = 'r'; src++; break; + case '\\': *dst++ = '\\'; *dst++ = '\\'; src++; break; + default: *dst++ = *src++; break; + } + } + return buf_pystr(&buf, 0, dst); +} + +static PyObject *quote_copy(PyObject *self, PyObject *args) +{ + return common_quote(args, quote_copy_body); +} + +/* raw bytea for byteain() */ +static const char doc_quote_bytea_raw[] = +"Quoting for bytea parser. Returns None as None.\n" +"\n" +"C implementation."; + +static PyObject *quote_bytea_raw_body(unsigned char *src, Py_ssize_t src_len) +{ + unsigned char *dst, *src_end = src + src_len; + struct Buf buf; + + if (src == NULL) { + Py_INCREF(Py_None); + return Py_None; + } + + dst = buf_init(&buf, src_len * 4); + if (!dst) + return NULL; + + while (src < src_end) { + if (*src < 0x20 || *src >= 0x7F) { + *dst++ = '\\'; + *dst++ = '0' + (*src >> 6); + *dst++ = '0' + ((*src >> 3) & 7); + *dst++ = '0' + (*src & 7); + src++; + } else { + if (*src == '\\') + *dst++ = '\\'; + *dst++ = *src++; + } + } + return buf_pystr(&buf, 0, dst); +} + +static PyObject *quote_bytea_raw(PyObject *self, PyObject *args) +{ + return common_quote(args, quote_bytea_raw_body); +} + +/* C unescape */ +static const char doc_unescape[] = +"Unescape C-style escaped string.\n\n" +"C implementation."; + +static PyObject *unescape_body(unsigned char *src, Py_ssize_t src_len) +{ + unsigned char *dst, *src_end = src + src_len; + struct Buf buf; + + if (src == NULL) { + PyErr_Format(PyExc_TypeError, "None not allowed"); + return NULL; + } + + dst = buf_init(&buf, src_len); + if (!dst) + return NULL; + + while (src < src_end) { + if (*src != '\\') { + *dst++ = *src++; + continue; + } + if (++src >= src_end) + goto failed; + switch (*src) { + case 't': *dst++ = '\t'; src++; break; + case 'n': *dst++ = '\n'; src++; break; + case 'r': *dst++ = '\r'; src++; break; + case 'a': *dst++ = '\a'; src++; break; + case 'b': *dst++ = '\b'; src++; break; + default: + if (*src >= '0' && *src <= '7') { + unsigned char c = *src++ - '0'; + if (src < src_end && *src >= '0' && *src <= '7') { + c = (c << 3) | ((*src++) - '0'); + if (src < src_end && *src >= '0' && *src <= '7') + c = (c << 3) | ((*src++) - '0'); + } + *dst++ = c; + } else { + *dst++ = *src++; + } + } + } + return buf_pystr(&buf, 0, dst); +failed: + PyErr_Format(PyExc_ValueError, "Broken string - \\ at the end"); + return NULL; +} + +static PyObject *unescape(PyObject *self, PyObject *args) +{ + return common_quote(args, unescape_body); +} + +/* + * urlencode of dict + */ + +static bool urlenc(struct Buf *buf, PyObject *obj) +{ + Py_ssize_t len; + unsigned char *src, *dst; + PyObject *strtmp = NULL; + static const unsigned char hextbl[] = "0123456789abcdef"; + bool ok = false; + + len = get_buffer(obj, &src, &strtmp); + if (len < 0) + goto failed; + + dst = buf_get_target_for(buf, len * 3); + if (!dst) + goto failed; + + while (len--) { + if ((*src >= 'a' && *src <= 'z') || + (*src >= 'A' && *src <= 'Z') || + (*src >= '0' && *src <= '9') || + (*src == '.' || *src == '_' || *src == '-')) + { + *dst++ = *src++; + } else if (*src == ' ') { + *dst++ = '+'; src++; + } else { + *dst++ = '%'; + *dst++ = hextbl[*src >> 4]; + *dst++ = hextbl[*src & 0xF]; + src++; + } + } + buf_set_target(buf, dst); + ok = true; +failed: + Py_CLEAR(strtmp); + return ok; +} + +/* urlencode key+val pair. val can be None */ +static bool urlenc_keyval(struct Buf *buf, PyObject *key, PyObject *value, bool needAmp) +{ + if (needAmp && !buf_put(buf, '&')) + return false; + if (!urlenc(buf, key)) + return false; + if (value != Py_None) { + if (!buf_put(buf, '=')) + return false; + if (!urlenc(buf, value)) + return false; + } + return true; +} + +/* encode native dict using PyDict_Next */ +static PyObject *encode_dict(PyObject *data) +{ + PyObject *key, *value; + Py_ssize_t pos = 0; + bool needAmp = false; + struct Buf buf; + if (!buf_init(&buf, 1024)) + return NULL; + while (PyDict_Next(data, &pos, &key, &value)) { + if (!urlenc_keyval(&buf, key, value, needAmp)) + goto failed; + needAmp = true; + } + return buf_pystr(&buf, 0, NULL); +failed: + buf_free(&buf); + return NULL; +} + +/* encode custom object using .iteritems() */ +static PyObject *encode_dictlike(PyObject *data) +{ + PyObject *key = NULL, *value = NULL, *tup, *iter; + struct Buf buf; + bool needAmp = false; + + if (!buf_init(&buf, 1024)) + return NULL; + + iter = PyObject_CallMethod(data, "iteritems", NULL); + if (iter == NULL) { + buf_free(&buf); + return NULL; + } + + while ((tup = PyIter_Next(iter))) { + key = PySequence_GetItem(tup, 0); + value = key ? PySequence_GetItem(tup, 1) : NULL; + Py_CLEAR(tup); + if (!key || !value) + goto failed; + + if (!urlenc_keyval(&buf, key, value, needAmp)) + goto failed; + needAmp = true; + + Py_CLEAR(key); + Py_CLEAR(value); + } + /* allow error from iterator */ + if (PyErr_Occurred()) + goto failed; + + Py_CLEAR(iter); + return buf_pystr(&buf, 0, NULL); +failed: + buf_free(&buf); + Py_CLEAR(iter); + Py_CLEAR(key); + Py_CLEAR(value); + return NULL; +} + +static const char doc_db_urlencode[] = +"Urlencode for database records.\n" +"If a value is None the key is output without '='.\n" +"\n" +"C implementation."; + +static PyObject *db_urlencode(PyObject *self, PyObject *args) +{ + PyObject *data; + if (!PyArg_ParseTuple(args, "O", &data)) + return NULL; + if (PyDict_Check(data)) { + return encode_dict(data); + } else { + return encode_dictlike(data); + } +} + +/* + * urldecode to dict + */ + +static inline int gethex(unsigned char c) +{ + if (c >= '0' && c <= '9') return c - '0'; + c |= 0x20; + if (c >= 'a' && c <= 'f') return c - 'a' + 10; + return -1; +} + +static PyObject *get_elem(unsigned char *buf, unsigned char **src_p, unsigned char *src_end) +{ + int c1, c2; + unsigned char *src = *src_p; + unsigned char *dst = buf; + + while (src < src_end) { + switch (*src) { + case '%': + if (++src + 2 > src_end) + goto hex_incomplete; + if ((c1 = gethex(*src++)) < 0) + goto hex_invalid; + if ((c2 = gethex(*src++)) < 0) + goto hex_invalid; + *dst++ = (c1 << 4) | c2; + break; + case '+': + *dst++ = ' '; src++; + break; + case '&': + case '=': + goto gotit; + default: + *dst++ = *src++; + } + } +gotit: + *src_p = src; + return PyString_FromStringAndSize((char *)buf, dst - buf); + +hex_incomplete: + PyErr_Format(PyExc_ValueError, "Incomplete hex code"); + return NULL; +hex_invalid: + PyErr_Format(PyExc_ValueError, "Invalid hex code"); + return NULL; +} + +static const char doc_db_urldecode[] = +"Urldecode from string to dict.\n" +"NULL are detected by missing '='.\n" +"Duplicate keys are ignored - only latest is kept.\n" +"\n" +"C implementation."; + +static PyObject *db_urldecode(PyObject *self, PyObject *args) +{ + unsigned char *src, *src_end; + Py_ssize_t src_len; + PyObject *dict = NULL, *key = NULL, *value = NULL; + struct Buf buf; + + if (!PyArg_ParseTuple(args, "t#", &src, &src_len)) + return NULL; + if (!buf_init(&buf, src_len)) + return NULL; + + dict = PyDict_New(); + if (!dict) { + buf_free(&buf); + return NULL; + } + + src_end = src + src_len; + while (src < src_end) { + if (*src == '&') { + src++; + continue; + } + + key = get_elem(buf.ptr, &src, src_end); + if (!key) + goto failed; + + if (src < src_end && *src == '=') { + src++; + value = get_elem(buf.ptr, &src, src_end); + if (value == NULL) + goto failed; + } else { + Py_INCREF(Py_None); + value = Py_None; + } + + /* lessen memory usage by intering */ + PyString_InternInPlace(&key); + + if (PyDict_SetItem(dict, key, value) < 0) + goto failed; + Py_CLEAR(key); + Py_CLEAR(value); + } + buf_free(&buf); + return dict; +failed: + buf_free(&buf); + Py_CLEAR(key); + Py_CLEAR(value); + Py_CLEAR(dict); + return NULL; +} + +/* + * Module initialization + */ + +static PyMethodDef +cquoting_methods[] = { + { "quote_literal", quote_literal, METH_VARARGS, doc_quote_literal }, + { "quote_copy", quote_copy, METH_VARARGS, doc_quote_copy }, + { "quote_bytea_raw", quote_bytea_raw, METH_VARARGS, doc_quote_bytea_raw }, + { "unescape", unescape, METH_VARARGS, doc_unescape }, + { "db_urlencode", db_urlencode, METH_VARARGS, doc_db_urlencode }, + { "db_urldecode", db_urldecode, METH_VARARGS, doc_db_urldecode }, + { NULL } +}; + +PyMODINIT_FUNC +init_cquoting(void) +{ + PyObject *module; + module = Py_InitModule("_cquoting", cquoting_methods); + PyModule_AddStringConstant(module, "__doc__", "fast quoting for skytools"); +} + diff --git a/python/skytools/__init__.py b/python/skytools/__init__.py index 7b7dd126..89884095 100644 --- a/python/skytools/__init__.py +++ b/python/skytools/__init__.py @@ -9,6 +9,7 @@ from gzlog import * from scripting import * from sqltools import * from quoting import * +from parsing import * __all__ = (psycopgwrapper.__all__ + config.__all__ @@ -16,5 +17,6 @@ __all__ = (psycopgwrapper.__all__ + gzlog.__all__ + scripting.__all__ + sqltools.__all__ + + parsing.__all__ + quoting.__all__ ) diff --git a/python/skytools/_pyquoting.py b/python/skytools/_pyquoting.py new file mode 100644 index 00000000..28a57577 --- /dev/null +++ b/python/skytools/_pyquoting.py @@ -0,0 +1,153 @@ +# _pyquoting.py + +"""Various helpers for string quoting/unquoting. + +Here is pure Python that should match C code in _cquoting. +""" + +import urllib, re + +__all__ = [ + "quote_literal", "quote_copy", "quote_bytea_raw", + "db_urlencode", "db_urldecode", "unescape", +] + +# +# SQL quoting +# + +def quote_literal(s): + """Quote a literal value for SQL. + + If string contains '\\', it is quoted and result is prefixed with E. + Input value of None results in string "null" without quotes. + + Python implementation. + """ + + if s == None: + return "null" + s = str(s).replace("'", "''") + s2 = s.replace("\\", "\\\\") + if len(s) != len(s2): + return "E'" + s2 + "'" + return "'" + s2 + "'" + +def quote_copy(s): + """Quoting for copy command. None is converted to \\N. + + Python implementation. + """ + + if s == None: + return "\\N" + s = str(s) + s = s.replace("\\", "\\\\") + s = s.replace("\t", "\\t") + s = s.replace("\n", "\\n") + s = s.replace("\r", "\\r") + return s + +_bytea_map = None +def quote_bytea_raw(s): + """Quoting for bytea parser. Returns None as None. + + Python implementation. + """ + global _bytea_map + if s == None: + return None + if 1 and _bytea_map is None: + _bytea_map = {} + for i in xrange(256): + c = chr(i) + if i < 0x20 or i >= 0x7F: + _bytea_map[c] = "\\%03o" % i + elif c == "\\": + _bytea_map[c] = r"\\" + else: + _bytea_map[c] = c + return "".join([_bytea_map[c] for c in s]) + # faster but does not match c code + #return s.replace("\\", "\\\\").replace("\0", "\\000") + +# +# Database specific urlencode and urldecode. +# + +def db_urlencode(dict): + """Database specific urlencode. + + Encode None as key without '='. That means that in "foo&bar=", + foo is NULL and bar is empty string. + + Python implementation. + """ + + elem_list = [] + for k, v in dict.items(): + if v is None: + elem = urllib.quote_plus(str(k)) + else: + elem = urllib.quote_plus(str(k)) + '=' + urllib.quote_plus(str(v)) + elem_list.append(elem) + return '&'.join(elem_list) + +def db_urldecode(qs): + """Database specific urldecode. + + Decode key without '=' as None. + This also does not support one key several times. + + Python implementation. + """ + + res = {} + for elem in qs.split('&'): + if not elem: + continue + pair = elem.split('=', 1) + name = urllib.unquote_plus(pair[0]) + + # keep only one instance around + name = intern(str(name)) + + if len(pair) == 1: + res[name] = None + else: + res[name] = urllib.unquote_plus(pair[1]) + return res + +# +# Remove C-like backslash escapes +# + +_esc_re = r"\\([0-7]{1,3}|.)" +_esc_rc = re.compile(_esc_re) +_esc_map = { + 't': '\t', + 'n': '\n', + 'r': '\r', + 'a': '\a', + 'b': '\b', + "'": "'", + '"': '"', + '\\': '\\', +} + +def _sub_unescape(m): + v = m.group(1) + if (len(v) == 1) and (v < '0' or v > '7'): + try: + return _esc_map[v] + except KeyError: + return v + else: + return chr(int(v, 8)) + +def unescape(val): + """Removes C-style escapes from string. + Python implementation. + """ + return _esc_rc.sub(_sub_unescape, val) + diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py new file mode 100644 index 00000000..1f4dd781 --- /dev/null +++ b/python/skytools/parsing.py @@ -0,0 +1,272 @@ + +"""Various parsers for Postgres-specific data formats.""" + +import re + +from skytools.quoting import unescape + +__all__ = ["parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements"] + +_rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X) + +# _parse_pgarray +def parse_pgarray(array): + """ Parse Postgres array and return list of items inside it + Used to deserialize data recived from service layer parameters + """ + if not array or array[0] != "{": + raise Exception("bad array format: must start with {") + res = [] + pos = 1 + while 1: + m = _rc_listelem.search(array, pos) + if not m: + break + pos2 = m.end() + item = array[pos:pos2] + if len(item) > 0 and item[0] == '"': + item = item[1:-1] + item = unescape(item) + res.append(item) + + pos = pos2 + 1 + if array[pos2] == "}": + break + elif array[pos2] != ",": + raise Exception("bad array format: expected ,} got " + array[pos2]) + return res + +# +# parse logtriga partial sql +# + +class _logtriga_parser: + token_re = r""" + [ \t\r\n]* + ( [a-z][a-z0-9_]* + | ["] ( [^"\\]+ | \\. )* ["] + | ['] ( [^'\\]+ | \\. | [']['] )* ['] + | [^ \t\r\n] + )""" + token_rc = None + + def tokenizer(self, sql): + if not _logtriga_parser.token_rc: + _logtriga_parser.token_rc = re.compile(self.token_re, re.X | re.I) + rc = self.token_rc + + pos = 0 + while 1: + m = rc.match(sql, pos) + if not m: + break + pos = m.end() + yield m.group(1) + + def unquote_data(self, fields, values): + # unquote data and column names + data = {} + for k, v in zip(fields, values): + if k[0] == '"': + k = unescape(k[1:-1]) + if len(v) == 4 and v.lower() == "null": + v = None + elif v[0] == "'": + v = unescape(v[1:-1]) + data[k] = v + return data + + def parse_insert(self, tk, fields, values): + # (col1, col2) values ('data', null) + if tk.next() != "(": + raise Exception("syntax error") + while 1: + fields.append(tk.next()) + t = tk.next() + if t == ")": + break + elif t != ",": + raise Exception("syntax error") + if tk.next().lower() != "values": + raise Exception("syntax error") + if tk.next() != "(": + raise Exception("syntax error") + while 1: + t = tk.next() + if t == ")": + break + if t == ",": + continue + values.append(t) + tk.next() + + def parse_update(self, tk, fields, values): + # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2' + while 1: + fields.append(tk.next()) + if tk.next() != "=": + raise Exception("syntax error") + values.append(tk.next()) + + t = tk.next() + if t == ",": + continue + elif t.lower() == "where": + break + else: + raise Exception("syntax error") + while 1: + t = tk.next() + fields.append(t) + if tk.next() != "=": + raise Exception("syntax error") + values.append(tk.next()) + t = tk.next() + if t.lower() != "and": + raise Exception("syntax error") + + def parse_delete(self, tk, fields, values): + # pk1 = 'pk1' and pk2 = 'pk2' + while 1: + t = tk.next() + if t == "and": + continue + fields.append(t) + if tk.next() != "=": + raise Exception("syntax error") + values.append(tk.next()) + + def parse_sql(self, op, sql): + tk = self.tokenizer(sql) + fields = [] + values = [] + try: + if op == "I": + self.parse_insert(tk, fields, values) + elif op == "U": + self.parse_update(tk, fields, values) + elif op == "D": + self.parse_delete(tk, fields, values) + raise Exception("syntax error") + except StopIteration: + # last sanity check + if len(fields) == 0 or len(fields) != len(values): + raise Exception("syntax error") + + return self.unquote_data(fields, values) + +def parse_logtriga_sql(op, sql): + """Parse partial SQL used by logtriga() back to data values. + + Parser has following limitations: + - Expects standard_quoted_strings = off + - Does not support dollar quoting. + - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1)) + - WHERE expression must not contain IS (NOT) NULL + - Does not support updateing pk value. + + Returns dict of col->data pairs. + """ + return _logtriga_parser().parse_sql(op, sql) + + +def parse_tabbed_table(txt): + """Parse a tab-separated table into list of dicts. + + Expect first row to be column names. + + Very primitive. + """ + + txt = txt.replace("\r\n", "\n") + fields = None + data = [] + for ln in txt.split("\n"): + if not ln: + continue + if not fields: + fields = ln.split("\t") + continue + cols = ln.split("\t") + if len(cols) != len(fields): + continue + row = dict(zip(fields, cols)) + data.append(row) + return data + + +_sql_token_re = r""" + ( [a-z][a-z0-9_$]* + | ["] ( [^"\\]+ | \\. )* ["] + | ['] ( [^'\\]+ | \\. | [']['] )* ['] + | [$] ([_a-z][_a-z0-9]*)? [$] + | (?P<ws> \s+ | [/][*] | [-][-][^\n]* ) + | . + )""" +_sql_token_rc = None +_copy_from_stdin_re = "copy.*from\s+stdin" +_copy_from_stdin_rc = None + +def _sql_tokenizer(sql): + global _sql_token_rc, _copy_from_stdin_rc + if not _sql_token_rc: + _sql_token_rc = re.compile(_sql_token_re, re.X | re.I) + _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I) + rc = _sql_token_rc + + pos = 0 + while 1: + m = rc.match(sql, pos) + if not m: + break + pos = m.end() + tok = m.group(1) + ws = m.start('ws') >= 0 # it tok empty? + if tok == "/*": + end = sql.find("*/", pos) + if end < 0: + raise Exception("unterminated c comment") + pos = end + 2 + tok = sql[ m.start() : pos] + elif len(tok) > 1 and tok[0] == "$" and tok[-1] == "$": + end = sql.find(tok, pos) + if end < 0: + raise Exception("unterminated dollar string") + pos = end + len(tok) + tok = sql[ m.start() : pos] + yield (ws, tok) + +def parse_statements(sql): + """Parse multi-statement string into separate statements. + + Returns list of statements. + """ + + tk = _sql_tokenizer(sql) + tokens = [] + pcount = 0 # '(' level + while 1: + try: + ws, t = tk.next() + except StopIteration: + break + # skip whitespace and comments before statement + if len(tokens) == 0 and ws: + continue + # keep the rest + tokens.append(t) + if t == "(": + pcount += 1 + elif t == ")": + pcount -= 1 + elif t == ";" and pcount == 0: + sql = "".join(tokens) + if _copy_from_stdin_rc.match(sql): + raise Exception("copy from stdin not supported") + yield ("".join(tokens)) + tokens = [] + if len(tokens) > 0: + yield ("".join(tokens)) + if pcount != 0: + raise Exception("syntax error - unbalanced parenthesis") + diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py index 594646a4..10d4626a 100644 --- a/python/skytools/quoting.py +++ b/python/skytools/quoting.py @@ -4,49 +4,23 @@ import urllib, re -from skytools.psycopgwrapper import QuotedString - __all__ = [ "quote_literal", "quote_copy", "quote_bytea_raw", + "db_urlencode", "db_urldecode", "unescape", + "quote_bytea_literal", "quote_bytea_copy", "quote_statement", - "quote_ident", "quote_fqident", "quote_json", - "db_urlencode", "db_urldecode", "unescape", "unescape_copy" + "quote_ident", "quote_fqident", "quote_json", "unescape_copy" ] +try: + from _cquoting import * +except ImportError: + from _pyquoting import * + # # SQL quoting # -def quote_literal(s): - """Quote a literal value for SQL. - - Surronds it with single-quotes. - """ - - if s == None: - return "null" - s = QuotedString(str(s)) - return str(s) - -def quote_copy(s): - """Quoting for copy command.""" - - if s == None: - return "\\N" - s = str(s) - s = s.replace("\\", "\\\\") - s = s.replace("\t", "\\t") - s = s.replace("\n", "\\n") - s = s.replace("\r", "\\r") - return s - -def quote_bytea_raw(s): - """Quoting for bytea parser.""" - - if s == None: - return None - return s.replace("\\", "\\\\").replace("\0", "\\000") - def quote_bytea_literal(s): """Quote bytea for regular SQL.""" @@ -125,214 +99,9 @@ def quote_json(s): return "null" return '"%s"' % _jsre.sub(_json_quote_char, s) -# -# Database specific urlencode and urldecode. -# - -def db_urlencode(dict): - """Database specific urlencode. - - Encode None as key without '='. That means that in "foo&bar=", - foo is NULL and bar is empty string. - """ - - elem_list = [] - for k, v in dict.items(): - if v is None: - elem = urllib.quote_plus(str(k)) - else: - elem = urllib.quote_plus(str(k)) + '=' + urllib.quote_plus(str(v)) - elem_list.append(elem) - return '&'.join(elem_list) - -def db_urldecode(qs): - """Database specific urldecode. - - Decode key without '=' as None. - This also does not support one key several times. - """ - - res = {} - for elem in qs.split('&'): - if not elem: - continue - pair = elem.split('=', 1) - name = urllib.unquote_plus(pair[0]) - - # keep only one instance around - name = intern(name) - - if len(pair) == 1: - res[name] = None - else: - res[name] = urllib.unquote_plus(pair[1]) - return res - -# -# Remove C-like backslash escapes -# - -_esc_re = r"\\([0-7][0-7][0-7]|.)" -_esc_rc = re.compile(_esc_re) -_esc_map = { - 't': '\t', - 'n': '\n', - 'r': '\r', - 'a': '\a', - 'b': '\b', - "'": "'", - '"': '"', - '\\': '\\', -} - -def _sub_unescape(m): - v = m.group(1) - if len(v) == 1: - return _esc_map[v] - else: - return chr(int(v, 8)) - -def unescape(val): - """Removes C-style escapes from string.""" - return _esc_rc.sub(_sub_unescape, val) - def unescape_copy(val): """Removes C-style escapes, also converts "\N" to None.""" if val == r"\N": return None return unescape(val) - -# -# parse logtriga partial sql -# - -class _logtriga_parser: - token_re = r""" - [ \t\r\n]* - ( [a-z][a-z0-9_]* - | ["] ( [^"\\]+ | \\. )* ["] - | ['] ( [^'\\]+ | \\. | [']['] )* ['] - | [^ \t\r\n] - )""" - token_rc = None - - def tokenizer(self, sql): - if not _logtriga_parser.token_rc: - _logtriga_parser.token_rc = re.compile(self.token_re, re.X | re.I) - rc = self.token_rc - - pos = 0 - while 1: - m = rc.match(sql, pos) - if not m: - break - pos = m.end() - yield m.group(1) - - def unquote_data(self, fields, values): - # unquote data and column names - data = {} - for k, v in zip(fields, values): - if k[0] == '"': - k = unescape(k[1:-1]) - if len(v) == 4 and v.lower() == "null": - v = None - elif v[0] == "'": - v = unescape(v[1:-1]) - data[k] = v - return data - - def parse_insert(self, tk, fields, values): - # (col1, col2) values ('data', null) - if tk.next() != "(": - raise Exception("syntax error") - while 1: - fields.append(tk.next()) - t = tk.next() - if t == ")": - break - elif t != ",": - raise Exception("syntax error") - if tk.next().lower() != "values": - raise Exception("syntax error") - if tk.next() != "(": - raise Exception("syntax error") - while 1: - t = tk.next() - if t == ")": - break - if t == ",": - continue - values.append(t) - tk.next() - - def parse_update(self, tk, fields, values): - # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2' - while 1: - fields.append(tk.next()) - if tk.next() != "=": - raise Exception("syntax error") - values.append(tk.next()) - - t = tk.next() - if t == ",": - continue - elif t.lower() == "where": - break - else: - raise Exception("syntax error") - while 1: - t = tk.next() - fields.append(t) - if tk.next() != "=": - raise Exception("syntax error") - values.append(tk.next()) - t = tk.next() - if t.lower() != "and": - raise Exception("syntax error") - - def parse_delete(self, tk, fields, values): - # pk1 = 'pk1' and pk2 = 'pk2' - while 1: - t = tk.next() - if t == "and": - continue - fields.append(t) - if tk.next() != "=": - raise Exception("syntax error") - values.append(tk.next()) - - def parse_sql(self, op, sql): - tk = self.tokenizer(sql) - fields = [] - values = [] - try: - if op == "I": - self.parse_insert(tk, fields, values) - elif op == "U": - self.parse_update(tk, fields, values) - elif op == "D": - self.parse_delete(tk, fields, values) - raise Exception("syntax error") - except StopIteration: - # last sanity check - if len(fields) == 0 or len(fields) != len(values): - raise Exception("syntax error") - - return self.unquote_data(fields, values) - -def parse_logtriga_sql(op, sql): - """Parse partial SQL used by logtriga() back to data values. - - Parser has following limitations: - - Expects standard_quoted_strings = off - - Does not support dollar quoting. - - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1)) - - WHERE expression must not contain IS (NOT) NULL - - Does not support updateing pk value. - - Returns dict of col->data pairs. - """ - return _logtriga_parser().parse_sql(op, sql) - @@ -53,5 +53,6 @@ setup( 'scripts/scriptmgr.ini.templ', ]), ('share/skytools', share_dup_files)], + ext_modules=[Extension("skytools._cquoting", ['python/modules/cquoting.c'])], ) |