diff options
author | Marko Kreen | 2009-02-13 10:03:53 +0000 |
---|---|---|
committer | Marko Kreen | 2009-02-13 12:21:01 +0000 |
commit | 5521e5fc2f399a923fa7fb313bbe797cfa0d5baa (patch) | |
tree | 7df01ec195273b812bc9b64953a71ad1155d9438 | |
parent | 012aee6a0369d8a4b2617046a27659c02b0bb478 (diff) |
python/skytools update
- docstrings
- some preliminary python 3.0 compat (var names, print())
- sync with 2.1-stable
adminscript:
- move exec_cmd function to dbscript
dbstruct:
- support sequnces. SERIAL columns are not automatically created,
but the link beteween column and sequence is.
psycopgwrapper:
- drop support for psycopg1
- beginnings of quick DB-API / DictRow description.
quoting:
- new unquote_fqident() function, reverse of quote_fqident()
- quote_statement() accepts both row and dict
dbscript:
- catch startup errors
- use log.exception for exceptions, will result in nicer logs
sqltools:
- exists_sequence()
_pyquoting:
- fix typo in variable name
-rw-r--r-- | python/skytools/__init__.py | 4 | ||||
-rw-r--r-- | python/skytools/_pyquoting.py | 4 | ||||
-rw-r--r-- | python/skytools/adminscript.py | 87 | ||||
-rw-r--r-- | python/skytools/config.py | 14 | ||||
-rw-r--r-- | python/skytools/dbstruct.py | 220 | ||||
-rw-r--r-- | python/skytools/installer_config.py.in | 4 | ||||
-rw-r--r-- | python/skytools/parsing.py | 6 | ||||
-rw-r--r-- | python/skytools/psycopgwrapper.py | 59 | ||||
-rw-r--r-- | python/skytools/quoting.py | 29 | ||||
-rw-r--r-- | python/skytools/scripting.py | 179 | ||||
-rw-r--r-- | python/skytools/sqltools.py | 50 |
11 files changed, 477 insertions, 179 deletions
diff --git a/python/skytools/__init__.py b/python/skytools/__init__.py index e42f06a3..b958b483 100644 --- a/python/skytools/__init__.py +++ b/python/skytools/__init__.py @@ -1,6 +1,10 @@ """Tools for Python database scripts.""" +__version__ = '3.0' + +__pychecker__ = 'no-miximport' + import skytools.quoting import skytools.config import skytools.psycopgwrapper diff --git a/python/skytools/_pyquoting.py b/python/skytools/_pyquoting.py index 8f72eb5f..e5511687 100644 --- a/python/skytools/_pyquoting.py +++ b/python/skytools/_pyquoting.py @@ -136,6 +136,7 @@ _esc_map = { } def _sub_unescape_c(m): + """unescape single escape seq.""" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: @@ -152,8 +153,9 @@ def unescape(val): return _esc_rc.sub(_sub_unescape_c, val) _esql_re = r"''|\\([0-7]{1,3}|.)" -_esql_rc = re.compile(_esc_re) +_esql_rc = re.compile(_esql_re) def _sub_unescape_sqlext(m): + """Unescape extended-quoted string.""" if m.group() == "''": return "'" v = m.group(1) diff --git a/python/skytools/adminscript.py b/python/skytools/adminscript.py index 399e5cd3..ce51ae44 100644 --- a/python/skytools/adminscript.py +++ b/python/skytools/adminscript.py @@ -11,15 +11,25 @@ from skytools.quoting import quote_statement __all__ = ['AdminScript'] class AdminScript(DBScript): + """Contains common admin script tools. + + Second argument (first is .ini file) is takes as command + name. If class method 'cmd_' + arg exists, it is called, + otherwise error is given. + """ def __init__(self, service_name, args): + """AdminScript init.""" DBScript.__init__(self, service_name, args) - self.pidfile = self.pidfile + ".admin" + if self.pidfile: + self.pidfile = self.pidfile + ".admin" if len(self.args) < 2: self.log.error("need command") sys.exit(1) def work(self): + """Non-looping work function, calls command function.""" + self.set_single_loop(1) cmd = self.args[1] @@ -47,6 +57,7 @@ class AdminScript(DBScript): fn(*cmdargs) def fetch_list(self, db, sql, args, keycol = None): + """Fetch a resultset from db, optionally turnin it info value list.""" curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() @@ -81,85 +92,25 @@ class AdminScript(DBScript): fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) if desc: - print desc - print fmt % tuple(fields) - print fmt % tuple(['-'*15] * len(fields)) + print(desc) + print(fmt % tuple(fields)) + print(fmt % tuple(['-'*15] * len(fields))) for row in rows: - print fmt % tuple([row[k] for k in fields]) - print '\n' + print(fmt % tuple([row[k] for k in fields])) + print('\n') return 1 - def _exec_cmd(self, curs, sql, args): - self.log.debug("exec_cmd: %s" % quote_statement(sql, args)) - curs.execute(sql, args) - ok = True - rows = curs.fetchall() - for row in rows: - try: - code = row['ret_code'] - msg = row['ret_note'] - except KeyError: - self.log.error("Query does not conform to exec_cmd API:") - self.log.error("SQL: %s" % quote_statement(sql, args)) - self.log.error("Row: %s" % repr(row.copy())) - sys.exit(1) - level = code / 100 - if level == 1: - self.log.debug("%d %s" % (code, msg)) - elif level == 2: - self.log.info("%d %s" % (code, msg)) - elif level == 3: - self.log.warning("%d %s" % (code, msg)) - else: - self.log.error("%d %s" % (code, msg)) - self.log.error("Query was: %s" % skytools.quote_statement(sql, args)) - ok = False - return (ok, rows) - - def _exec_cmd_many(self, curs, sql, baseargs, extra_list): - ok = True - rows = [] - for a in extra_list: - (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a]) - ok = tmp_ok and ok - rows += tmp_rows - return (ok, rows) - - def exec_cmd(self, db, q, args, commit = True): - (ok, rows) = self._exec_cmd(db.cursor(), q, args) - if ok: - if commit: - self.log.info("COMMIT") - db.commit() - return rows - else: - self.log.info("ROLLBACK") - db.rollback() - raise EXception("rollback") - - def exec_cmd_many(self, db, sql, baseargs, extra_list, commit = True): - curs = db.cursor() - (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list) - if ok: - if commit: - self.log.info("COMMIT") - db.commit() - return rows - else: - self.log.info("ROLLBACK") - db.rollback() - raise EXception("rollback") - - def exec_stmt(self, db, sql, args): + """Run regular non-query SQL on db.""" self.log.debug("exec_stmt: %s" % quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) db.commit() def exec_query(self, db, sql, args): + """Run regular query SQL on db.""" self.log.debug("exec_query: %s" % quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) diff --git a/python/skytools/config.py b/python/skytools/config.py index a0e96ec5..fbe0b6ed 100644 --- a/python/skytools/config.py +++ b/python/skytools/config.py @@ -1,7 +1,7 @@ """Nicer config class.""" -import sys, os, ConfigParser, socket +import os, ConfigParser, socket __all__ = ['Config'] @@ -52,7 +52,7 @@ class Config(object): """Reads string value, if not set then default.""" try: return self.cf.get(self.main_section, key) - except ConfigParser.NoOptionError, det: + except ConfigParser.NoOptionError: if default == None: raise Exception("Config value not set: " + key) return default @@ -61,7 +61,7 @@ class Config(object): """Reads int value, if not set then default.""" try: return self.cf.getint(self.main_section, key) - except ConfigParser.NoOptionError, det: + except ConfigParser.NoOptionError: if default == None: raise Exception("Config value not set: " + key) return default @@ -70,7 +70,7 @@ class Config(object): """Reads boolean value, if not set then default.""" try: return self.cf.getboolean(self.main_section, key) - except ConfigParser.NoOptionError, det: + except ConfigParser.NoOptionError: if default == None: raise Exception("Config value not set: " + key) return default @@ -79,7 +79,7 @@ class Config(object): """Reads float value, if not set then default.""" try: return self.cf.getfloat(self.main_section, key) - except ConfigParser.NoOptionError, det: + except ConfigParser.NoOptionError: if default == None: raise Exception("Config value not set: " + key) return default @@ -94,7 +94,7 @@ class Config(object): for v in s.split(","): res.append(v.strip()) return res - except ConfigParser.NoOptionError, det: + except ConfigParser.NoOptionError: if default == None: raise Exception("Config value not set: " + key) return default @@ -129,7 +129,7 @@ class Config(object): for key in keys: try: return self.cf.get(self.main_section, key) - except ConfigParser.NoOptionError, det: + except ConfigParser.NoOptionError: pass if default == None: diff --git a/python/skytools/dbstruct.py b/python/skytools/dbstruct.py index 1c7741c5..2a1c8e47 100644 --- a/python/skytools/dbstruct.py +++ b/python/skytools/dbstruct.py @@ -1,14 +1,15 @@ """Find table structure and allow CREATE/DROP elements from it. """ -import sys, re +import re from skytools.sqltools import fq_name_parts, get_table_oid -from skytools.quoting import quote_ident, quote_fqident +from skytools.quoting import quote_ident, quote_fqident, quote_literal, unquote_fqident -__all__ = ['TableStruct', +__all__ = ['TableStruct', 'SeqStruct', 'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER', - 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL'] + 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL', + 'T_SEQUENCE'] T_TABLE = 1 << 0 T_CONSTRAINT = 1 << 1 @@ -17,8 +18,9 @@ T_TRIGGER = 1 << 3 T_RULE = 1 << 4 T_GRANT = 1 << 5 T_OWNER = 1 << 6 +T_SEQUENCE = 1 << 7 T_PKEY = 1 << 20 # special, one of constraints -T_ALL = ( T_TABLE | T_CONSTRAINT | T_INDEX +T_ALL = ( T_TABLE | T_CONSTRAINT | T_INDEX | T_SEQUENCE | T_TRIGGER | T_RULE | T_GRANT | T_OWNER ) # @@ -63,7 +65,7 @@ class TElem(object): """Keeps info about one metadata object.""" SQL = "" type = 0 - def get_create_sql(self, curs): + def get_create_sql(self, curs, new_name = None): """Return SQL statement for creating or None if not supported.""" return None def get_drop_sql(self, curs): @@ -78,6 +80,7 @@ class TConstraint(TElem): FROM pg_constraint WHERE conrelid = %(oid)s AND contype != 'f' """ def __init__(self, table_name, row): + """Init constraint.""" self.table_name = table_name self.name = row['name'] self.defn = row['def'] @@ -88,6 +91,7 @@ class TConstraint(TElem): self.type += T_PKEY def get_create_sql(self, curs, new_table_name=None): + """Generate creation SQL.""" fmt = "ALTER TABLE ONLY %s ADD CONSTRAINT %s %s;" if new_table_name: name = self.name @@ -102,6 +106,7 @@ class TConstraint(TElem): return sql def get_drop_sql(self, curs): + """Generate removal sql.""" fmt = "ALTER TABLE ONLY %s DROP CONSTRAINT %s;" sql = fmt % (quote_fqident(self.table_name), quote_ident(self.name)) return sql @@ -126,6 +131,7 @@ class TIndex(TElem): self.defn = row['defn'] + ';' def get_create_sql(self, curs, new_table_name = None): + """Generate creation SQL.""" if not new_table_name: return self.defn # fixme: seems broken @@ -151,9 +157,10 @@ class TRule(TElem): self.defn = row['def'] def get_create_sql(self, curs, new_table_name = None): + """Generate creation SQL.""" if not new_table_name: return self.defn - # fixme: broken + # fixme: broken / quoting rx = r"\bTO[ ][a-z0-9._]+[ ]DO[ ]" pnew = "TO %s DO " % new_table_name return rx_replace(rx, self.defn, pnew) @@ -161,11 +168,12 @@ class TRule(TElem): def get_drop_sql(self, curs): return 'DROP RULE %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) + class TTrigger(TElem): """Info about trigger.""" type = T_TRIGGER SQL = """ - SELECT tgname as name, pg_get_triggerdef(oid) as def + SELECT tgname as name, pg_get_triggerdef(oid) as def FROM pg_trigger WHERE tgrelid = %(oid)s AND NOT tgisconstraint """ @@ -175,9 +183,11 @@ class TTrigger(TElem): self.defn = row['def'] + ';' def get_create_sql(self, curs, new_table_name = None): + """Generate creation SQL.""" if not new_table_name: return self.defn - # fixme: broken + + # fixme: broken / quoting rx = r"\bON[ ][a-z0-9._]+[ ]" pnew = "ON %s " % new_table_name return rx_replace(rx, self.defn, pnew) @@ -198,6 +208,7 @@ class TOwner(TElem): self.owner = row['owner'] def get_create_sql(self, curs, new_name = None): + """Generate creation SQL.""" if not new_name: new_name = self.table_name return 'ALTER TABLE %s OWNER TO %s;' % (quote_fqident(new_name), quote_ident(self.owner)) @@ -217,38 +228,40 @@ class TGrant(TElem): return ", ".join([ self.acl_map[c] for c in acl ]) def parse_relacl(self, relacl): + """Parse ACL to tuple of (user, acl, who)""" if relacl is None: return [] if len(relacl) > 0 and relacl[0] == '{' and relacl[-1] == '}': relacl = relacl[1:-1] - list = [] + tup_list = [] for f in relacl.split(','): user, tmp = f.strip('"').split('=') acl, who = tmp.split('/') - list.append((user, acl, who)) - return list + tup_list.append((user, acl, who)) + return tup_list def __init__(self, table_name, row, new_name = None): self.name = table_name self.acl_list = self.parse_relacl(row['relacl']) def get_create_sql(self, curs, new_name = None): + """Generate creation SQL.""" if not new_name: new_name = self.name - list = [] + sql_list = [] for user, acl, who in self.acl_list: astr = self.acl_to_grants(acl) sql = "GRANT %s ON %s TO %s;" % (astr, quote_fqident(new_name), quote_ident(user)) - list.append(sql) - return "\n".join(list) + sql_list.append(sql) + return "\n".join(sql_list) def get_drop_sql(self, curs): - list = [] + sql_list = [] for user, acl, who in self.acl_list: sql = "REVOKE ALL FROM %s ON %s;" % (quote_ident(user), quote_fqident(self.name)) - list.append(sql) - return "\n".join(list) + sql_list.append(sql) + return "\n".join(sql_list) class TColumn(TElem): """Info about table column.""" @@ -257,8 +270,9 @@ class TColumn(TElem): a.attname || ' ' || format_type(a.atttypid, a.atttypmod) || case when a.attnotnull then ' not null' else '' end - || case when a.atthasdef then ' ' || d.adsrc else '' end - as def + || case when a.atthasdef then ' default ' || d.adsrc else '' end + as def, + pg_get_serial_sequence(%(fq2name)s, a.attname) as seqname from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s @@ -266,9 +280,13 @@ class TColumn(TElem): and a.attnum > 0 order by a.attnum; """ + seqname = None def __init__(self, table_name, row): self.name = row['name'] self.column_def = row['def'] + self.sequence = None + if row['seqname']: + self.seqname = unquote_fqident(row['seqname']) class TTable(TElem): """Info about table only (columns).""" @@ -278,6 +296,7 @@ class TTable(TElem): self.col_list = col_list def get_create_sql(self, curs, new_name = None): + """Generate creation SQL.""" if not new_name: new_name = self.name sql = "create table %s (" % quote_fqident(new_name) @@ -287,53 +306,88 @@ class TTable(TElem): sep = ",\n\t" sql += "\n);" return sql - + def get_drop_sql(self, curs): return "DROP TABLE %s;" % quote_fqident(self.name) +class TSeq(TElem): + """Info about sequence.""" + type = T_SEQUENCE + SQL = """SELECT *, %(owner)s as "owner" from %(fqname)s """ + def __init__(self, seq_name, row): + self.name = seq_name + defn = '' + self.owner = row['owner'] + if row['increment_by'] != 1: + defn += ' INCREMENT BY %d' % row['increment_by'] + if row['min_value'] != 1: + defn += ' MINVALUE %d' % row['min_value'] + if row['max_value'] != 9223372036854775807: + defn += ' MAXVALUE %d' % row['max_value'] + last_value = row['last_value'] + if row['is_called']: + last_value += row['increment_by'] + if last_value >= row['max_value']: + raise Exception('duh, seq passed max_value') + if last_value != 1: + defn += ' START %d' % last_value + if row['cache_value'] != 1: + defn += ' CACHE %d' % row['cache_value'] + if row['is_cycled']: + defn += ' CYCLE ' + if self.owner: + defn += ' OWNED BY %s' % self.owner + self.defn = defn + + def get_create_sql(self, curs, new_seq_name = None): + """Generate creation SQL.""" + + # we are in table def, forget full def + if self.owner: + sql = "ALTER SEQUENCE %s OWNED BY %s" % ( + quote_fqident(self.name), self.owner ) + return sql + + name = self.name + if new_seq_name: + name = new_seq_name + sql = 'CREATE SEQUENCE %s %s;' % (quote_fqident(name), self.defn) + return sql + + def get_drop_sql(self, curs): + if self.owner: + return '' + return 'DROP SEQUENCE %s;' % quote_fqident(self.name) + # # Main table object, loads all the others # -class TableStruct(object): - """Collects and manages all info about table. +class BaseStruct(object): + """Collects and manages all info about a higher-level db object. Allow to issue CREATE/DROP statements about any group of elements. """ - def __init__(self, curs, table_name): + object_list = [] + def __init__(self, curs, name): """Initializes class by loading info about table_name from database.""" - self.table_name = table_name - - # fill args - schema, name = fq_name_parts(table_name) - args = { - 'schema': schema, - 'table': name, - 'oid': get_table_oid(curs, table_name), - 'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'), - } - - # load table struct - self.col_list = self._load_elem(curs, args, TColumn) - self.object_list = [ TTable(table_name, self.col_list) ] - - # load additional objects - to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner] - for eclass in to_load: - self.object_list += self._load_elem(curs, args, eclass) + self.name = name + self.fqname = quote_fqident(name) - def _load_elem(self, curs, args, eclass): - list = [] + def _load_elem(self, curs, name, args, eclass): + """Fetch element(s) from db.""" + elem_list = [] + #print "Loading %s, name=%s, args=%s" % (repr(eclass), repr(name), repr(args)) curs.execute(eclass.SQL % args) for row in curs.dictfetchall(): - list.append(eclass(self.table_name, row)) - return list + elem_list.append(eclass(name, row)) + return elem_list def create(self, curs, objs, new_table_name = None, log = None): """Issues CREATE statements for requested set of objects. - + If new_table_name is giver, creates table under that name and also tries to rename all indexes/constraints that conflict with existing table. @@ -361,6 +415,57 @@ class TableStruct(object): log.debug(sql) curs.execute(sql) + def get_create_sql(self, objs): + res = [] + for o in self.object_list: + if o.type & objs: + sql = o.get_create_sql(None, None) + if sql: + res.append(sql) + return "".join(res) + +class TableStruct(BaseStruct): + """Collects and manages all info about table. + + Allow to issue CREATE/DROP statements about any + group of elements. + """ + def __init__(self, curs, table_name): + """Initializes class by loading info about table_name from database.""" + + BaseStruct.__init__(self, curs, table_name) + + self.table_name = table_name + + # fill args + schema, name = fq_name_parts(table_name) + args = { + 'schema': schema, + 'table': name, + 'fqname': self.fqname, + 'fq2name': quote_literal(self.fqname), + 'oid': get_table_oid(curs, table_name), + 'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'), + } + + # load table struct + self.col_list = self._load_elem(curs, self.name, args, TColumn) + self.object_list = [ TTable(table_name, self.col_list) ] + self.seq_list = [] + + # load seqs + for col in self.col_list: + if col.seqname: + owner = self.fqname + '.' + quote_ident(col.name) + seq_args = { 'fqname': col.seqname, 'owner': quote_literal(owner) } + self.seq_list += self._load_elem(curs, col.seqname, seq_args, TSeq) + self.object_list += self.seq_list + + # load additional objects + to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner] + for eclass in to_load: + self.object_list += self._load_elem(curs, self.name, args, eclass) + def get_column_list(self): """Returns list of column names the table has.""" @@ -369,11 +474,28 @@ class TableStruct(object): res.append(c.name) return res +class SeqStruct(BaseStruct): + """Collects and manages all info about sequence. + + Allow to issue CREATE/DROP statements about any + group of elements. + """ + def __init__(self, curs, seq_name): + """Initializes class by loading info about table_name from database.""" + + BaseStruct.__init__(self, curs, seq_name) + + # fill args + args = { 'fqname': self.fqname, 'owner': 'null' } + + # load table struct + self.object_list = self._load_elem(curs, seq_name, args, TSeq) + def test(): from skytools import connect_database db = connect_database("dbname=fooz") curs = db.cursor() - + s = TableStruct(curs, "public.data1") s.drop(curs, T_ALL) diff --git a/python/skytools/installer_config.py.in b/python/skytools/installer_config.py.in index 06c9b956..a01621f7 100644 --- a/python/skytools/installer_config.py.in +++ b/python/skytools/installer_config.py.in @@ -1,4 +1,8 @@ +"""SQL script locations.""" + +__all__ = ['sql_locations'] + sql_locations = [ "@SQLDIR@", ] diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py index 3aa94991..36b17d53 100644 --- a/python/skytools/parsing.py +++ b/python/skytools/parsing.py @@ -42,11 +42,14 @@ def parse_pgarray(array): # class _logtriga_parser: + """Parses logtriga/sqltriga partial SQL to values.""" def tokenizer(self, sql): + """Token generator.""" for typ, tok in sql_tokenizer(sql, ignore_whitespace = True): yield tok def parse_insert(self, tk, fields, values): + """Handler for inserts.""" # (col1, col2) values ('data', null) if tk.next() != "(": raise Exception("syntax error") @@ -73,6 +76,7 @@ class _logtriga_parser: raise Exception("expected EOF, got " + repr(t)) def parse_update(self, tk, fields, values): + """Handler for updates.""" # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2' while 1: fields.append(tk.next()) @@ -97,6 +101,7 @@ class _logtriga_parser: raise Exception("syntax error, expected AND got "+repr(t)) def parse_delete(self, tk, fields, values): + """Handler for deletes.""" # pk1 = 'pk1' and pk2 = 'pk2' while 1: fields.append(tk.next()) @@ -108,6 +113,7 @@ class _logtriga_parser: raise Exception("syntax error, expected AND, got "+repr(t)) def parse_sql(self, op, sql): + """Main entry point.""" tk = self.tokenizer(sql) fields = [] values = [] diff --git a/python/skytools/psycopgwrapper.py b/python/skytools/psycopgwrapper.py index 9d23e506..7bce0c33 100644 --- a/python/skytools/psycopgwrapper.py +++ b/python/skytools/psycopgwrapper.py @@ -1,16 +1,60 @@ -"""Wrapper around psycopg1/2. +"""Wrapper around psycopg2. -Preferred is psycopg2, fallback to psycopg1. +Database connection provides regular DB-API 2.0 interface. -Interface provided is psycopg1: - - dict* methods. - - new columns can be assigned to row. +Connection object methods:: -""" + .cursor() + + .commit() + + .rollback() + + .close() + +Cursor methods:: + + .execute(query[, args]) + + .fetchone() + + .fetchall() + + +Sample usage:: -import sys + db = self.get_database('somedb') + curs = db.cursor() + + # query arguments as array + q = "select * from table where id = %s and name = %s" + curs.execute(q, [1, 'somename']) + + # query arguments as dict + q = "select id, name from table where id = %(id)s and name = %(name)s" + curs.execute(q, {'id': 1, 'name': 'somename'}) + + # loop over resultset + for row in curs.fetchall(): + + # columns can be asked by index: + id = row[0] + name = row[1] + + # and by name: + id = row['id'] + name = row['name'] + + # now commit the transaction + db.commit() + +Deprecated interface: .dictfetchall/.dictfetchone functions on cursor. +Plain .fetchall() / .fetchone() give exact same result. + +""" +# no exports __all__ = [] ##from psycopg2.psycopg1 import connect as _pgconnect @@ -54,6 +98,7 @@ class _CompatCursor(psycopg2.extras.DictCursor): class _CompatConnection(psycopg2.extensions.connection): """Connection object that uses _CompatCursor.""" + my_name = '?' def cursor(self): return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor) diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py index 8225c7b0..9d281254 100644 --- a/python/skytools/quoting.py +++ b/python/skytools/quoting.py @@ -12,7 +12,7 @@ __all__ = [ # local "quote_bytea_literal", "quote_bytea_copy", "quote_statement", "quote_ident", "quote_fqident", "quote_json", "unescape_copy", - "unquote_ident", + "unquote_ident", "unquote_fqident", ] try: @@ -34,15 +34,19 @@ def quote_bytea_copy(s): return quote_copy(quote_bytea_raw(s)) -def quote_statement(sql, dict): +def quote_statement(sql, dict_or_list): """Quote whole statement. - Data values are taken from dict. + Data values are taken from dict or list or tuple. """ - xdict = {} - for k, v in dict.items(): - xdict[k] = quote_literal(v) - return sql % xdict + if hasattr(dict_or_list, 'items'): + qdict = {} + for k, v in dict_or_list.items(): + qdict[k] = quote_literal(v) + return sql % qdict + else: + qvals = [quote_literal(v) for v in dict_or_list] + return sql % tuple(qvals) # reserved keywords _ident_kwmap = { @@ -58,6 +62,8 @@ _ident_kwmap = { "primary":1, "references":1, "returning":1, "select":1, "session_user":1, "some":1, "symmetric":1, "table":1, "then":1, "to":1, "trailing":1, "true":1, "union":1, "unique":1, "user":1, "using":1, "when":1, "where":1, +# greenplum? +"errors":1, } _ident_bad = re.compile(r"[^a-z0-9_]") @@ -90,6 +96,7 @@ _jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", } def _json_quote_char(m): + """Quote single char.""" c = m.group(0) try: return _jsmap[c] @@ -114,3 +121,11 @@ def unquote_ident(val): return val[1:-1].replace('""', '"') return val +def unquote_fqident(val): + """Unquotes fully-qualified possibly quoted SQL identifier. + + It must be prefixed schema, which does not contain dots. + """ + tmp = val.split('.', 1) + return "%s.%s" % (unquote_ident(tmp[0]), unquote_ident(tmp[1])) + diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py index b2c26bc0..a61154de 100644 --- a/python/skytools/scripting.py +++ b/python/skytools/scripting.py @@ -1,13 +1,29 @@ """Useful functions and classes for database scripts.""" -import sys, os, signal, optparse, traceback, time, errno +import sys, os, signal, optparse, time, errno import logging, logging.handlers, logging.config from skytools.config import * from skytools.psycopgwrapper import connect_database +from skytools.quoting import quote_statement import skytools.skylog +__pychecker__ = 'no-badexcept' + +#: how old connections need to be closed +DEF_CONN_AGE = 20*60 # 20 min + +#: isolation level not set +I_DEFAULT = -1 + +#: isolation level constant for AUTOCOMMIT +I_AUTOCOMMIT = 0 +#: isolation level constant for READ COMMITTED +I_READ_COMMITTED = 1 +#: isolation level constant for SERIALIZABLE +I_SERIALIZABLE = 2 + __all__ = ['DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE', 'signal_pidfile'] #__all__ += ['daemonize', 'run_single_process'] @@ -80,10 +96,10 @@ def run_single_process(runnable, daemon, pidfile): # check if another process is running if pidfile and os.path.isfile(pidfile): if signal_pidfile(pidfile, 0): - print "Pidfile exists, another process running?" + print("Pidfile exists, another process running?") sys.exit(1) else: - print "Ignoring stale pidfile" + print("Ignoring stale pidfile") # daemonize if needed and write pidfile if daemon: @@ -122,8 +138,8 @@ def _init_log(job_name, service_name, cf, log_level): skytools.skylog.set_service_name(service_name) # load general config - list = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini'] - for fn in list: + flist = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini'] + for fn in flist: fn = os.path.expanduser(fn) if os.path.isfile(fn): defs = {'job_name': job_name, 'service_name': service_name} @@ -163,33 +179,24 @@ def _init_log(job_name, service_name, cf, log_level): return log -#: how old connections need to be closed -DEF_CONN_AGE = 20*60 # 20 min - -#: isolation level not set -I_DEFAULT = -1 - -#: isolation level constant for AUTOCOMMIT -I_AUTOCOMMIT = 0 -#: isolation level constant for READ COMMITTED -I_READ_COMMITTED = 1 -#: isolation level constant for SERIALIZABLE -I_SERIALIZABLE = 2 - class DBCachedConn(object): """Cache a db connection.""" - def __init__(self, name, loc, max_age = DEF_CONN_AGE): + def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None): self.name = name self.loc = loc self.conn = None self.conn_time = 0 self.max_age = max_age self.autocommit = -1 - self.isolation_level = -1 + self.isolation_level = I_DEFAULT + self.verbose = verbose + self.setup_func = setup_func - def get_connection(self, autocommit = 0, isolation_level = -1): + def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT): # autocommit overrider isolation_level if autocommit: + if isolation_level == I_SERIALIZABLE: + raise Exception('autocommit is not compatible with I_SERIALIZABLE') isolation_level = I_AUTOCOMMIT # default isolation_level is READ COMMITTED @@ -200,9 +207,12 @@ class DBCachedConn(object): if not self.conn: self.isolation_level = isolation_level self.conn = connect_database(self.loc) + self.conn.my_name = self.name self.conn.set_isolation_level(isolation_level) self.conn_time = time.time() + if self.setup_func: + self.setup_func(self.name, self.conn) else: if self.isolation_level != isolation_level: raise Exception("Conflict in isolation_level") @@ -250,6 +260,8 @@ class DBScript(object): job_name = None cf = None log = None + pidfile = None + loop_delay = 1 def __init__(self, service_name, args): """Script setup. @@ -286,7 +298,7 @@ class DBScript(object): if self.options.verbose: self.log_level = logging.DEBUG if len(self.args) < 1: - print "need config file" + print("need config file") sys.exit(1) # read config file @@ -305,6 +317,8 @@ class DBScript(object): self.send_signal(signal.SIGHUP) def load_config(self): + """Loads and returns skytools.Config instance.""" + conf_file = self.args[0] return Config(self.service_name, conf_file) @@ -369,7 +383,21 @@ class DBScript(object): if not self.pidfile: self.log.error("Daemon needs pidfile") sys.exit(1) - run_single_process(self, self.go_daemon, self.pidfile) + + try: + run_single_process(self, self.go_daemon, self.pidfile) + except KeyboardInterrupt: + raise + except SystemExit: + raise + except Exception: + # catch startup errors + exc, msg, tb = sys.exc_info() + self.log.exception("Job %s crashed on startup: %s: %s" % ( + self.job_name, str(exc), str(msg).rstrip())) + del tb + sys.exit(1) + def stop(self): """Safely stops processing loop.""" @@ -386,9 +414,15 @@ class DBScript(object): "Internal SIGHUP handler. Minimal code here." self.need_reload = 1 + last_sigint = 0 def hook_sigint(self, sig, frame): "Internal SIGINT handler. Minimal code here." self.stop() + t = time.time() + if t - self.last_sigint < 1: + self.log.warning("Double ^C, fast exit") + sys.exit(1) + self.last_sigint = t def stat_add(self, key, value): """Old, deprecated function.""" @@ -419,6 +453,9 @@ class DBScript(object): self.log.info(logmsg) self.stat_dict = {} + def connection_setup(self, dbname, conn): + pass + def get_database(self, dbname, autocommit = 0, isolation_level = -1, cache = None, connstr = None): """Load cached database connection. @@ -435,7 +472,7 @@ class DBScript(object): else: if not connstr: connstr = self.cf.get(dbname) - dbc = DBCachedConn(cache, connstr, max_age) + dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_setup) self.db_cache[cache] = dbc return dbc.get_connection(autocommit, isolation_level) @@ -462,15 +499,14 @@ class DBScript(object): # run startup, safely try: self.startup() - except KeyboardInterrupt, det: + except KeyboardInterrupt: raise - except SystemExit, det: + except SystemExit: raise - except Exception, det: + except Exception: exc, msg, tb = sys.exc_info() - self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % ( - self.job_name, str(exc), str(msg).rstrip(), - str(tb), repr(traceback.format_tb(tb)))) + self.log.exception("Job %s crashed: %s: %s" % ( + self.job_name, str(exc), str(msg).rstrip())) del tb self.reset() sys.exit(1) @@ -523,10 +559,9 @@ class DBScript(object): except Exception, d: self.send_stats() exc, msg, tb = sys.exc_info() - self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % ( - self.job_name, str(exc), str(msg).rstrip(), - str(tb), repr(traceback.format_tb(tb)))) del tb + self.log.exception("Job %s crashed: %s: %s" % ( + self.job_name, str(exc), str(msg).rstrip())) self.reset() if self.looping and not self.do_single_loop: time.sleep(20) @@ -553,4 +588,82 @@ class DBScript(object): signal.signal(signal.SIGHUP, self.hook_sighup) signal.signal(signal.SIGINT, self.hook_sigint) + def _exec_cmd(self, curs, sql, args, quiet = False): + """Internal tool: Run SQL on cursor.""" + self.log.debug("exec_cmd: %s" % quote_statement(sql, args)) + curs.execute(sql, args) + ok = True + rows = curs.fetchall() + for row in rows: + try: + code = row['ret_code'] + msg = row['ret_note'] + except KeyError: + self.log.error("Query does not conform to exec_cmd API:") + self.log.error("SQL: %s" % quote_statement(sql, args)) + self.log.error("Row: %s" % repr(row.copy())) + sys.exit(1) + level = code / 100 + if level == 1: + self.log.debug("%d %s" % (code, msg)) + elif level == 2: + if quiet: + self.log.debug("%d %s" % (code, msg)) + else: + self.log.info("%s" % (msg,)) + elif level == 3: + self.log.warning("%s" % (msg,)) + else: + self.log.error("%s" % (msg,)) + self.log.error("Query was: %s" % quote_statement(sql, args)) + ok = False + return (ok, rows) + + def _exec_cmd_many(self, curs, sql, baseargs, extra_list, quiet = False): + """Internal tool: Run SQL on cursor multiple times.""" + ok = True + rows = [] + for a in extra_list: + (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a], quiet=quiet) + if not tmp_ok: + ok = False + rows += tmp_rows + return (ok, rows) + + def exec_cmd(self, db_or_curs, q, args, commit = True, quiet = False): + """Run SQL on db with code/value error handling.""" + if hasattr(db_or_curs, 'cursor'): + db = db_or_curs + curs = db.cursor() + else: + db = None + curs = db_or_curs + (ok, rows) = self._exec_cmd(curs, q, args, quiet = quiet) + if ok: + if commit and db: + db.commit() + return rows + else: + if db: + db.rollback() + raise Exception("db error") + + def exec_cmd_many(self, db_or_curs, sql, baseargs, extra_list, commit = True, quiet = False): + """Run SQL on db multiple times.""" + if hasattr(db_or_curs, 'cursor'): + db = db_or_curs + curs = db.cursor() + else: + db = None + curs = db_or_curs + (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list, quiet=quiet) + if ok: + if commit and db: + db.commit() + return rows + else: + if db: + db.rollback() + raise Exception("db error") + diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py index 883bbe8b..fd4fb855 100644 --- a/python/skytools/sqltools.py +++ b/python/skytools/sqltools.py @@ -9,6 +9,7 @@ import skytools.installer_config __all__ = [ "fq_name_parts", "fq_name", "get_table_oid", "get_table_pkeys", "get_table_columns", "exists_schema", "exists_table", "exists_type", + "exists_sequence", "exists_function", "exists_language", "Snapshot", "magic_insert", "CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction", "DBLanguage", "db_install", "installer_find_file", "installer_apply_file", @@ -19,9 +20,15 @@ class dbdict(dict): """Wrapper on actual dict that allows accessing dict keys as attributes.""" # obj.foo access - def __getattr__(self, k): return self[k] - def __setattr__(self, k, v): self[k] = v - def __delattr__(self, k): del self[k] + def __getattr__(self, k): + "Return attribute." + return self[k] + def __setattr__(self, k, v): + "Set attribute." + self[k] = v + def __delattr__(self, k): + "Remove attribute" + del self[k] # # Fully qualified table name @@ -46,6 +53,7 @@ def fq_name(tbl): # info about table # def get_table_oid(curs, table_name): + """Find Postgres OID for table.""" schema, name = fq_name_parts(table_name) q = """select c.oid from pg_namespace n, pg_class c where c.relnamespace = n.oid @@ -57,6 +65,7 @@ def get_table_oid(curs, table_name): return res[0][0] def get_table_pkeys(curs, tbl): + """Return list of pkey column names.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_index i, pg_attribute k"\ " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\ @@ -66,6 +75,7 @@ def get_table_pkeys(curs, tbl): return map(lambda x: x[0], curs.fetchall()) def get_table_columns(curs, tbl): + """Return list of column names for table.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_attribute k"\ " WHERE k.attrelid = %s"\ @@ -78,12 +88,14 @@ def get_table_columns(curs, tbl): # exist checks # def exists_schema(curs, schema): + """Does schema exists?""" q = "select count(1) from pg_namespace where nspname = %s" curs.execute(q, [schema]) res = curs.fetchone() return res[0] def exists_table(curs, table_name): + """Does table exists?""" schema, name = fq_name_parts(table_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'r' @@ -92,7 +104,18 @@ def exists_table(curs, table_name): res = curs.fetchone() return res[0] +def exists_sequence(curs, seq_name): + """Does sequence exists?""" + schema, name = fq_name_parts(seq_name) + q = """select count(1) from pg_namespace n, pg_class c + where c.relnamespace = n.oid and c.relkind = 'S' + and n.nspname = %s and c.relname = %s""" + curs.execute(q, [schema, name]) + res = curs.fetchone() + return res[0] + def exists_type(curs, type_name): + """Does type exists?""" schema, name = fq_name_parts(type_name) q = """select count(1) from pg_namespace n, pg_type t where t.typnamespace = n.oid @@ -102,6 +125,7 @@ def exists_type(curs, type_name): return res[0] def exists_function(curs, function_name, nargs): + """Does function exists?""" # this does not check arg types, so may match several functions schema, name = fq_name_parts(function_name) q = """select count(1) from pg_namespace n, pg_proc p @@ -118,6 +142,7 @@ def exists_function(curs, function_name, nargs): return res[0] def exists_language(curs, lang_name): + """Does PL exists?""" q = """select count(1) from pg_language where lanname = %s""" curs.execute(q, [lang_name]) @@ -331,11 +356,13 @@ class DBObject(object): sql = None sql_file = None def __init__(self, name, sql = None, sql_file = None): + """Generic dbobject init.""" self.name = name self.sql = sql self.sql_file = sql_file def create(self, curs, log = None): + """Create a dbobject.""" if log: log.info('Installing %s' % self.name) if self.sql: @@ -352,13 +379,14 @@ class DBObject(object): curs.execute(stmt) def find_file(self): + """Find install script file.""" full_fn = None if self.sql_file[0] == "/": full_fn = self.sql_file else: dir_list = skytools.installer_config.sql_locations - for dir in dir_list: - fn = os.path.join(dir, self.sql_file) + for fdir in dir_list: + fn = os.path.join(fdir, self.sql_file) if os.path.isfile(fn): full_fn = fn break @@ -370,26 +398,32 @@ class DBObject(object): class DBSchema(DBObject): """Handles db schema.""" def exists(self, curs): + """Does schema exists.""" return exists_schema(curs, self.name) class DBTable(DBObject): """Handles db table.""" def exists(self, curs): + """Does table exists.""" return exists_table(curs, self.name) class DBFunction(DBObject): """Handles db function.""" def __init__(self, name, nargs, sql = None, sql_file = None): + """Function object - number of args is significant.""" DBObject.__init__(self, name, sql, sql_file) self.nargs = nargs def exists(self, curs): + """Does function exists.""" return exists_function(curs, self.name, self.nargs) class DBLanguage(DBObject): """Handles db language.""" def __init__(self, name): + """PL object - creation happens with CREATE LANGUAGE.""" DBObject.__init__(self, name, sql = "create language %s" % name) def exists(self, curs): + """Does PL exists.""" return exists_language(curs, self.name) def db_install(curs, list, log = None): @@ -402,14 +436,15 @@ def db_install(curs, list, log = None): log.info('%s is installed' % obj.name) def installer_find_file(filename): + """Find SQL script from pre-defined paths.""" full_fn = None if filename[0] == "/": if os.path.isfile(filename): full_fn = filename else: dir_list = ["."] + skytools.installer_config.sql_locations - for dir in dir_list: - fn = os.path.join(dir, filename) + for fdir in dir_list: + fn = os.path.join(fdir, filename) if os.path.isfile(fn): full_fn = fn break @@ -419,6 +454,7 @@ def installer_find_file(filename): return full_fn def installer_apply_file(db, filename, log): + """Find SQL file and apply it to db, statement-by-statement.""" fn = installer_find_file(filename) sql = open(fn, "r").read() if log: |