diff options
author | Marko Kreen | 2009-10-14 13:37:17 +0000 |
---|---|---|
committer | Marko Kreen | 2009-10-14 13:37:17 +0000 |
commit | 66c72793edd322c55dd02744abd8fab822e211b4 (patch) | |
tree | 5339b7070d5fa9cf19d6de1117a940ea4a48f8c7 | |
parent | fc41ef5f83372e766ba31165d01d7ed2d74b034c (diff) |
python/skytools: add doctest-based regtests to few non-sql functions
Seems to be better testing method than ad-hoc scripts. They will
serve as examples too.
Also fix few minor problems found in the process:
- parse_pgarray: check if str ends with }
- parse_pgarray: support NULL
- quote_fqident: add 'public.' schema to idents without schema
- fq_name_parts: return always list
-rw-r--r-- | python/skytools/parsing.py | 68 | ||||
-rw-r--r-- | python/skytools/quoting.py | 23 | ||||
-rw-r--r-- | python/skytools/sqltools.py | 60 |
3 files changed, 128 insertions, 23 deletions
diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py index d50b14c0..4b92306e 100644 --- a/python/skytools/parsing.py +++ b/python/skytools/parsing.py @@ -14,11 +14,18 @@ _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X) # _parse_pgarray def parse_pgarray(array): - """ Parse Postgres array and return list of items inside it - Used to deserialize data recived from service layer parameters + r"""Parse Postgres array and return list of items inside it. + + Examples: + >>> parse_pgarray('{}') + [] + >>> parse_pgarray('{a,b,null,"null"}') + ['a', 'b', None, 'null'] + >>> parse_pgarray(r'{"a,a","b\"b","c\\c"}') + ['a,a', 'b"b', 'c\\c'] """ - if not array or array[0] != "{": - raise Exception("bad array format: must start with {") + if not array or array[0] != "{" or array[-1] != '}': + raise Exception("bad array format: must be surrounded with {}") res = [] pos = 1 while 1: @@ -27,16 +34,19 @@ def parse_pgarray(array): break pos2 = m.end() item = array[pos:pos2] - if len(item) > 0 and item[0] == '"': - item = item[1:-1] - item = unescape(item) - res.append(item) + if len(item) == 4 and item.upper() == "NULL": + val = None + else: + if len(item) > 0 and item[0] == '"': + item = item[1:-1] + val = unescape(item) + res.append(val) pos = pos2 + 1 if array[pos2] == "}": break elif array[pos2] != ",": - raise Exception("bad array format: expected ,} got " + array[pos2]) + raise Exception("bad array format: expected ,} got " + repr(array[pos2])) return res # @@ -136,26 +146,45 @@ class _logtriga_parser: return dbdict(zip(fields, values)) def parse_logtriga_sql(op, sql): - """Parse partial SQL used by logtriga() back to data values. + return parse_sqltriga_sql(op, sql) + +def parse_sqltriga_sql(op, sql): + """Parse partial SQL used by pgq.sqltriga() back to data values. Parser has following limitations: - Expects standard_quoted_strings = off - Does not support dollar quoting. - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1)) - WHERE expression must not contain IS (NOT) NULL - - Does not support updateing pk value. + - Does not support updating pk value. Returns dict of col->data pairs. + + Insert event: + >>> parse_logtriga_sql('I', '(id, data) values (1, null)') + {'data': None, 'id': '1'} + + Update event: + >>> parse_logtriga_sql('U', "data='foo' where id = 1") + {'data': 'foo', 'id': '1'} + + Delete event: + >>> parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'") + {'id2': "str'val", 'id': '1'} """ return _logtriga_parser().parse_sql(op, sql) def parse_tabbed_table(txt): - """Parse a tab-separated table into list of dicts. + r"""Parse a tab-separated table into list of dicts. Expect first row to be column names. Very primitive. + + Example: + >>> parse_tabbed_table('col1\tcol2\nval1\tval2\n') + [{'col2': 'val2', 'col1': 'val1'}] """ txt = txt.replace("\r\n", "\n") @@ -194,9 +223,15 @@ _ext_sql = r"""(?: (?P<str> [E]? %s ) | %s )""" % (_extstr, _base_sql) _std_sql_rc = _ext_sql_rc = None def sql_tokenizer(sql, standard_quoting = False, ignore_whitespace = False): - """Parser SQL to tokens. + r"""Parser SQL to tokens. Iterator, returns (toktype, tokstr) tuples. + + Example + >>> [x for x in sql_tokenizer("select * from a.b", ignore_whitespace=True)] + [('ident', 'select'), ('sym', '*'), ('ident', 'from'), ('ident', 'a'), ('sym', '.'), ('ident', 'b')] + >>> [x for x in sql_tokenizer("\"c olumn\",'str''val'")] + [('ident', '"c olumn"'), ('sym', ','), ('str', "'str''val'")] """ global _std_sql_rc, _ext_sql_rc if not _std_sql_rc: @@ -224,6 +259,9 @@ def parse_statements(sql, standard_quoting = False): """Parse multi-statement string into separate statements. Returns list of statements. + + >>> [sql for sql in parse_statements("begin; select 1; select 'foo'; end;")] + ['begin;', 'select 1;', "select 'foo';", 'end;'] """ global _copy_from_stdin_rc @@ -252,3 +290,7 @@ def parse_statements(sql, standard_quoting = False): if pcount != 0: raise Exception("syntax error - unbalanced parenthesis") +if __name__ == '__main__': + import doctest + doctest.testmod() + diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py index ce860040..e83d2495 100644 --- a/python/skytools/quoting.py +++ b/python/skytools/quoting.py @@ -82,8 +82,17 @@ def quote_fqident(s): The '.' is taken as namespace separator and all parts are quoted separately + + Example: + >>> quote_fqident('tbl') + 'public.tbl' + >>> quote_fqident('Baz.Foo.Bar') + '"Baz"."Foo.Bar"' """ - return '.'.join(map(quote_ident, s.split('.', 1))) + tmp = s.split('.', 1) + if len(tmp) == 1: + return 'public.' + quote_ident(s) + return '.'.join(map(quote_ident, tmp)) # # quoting for JSON strings @@ -110,7 +119,14 @@ def quote_json(s): return '"%s"' % _jsre.sub(_json_quote_char, s) def unescape_copy(val): - """Removes C-style escapes, also converts "\N" to None.""" + r"""Removes C-style escapes, also converts "\N" to None. + + Example: + >>> unescape_copy(r'baz\tfo\'o') + "baz\tfo'o" + >>> unescape_copy(r'\N') is None + True + """ if val == r"\N": return None return unescape(val) @@ -129,3 +145,6 @@ def unquote_fqident(val): tmp = val.split('.', 1) return "%s.%s" % (unquote_ident(tmp[0]), unquote_ident(tmp[1])) +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py index 3037df69..902bf9a6 100644 --- a/python/skytools/sqltools.py +++ b/python/skytools/sqltools.py @@ -44,18 +44,34 @@ class dbdict(dict): # def fq_name_parts(tbl): - "Return fully qualified name parts." + """Return fully qualified name parts. + + >>> fq_name_parts('tbl') + ['public', 'tbl'] + >>> fq_name_parts('foo.tbl') + ['foo', 'tbl'] + >>> fq_name_parts('foo.tbl.baz') + ['foo', 'tbl.baz'] + """ tmp = tbl.split('.', 1) if len(tmp) == 1: - return ('public', tbl) + return ['public', tbl] elif len(tmp) == 2: return tmp else: raise Exception('Syntax error in table name:'+tbl) def fq_name(tbl): - "Return fully qualified name." + """Return fully qualified name. + + >>> fq_name('tbl') + 'public.tbl' + >>> fq_name('foo.tbl') + 'foo.tbl' + >>> fq_name('foo.tbl.baz') + 'foo.tbl.baz' + """ return '.'.join(fq_name_parts(tbl)) # @@ -171,7 +187,19 @@ def exists_temp_table(curs, tbl): # class Snapshot(object): - "Represents a PostgreSQL snapshot." + """Represents a PostgreSQL snapshot. + + Example: + >>> sn = Snapshot('11:20:11,12,15') + >>> sn.contains(9) + True + >>> sn.contains(11) + False + >>> sn.contains(17) + True + >>> sn.contains(20) + False + """ def __init__(self, str): "Create snapshot from string." @@ -235,11 +263,15 @@ def _gen_list_insert(tbl, row, fields, qfields): return fmt % (tbl, ",".join(qfields), ",".join(tmp)) def magic_insert(curs, tablename, data, fields = None, use_insert = 0): - """Copy/insert a list of dict/list data to database. - + r"""Copy/insert a list of dict/list data to database. + If curs == None, then the copy or insert statements are returned as string. For list of dict the field list is optional, as its possible to guess them from dict keys. + + Example: + >>> magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2']) + 'COPY public.tbl (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' """ if len(data) == 0: return @@ -486,7 +518,11 @@ def installer_apply_file(db, filename, log): # def mk_insert_sql(row, tbl, pkey_list = None, field_map = None): - """Generate INSERT statement from dict data.""" + """Generate INSERT statement from dict data. + + >>> mk_insert_sql({'id': '1', 'data': None}, 'tbl') + "insert into public.tbl (data, id) values (null, '1');" + """ col_list = [] val_list = [] @@ -504,7 +540,11 @@ def mk_insert_sql(row, tbl, pkey_list = None, field_map = None): quote_fqident(tbl), col_str, val_str) def mk_update_sql(row, tbl, pkey_list, field_map = None): - """Generate UPDATE statement from dict data.""" + r"""Generate UPDATE statement from dict data. + + >>> mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2']) + 'update only public."Table" set data = E\'str\\\\\' where id = \'0\' and id2 = \'2\';' + """ if len(pkey_list) < 1: raise Exception("update needs pkeys") @@ -787,3 +827,7 @@ class PLPyQueryBuilder(QueryBuilder): res = [dbdict(r) for r in res] return res +if __name__ == '__main__': + import doctest + doctest.testmod() + |