summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Kreen2009-10-14 13:37:17 +0000
committerMarko Kreen2009-10-14 13:37:17 +0000
commit66c72793edd322c55dd02744abd8fab822e211b4 (patch)
tree5339b7070d5fa9cf19d6de1117a940ea4a48f8c7
parentfc41ef5f83372e766ba31165d01d7ed2d74b034c (diff)
python/skytools: add doctest-based regtests to few non-sql functions
Seems to be better testing method than ad-hoc scripts. They will serve as examples too. Also fix few minor problems found in the process: - parse_pgarray: check if str ends with } - parse_pgarray: support NULL - quote_fqident: add 'public.' schema to idents without schema - fq_name_parts: return always list
-rw-r--r--python/skytools/parsing.py68
-rw-r--r--python/skytools/quoting.py23
-rw-r--r--python/skytools/sqltools.py60
3 files changed, 128 insertions, 23 deletions
diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py
index d50b14c0..4b92306e 100644
--- a/python/skytools/parsing.py
+++ b/python/skytools/parsing.py
@@ -14,11 +14,18 @@ _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X)
# _parse_pgarray
def parse_pgarray(array):
- """ Parse Postgres array and return list of items inside it
- Used to deserialize data recived from service layer parameters
+ r"""Parse Postgres array and return list of items inside it.
+
+ Examples:
+ >>> parse_pgarray('{}')
+ []
+ >>> parse_pgarray('{a,b,null,"null"}')
+ ['a', 'b', None, 'null']
+ >>> parse_pgarray(r'{"a,a","b\"b","c\\c"}')
+ ['a,a', 'b"b', 'c\\c']
"""
- if not array or array[0] != "{":
- raise Exception("bad array format: must start with {")
+ if not array or array[0] != "{" or array[-1] != '}':
+ raise Exception("bad array format: must be surrounded with {}")
res = []
pos = 1
while 1:
@@ -27,16 +34,19 @@ def parse_pgarray(array):
break
pos2 = m.end()
item = array[pos:pos2]
- if len(item) > 0 and item[0] == '"':
- item = item[1:-1]
- item = unescape(item)
- res.append(item)
+ if len(item) == 4 and item.upper() == "NULL":
+ val = None
+ else:
+ if len(item) > 0 and item[0] == '"':
+ item = item[1:-1]
+ val = unescape(item)
+ res.append(val)
pos = pos2 + 1
if array[pos2] == "}":
break
elif array[pos2] != ",":
- raise Exception("bad array format: expected ,} got " + array[pos2])
+ raise Exception("bad array format: expected ,} got " + repr(array[pos2]))
return res
#
@@ -136,26 +146,45 @@ class _logtriga_parser:
return dbdict(zip(fields, values))
def parse_logtriga_sql(op, sql):
- """Parse partial SQL used by logtriga() back to data values.
+ return parse_sqltriga_sql(op, sql)
+
+def parse_sqltriga_sql(op, sql):
+ """Parse partial SQL used by pgq.sqltriga() back to data values.
Parser has following limitations:
- Expects standard_quoted_strings = off
- Does not support dollar quoting.
- Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1))
- WHERE expression must not contain IS (NOT) NULL
- - Does not support updateing pk value.
+ - Does not support updating pk value.
Returns dict of col->data pairs.
+
+ Insert event:
+ >>> parse_logtriga_sql('I', '(id, data) values (1, null)')
+ {'data': None, 'id': '1'}
+
+ Update event:
+ >>> parse_logtriga_sql('U', "data='foo' where id = 1")
+ {'data': 'foo', 'id': '1'}
+
+ Delete event:
+ >>> parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'")
+ {'id2': "str'val", 'id': '1'}
"""
return _logtriga_parser().parse_sql(op, sql)
def parse_tabbed_table(txt):
- """Parse a tab-separated table into list of dicts.
+ r"""Parse a tab-separated table into list of dicts.
Expect first row to be column names.
Very primitive.
+
+ Example:
+ >>> parse_tabbed_table('col1\tcol2\nval1\tval2\n')
+ [{'col2': 'val2', 'col1': 'val1'}]
"""
txt = txt.replace("\r\n", "\n")
@@ -194,9 +223,15 @@ _ext_sql = r"""(?: (?P<str> [E]? %s ) | %s )""" % (_extstr, _base_sql)
_std_sql_rc = _ext_sql_rc = None
def sql_tokenizer(sql, standard_quoting = False, ignore_whitespace = False):
- """Parser SQL to tokens.
+ r"""Parser SQL to tokens.
Iterator, returns (toktype, tokstr) tuples.
+
+ Example
+ >>> [x for x in sql_tokenizer("select * from a.b", ignore_whitespace=True)]
+ [('ident', 'select'), ('sym', '*'), ('ident', 'from'), ('ident', 'a'), ('sym', '.'), ('ident', 'b')]
+ >>> [x for x in sql_tokenizer("\"c olumn\",'str''val'")]
+ [('ident', '"c olumn"'), ('sym', ','), ('str', "'str''val'")]
"""
global _std_sql_rc, _ext_sql_rc
if not _std_sql_rc:
@@ -224,6 +259,9 @@ def parse_statements(sql, standard_quoting = False):
"""Parse multi-statement string into separate statements.
Returns list of statements.
+
+ >>> [sql for sql in parse_statements("begin; select 1; select 'foo'; end;")]
+ ['begin;', 'select 1;', "select 'foo';", 'end;']
"""
global _copy_from_stdin_rc
@@ -252,3 +290,7 @@ def parse_statements(sql, standard_quoting = False):
if pcount != 0:
raise Exception("syntax error - unbalanced parenthesis")
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
+
diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py
index ce860040..e83d2495 100644
--- a/python/skytools/quoting.py
+++ b/python/skytools/quoting.py
@@ -82,8 +82,17 @@ def quote_fqident(s):
The '.' is taken as namespace separator and
all parts are quoted separately
+
+ Example:
+ >>> quote_fqident('tbl')
+ 'public.tbl'
+ >>> quote_fqident('Baz.Foo.Bar')
+ '"Baz"."Foo.Bar"'
"""
- return '.'.join(map(quote_ident, s.split('.', 1)))
+ tmp = s.split('.', 1)
+ if len(tmp) == 1:
+ return 'public.' + quote_ident(s)
+ return '.'.join(map(quote_ident, tmp))
#
# quoting for JSON strings
@@ -110,7 +119,14 @@ def quote_json(s):
return '"%s"' % _jsre.sub(_json_quote_char, s)
def unescape_copy(val):
- """Removes C-style escapes, also converts "\N" to None."""
+ r"""Removes C-style escapes, also converts "\N" to None.
+
+ Example:
+ >>> unescape_copy(r'baz\tfo\'o')
+ "baz\tfo'o"
+ >>> unescape_copy(r'\N') is None
+ True
+ """
if val == r"\N":
return None
return unescape(val)
@@ -129,3 +145,6 @@ def unquote_fqident(val):
tmp = val.split('.', 1)
return "%s.%s" % (unquote_ident(tmp[0]), unquote_ident(tmp[1]))
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py
index 3037df69..902bf9a6 100644
--- a/python/skytools/sqltools.py
+++ b/python/skytools/sqltools.py
@@ -44,18 +44,34 @@ class dbdict(dict):
#
def fq_name_parts(tbl):
- "Return fully qualified name parts."
+ """Return fully qualified name parts.
+
+ >>> fq_name_parts('tbl')
+ ['public', 'tbl']
+ >>> fq_name_parts('foo.tbl')
+ ['foo', 'tbl']
+ >>> fq_name_parts('foo.tbl.baz')
+ ['foo', 'tbl.baz']
+ """
tmp = tbl.split('.', 1)
if len(tmp) == 1:
- return ('public', tbl)
+ return ['public', tbl]
elif len(tmp) == 2:
return tmp
else:
raise Exception('Syntax error in table name:'+tbl)
def fq_name(tbl):
- "Return fully qualified name."
+ """Return fully qualified name.
+
+ >>> fq_name('tbl')
+ 'public.tbl'
+ >>> fq_name('foo.tbl')
+ 'foo.tbl'
+ >>> fq_name('foo.tbl.baz')
+ 'foo.tbl.baz'
+ """
return '.'.join(fq_name_parts(tbl))
#
@@ -171,7 +187,19 @@ def exists_temp_table(curs, tbl):
#
class Snapshot(object):
- "Represents a PostgreSQL snapshot."
+ """Represents a PostgreSQL snapshot.
+
+ Example:
+ >>> sn = Snapshot('11:20:11,12,15')
+ >>> sn.contains(9)
+ True
+ >>> sn.contains(11)
+ False
+ >>> sn.contains(17)
+ True
+ >>> sn.contains(20)
+ False
+ """
def __init__(self, str):
"Create snapshot from string."
@@ -235,11 +263,15 @@ def _gen_list_insert(tbl, row, fields, qfields):
return fmt % (tbl, ",".join(qfields), ",".join(tmp))
def magic_insert(curs, tablename, data, fields = None, use_insert = 0):
- """Copy/insert a list of dict/list data to database.
-
+ r"""Copy/insert a list of dict/list data to database.
+
If curs == None, then the copy or insert statements are returned
as string. For list of dict the field list is optional, as its
possible to guess them from dict keys.
+
+ Example:
+ >>> magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2'])
+ 'COPY public.tbl (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n'
"""
if len(data) == 0:
return
@@ -486,7 +518,11 @@ def installer_apply_file(db, filename, log):
#
def mk_insert_sql(row, tbl, pkey_list = None, field_map = None):
- """Generate INSERT statement from dict data."""
+ """Generate INSERT statement from dict data.
+
+ >>> mk_insert_sql({'id': '1', 'data': None}, 'tbl')
+ "insert into public.tbl (data, id) values (null, '1');"
+ """
col_list = []
val_list = []
@@ -504,7 +540,11 @@ def mk_insert_sql(row, tbl, pkey_list = None, field_map = None):
quote_fqident(tbl), col_str, val_str)
def mk_update_sql(row, tbl, pkey_list, field_map = None):
- """Generate UPDATE statement from dict data."""
+ r"""Generate UPDATE statement from dict data.
+
+ >>> mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2'])
+ 'update only public."Table" set data = E\'str\\\\\' where id = \'0\' and id2 = \'2\';'
+ """
if len(pkey_list) < 1:
raise Exception("update needs pkeys")
@@ -787,3 +827,7 @@ class PLPyQueryBuilder(QueryBuilder):
res = [dbdict(r) for r in res]
return res
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
+