summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Kreen2009-02-13 10:03:53 +0000
committerMarko Kreen2009-02-13 12:21:01 +0000
commit5521e5fc2f399a923fa7fb313bbe797cfa0d5baa (patch)
tree7df01ec195273b812bc9b64953a71ad1155d9438
parent012aee6a0369d8a4b2617046a27659c02b0bb478 (diff)
python/skytools update
- docstrings - some preliminary python 3.0 compat (var names, print()) - sync with 2.1-stable adminscript: - move exec_cmd function to dbscript dbstruct: - support sequnces. SERIAL columns are not automatically created, but the link beteween column and sequence is. psycopgwrapper: - drop support for psycopg1 - beginnings of quick DB-API / DictRow description. quoting: - new unquote_fqident() function, reverse of quote_fqident() - quote_statement() accepts both row and dict dbscript: - catch startup errors - use log.exception for exceptions, will result in nicer logs sqltools: - exists_sequence() _pyquoting: - fix typo in variable name
-rw-r--r--python/skytools/__init__.py4
-rw-r--r--python/skytools/_pyquoting.py4
-rw-r--r--python/skytools/adminscript.py87
-rw-r--r--python/skytools/config.py14
-rw-r--r--python/skytools/dbstruct.py220
-rw-r--r--python/skytools/installer_config.py.in4
-rw-r--r--python/skytools/parsing.py6
-rw-r--r--python/skytools/psycopgwrapper.py59
-rw-r--r--python/skytools/quoting.py29
-rw-r--r--python/skytools/scripting.py179
-rw-r--r--python/skytools/sqltools.py50
11 files changed, 477 insertions, 179 deletions
diff --git a/python/skytools/__init__.py b/python/skytools/__init__.py
index e42f06a3..b958b483 100644
--- a/python/skytools/__init__.py
+++ b/python/skytools/__init__.py
@@ -1,6 +1,10 @@
"""Tools for Python database scripts."""
+__version__ = '3.0'
+
+__pychecker__ = 'no-miximport'
+
import skytools.quoting
import skytools.config
import skytools.psycopgwrapper
diff --git a/python/skytools/_pyquoting.py b/python/skytools/_pyquoting.py
index 8f72eb5f..e5511687 100644
--- a/python/skytools/_pyquoting.py
+++ b/python/skytools/_pyquoting.py
@@ -136,6 +136,7 @@ _esc_map = {
}
def _sub_unescape_c(m):
+ """unescape single escape seq."""
v = m.group(1)
if (len(v) == 1) and (v < '0' or v > '7'):
try:
@@ -152,8 +153,9 @@ def unescape(val):
return _esc_rc.sub(_sub_unescape_c, val)
_esql_re = r"''|\\([0-7]{1,3}|.)"
-_esql_rc = re.compile(_esc_re)
+_esql_rc = re.compile(_esql_re)
def _sub_unescape_sqlext(m):
+ """Unescape extended-quoted string."""
if m.group() == "''":
return "'"
v = m.group(1)
diff --git a/python/skytools/adminscript.py b/python/skytools/adminscript.py
index 399e5cd3..ce51ae44 100644
--- a/python/skytools/adminscript.py
+++ b/python/skytools/adminscript.py
@@ -11,15 +11,25 @@ from skytools.quoting import quote_statement
__all__ = ['AdminScript']
class AdminScript(DBScript):
+ """Contains common admin script tools.
+
+ Second argument (first is .ini file) is takes as command
+ name. If class method 'cmd_' + arg exists, it is called,
+ otherwise error is given.
+ """
def __init__(self, service_name, args):
+ """AdminScript init."""
DBScript.__init__(self, service_name, args)
- self.pidfile = self.pidfile + ".admin"
+ if self.pidfile:
+ self.pidfile = self.pidfile + ".admin"
if len(self.args) < 2:
self.log.error("need command")
sys.exit(1)
def work(self):
+ """Non-looping work function, calls command function."""
+
self.set_single_loop(1)
cmd = self.args[1]
@@ -47,6 +57,7 @@ class AdminScript(DBScript):
fn(*cmdargs)
def fetch_list(self, db, sql, args, keycol = None):
+ """Fetch a resultset from db, optionally turnin it info value list."""
curs = db.cursor()
curs.execute(sql, args)
rows = curs.fetchall()
@@ -81,85 +92,25 @@ class AdminScript(DBScript):
fmt = '%%-%ds' * (len(widths) - 1) + '%%s'
fmt = fmt % tuple(widths[:-1])
if desc:
- print desc
- print fmt % tuple(fields)
- print fmt % tuple(['-'*15] * len(fields))
+ print(desc)
+ print(fmt % tuple(fields))
+ print(fmt % tuple(['-'*15] * len(fields)))
for row in rows:
- print fmt % tuple([row[k] for k in fields])
- print '\n'
+ print(fmt % tuple([row[k] for k in fields]))
+ print('\n')
return 1
- def _exec_cmd(self, curs, sql, args):
- self.log.debug("exec_cmd: %s" % quote_statement(sql, args))
- curs.execute(sql, args)
- ok = True
- rows = curs.fetchall()
- for row in rows:
- try:
- code = row['ret_code']
- msg = row['ret_note']
- except KeyError:
- self.log.error("Query does not conform to exec_cmd API:")
- self.log.error("SQL: %s" % quote_statement(sql, args))
- self.log.error("Row: %s" % repr(row.copy()))
- sys.exit(1)
- level = code / 100
- if level == 1:
- self.log.debug("%d %s" % (code, msg))
- elif level == 2:
- self.log.info("%d %s" % (code, msg))
- elif level == 3:
- self.log.warning("%d %s" % (code, msg))
- else:
- self.log.error("%d %s" % (code, msg))
- self.log.error("Query was: %s" % skytools.quote_statement(sql, args))
- ok = False
- return (ok, rows)
-
- def _exec_cmd_many(self, curs, sql, baseargs, extra_list):
- ok = True
- rows = []
- for a in extra_list:
- (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a])
- ok = tmp_ok and ok
- rows += tmp_rows
- return (ok, rows)
-
- def exec_cmd(self, db, q, args, commit = True):
- (ok, rows) = self._exec_cmd(db.cursor(), q, args)
- if ok:
- if commit:
- self.log.info("COMMIT")
- db.commit()
- return rows
- else:
- self.log.info("ROLLBACK")
- db.rollback()
- raise EXception("rollback")
-
- def exec_cmd_many(self, db, sql, baseargs, extra_list, commit = True):
- curs = db.cursor()
- (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list)
- if ok:
- if commit:
- self.log.info("COMMIT")
- db.commit()
- return rows
- else:
- self.log.info("ROLLBACK")
- db.rollback()
- raise EXception("rollback")
-
-
def exec_stmt(self, db, sql, args):
+ """Run regular non-query SQL on db."""
self.log.debug("exec_stmt: %s" % quote_statement(sql, args))
curs = db.cursor()
curs.execute(sql, args)
db.commit()
def exec_query(self, db, sql, args):
+ """Run regular query SQL on db."""
self.log.debug("exec_query: %s" % quote_statement(sql, args))
curs = db.cursor()
curs.execute(sql, args)
diff --git a/python/skytools/config.py b/python/skytools/config.py
index a0e96ec5..fbe0b6ed 100644
--- a/python/skytools/config.py
+++ b/python/skytools/config.py
@@ -1,7 +1,7 @@
"""Nicer config class."""
-import sys, os, ConfigParser, socket
+import os, ConfigParser, socket
__all__ = ['Config']
@@ -52,7 +52,7 @@ class Config(object):
"""Reads string value, if not set then default."""
try:
return self.cf.get(self.main_section, key)
- except ConfigParser.NoOptionError, det:
+ except ConfigParser.NoOptionError:
if default == None:
raise Exception("Config value not set: " + key)
return default
@@ -61,7 +61,7 @@ class Config(object):
"""Reads int value, if not set then default."""
try:
return self.cf.getint(self.main_section, key)
- except ConfigParser.NoOptionError, det:
+ except ConfigParser.NoOptionError:
if default == None:
raise Exception("Config value not set: " + key)
return default
@@ -70,7 +70,7 @@ class Config(object):
"""Reads boolean value, if not set then default."""
try:
return self.cf.getboolean(self.main_section, key)
- except ConfigParser.NoOptionError, det:
+ except ConfigParser.NoOptionError:
if default == None:
raise Exception("Config value not set: " + key)
return default
@@ -79,7 +79,7 @@ class Config(object):
"""Reads float value, if not set then default."""
try:
return self.cf.getfloat(self.main_section, key)
- except ConfigParser.NoOptionError, det:
+ except ConfigParser.NoOptionError:
if default == None:
raise Exception("Config value not set: " + key)
return default
@@ -94,7 +94,7 @@ class Config(object):
for v in s.split(","):
res.append(v.strip())
return res
- except ConfigParser.NoOptionError, det:
+ except ConfigParser.NoOptionError:
if default == None:
raise Exception("Config value not set: " + key)
return default
@@ -129,7 +129,7 @@ class Config(object):
for key in keys:
try:
return self.cf.get(self.main_section, key)
- except ConfigParser.NoOptionError, det:
+ except ConfigParser.NoOptionError:
pass
if default == None:
diff --git a/python/skytools/dbstruct.py b/python/skytools/dbstruct.py
index 1c7741c5..2a1c8e47 100644
--- a/python/skytools/dbstruct.py
+++ b/python/skytools/dbstruct.py
@@ -1,14 +1,15 @@
"""Find table structure and allow CREATE/DROP elements from it.
"""
-import sys, re
+import re
from skytools.sqltools import fq_name_parts, get_table_oid
-from skytools.quoting import quote_ident, quote_fqident
+from skytools.quoting import quote_ident, quote_fqident, quote_literal, unquote_fqident
-__all__ = ['TableStruct',
+__all__ = ['TableStruct', 'SeqStruct',
'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER',
- 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL']
+ 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL',
+ 'T_SEQUENCE']
T_TABLE = 1 << 0
T_CONSTRAINT = 1 << 1
@@ -17,8 +18,9 @@ T_TRIGGER = 1 << 3
T_RULE = 1 << 4
T_GRANT = 1 << 5
T_OWNER = 1 << 6
+T_SEQUENCE = 1 << 7
T_PKEY = 1 << 20 # special, one of constraints
-T_ALL = ( T_TABLE | T_CONSTRAINT | T_INDEX
+T_ALL = ( T_TABLE | T_CONSTRAINT | T_INDEX | T_SEQUENCE
| T_TRIGGER | T_RULE | T_GRANT | T_OWNER )
#
@@ -63,7 +65,7 @@ class TElem(object):
"""Keeps info about one metadata object."""
SQL = ""
type = 0
- def get_create_sql(self, curs):
+ def get_create_sql(self, curs, new_name = None):
"""Return SQL statement for creating or None if not supported."""
return None
def get_drop_sql(self, curs):
@@ -78,6 +80,7 @@ class TConstraint(TElem):
FROM pg_constraint WHERE conrelid = %(oid)s AND contype != 'f'
"""
def __init__(self, table_name, row):
+ """Init constraint."""
self.table_name = table_name
self.name = row['name']
self.defn = row['def']
@@ -88,6 +91,7 @@ class TConstraint(TElem):
self.type += T_PKEY
def get_create_sql(self, curs, new_table_name=None):
+ """Generate creation SQL."""
fmt = "ALTER TABLE ONLY %s ADD CONSTRAINT %s %s;"
if new_table_name:
name = self.name
@@ -102,6 +106,7 @@ class TConstraint(TElem):
return sql
def get_drop_sql(self, curs):
+ """Generate removal sql."""
fmt = "ALTER TABLE ONLY %s DROP CONSTRAINT %s;"
sql = fmt % (quote_fqident(self.table_name), quote_ident(self.name))
return sql
@@ -126,6 +131,7 @@ class TIndex(TElem):
self.defn = row['defn'] + ';'
def get_create_sql(self, curs, new_table_name = None):
+ """Generate creation SQL."""
if not new_table_name:
return self.defn
# fixme: seems broken
@@ -151,9 +157,10 @@ class TRule(TElem):
self.defn = row['def']
def get_create_sql(self, curs, new_table_name = None):
+ """Generate creation SQL."""
if not new_table_name:
return self.defn
- # fixme: broken
+ # fixme: broken / quoting
rx = r"\bTO[ ][a-z0-9._]+[ ]DO[ ]"
pnew = "TO %s DO " % new_table_name
return rx_replace(rx, self.defn, pnew)
@@ -161,11 +168,12 @@ class TRule(TElem):
def get_drop_sql(self, curs):
return 'DROP RULE %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name))
+
class TTrigger(TElem):
"""Info about trigger."""
type = T_TRIGGER
SQL = """
- SELECT tgname as name, pg_get_triggerdef(oid) as def
+ SELECT tgname as name, pg_get_triggerdef(oid) as def
FROM pg_trigger
WHERE tgrelid = %(oid)s AND NOT tgisconstraint
"""
@@ -175,9 +183,11 @@ class TTrigger(TElem):
self.defn = row['def'] + ';'
def get_create_sql(self, curs, new_table_name = None):
+ """Generate creation SQL."""
if not new_table_name:
return self.defn
- # fixme: broken
+
+ # fixme: broken / quoting
rx = r"\bON[ ][a-z0-9._]+[ ]"
pnew = "ON %s " % new_table_name
return rx_replace(rx, self.defn, pnew)
@@ -198,6 +208,7 @@ class TOwner(TElem):
self.owner = row['owner']
def get_create_sql(self, curs, new_name = None):
+ """Generate creation SQL."""
if not new_name:
new_name = self.table_name
return 'ALTER TABLE %s OWNER TO %s;' % (quote_fqident(new_name), quote_ident(self.owner))
@@ -217,38 +228,40 @@ class TGrant(TElem):
return ", ".join([ self.acl_map[c] for c in acl ])
def parse_relacl(self, relacl):
+ """Parse ACL to tuple of (user, acl, who)"""
if relacl is None:
return []
if len(relacl) > 0 and relacl[0] == '{' and relacl[-1] == '}':
relacl = relacl[1:-1]
- list = []
+ tup_list = []
for f in relacl.split(','):
user, tmp = f.strip('"').split('=')
acl, who = tmp.split('/')
- list.append((user, acl, who))
- return list
+ tup_list.append((user, acl, who))
+ return tup_list
def __init__(self, table_name, row, new_name = None):
self.name = table_name
self.acl_list = self.parse_relacl(row['relacl'])
def get_create_sql(self, curs, new_name = None):
+ """Generate creation SQL."""
if not new_name:
new_name = self.name
- list = []
+ sql_list = []
for user, acl, who in self.acl_list:
astr = self.acl_to_grants(acl)
sql = "GRANT %s ON %s TO %s;" % (astr, quote_fqident(new_name), quote_ident(user))
- list.append(sql)
- return "\n".join(list)
+ sql_list.append(sql)
+ return "\n".join(sql_list)
def get_drop_sql(self, curs):
- list = []
+ sql_list = []
for user, acl, who in self.acl_list:
sql = "REVOKE ALL FROM %s ON %s;" % (quote_ident(user), quote_fqident(self.name))
- list.append(sql)
- return "\n".join(list)
+ sql_list.append(sql)
+ return "\n".join(sql_list)
class TColumn(TElem):
"""Info about table column."""
@@ -257,8 +270,9 @@ class TColumn(TElem):
a.attname || ' '
|| format_type(a.atttypid, a.atttypmod)
|| case when a.attnotnull then ' not null' else '' end
- || case when a.atthasdef then ' ' || d.adsrc else '' end
- as def
+ || case when a.atthasdef then ' default ' || d.adsrc else '' end
+ as def,
+ pg_get_serial_sequence(%(fq2name)s, a.attname) as seqname
from pg_attribute a left join pg_attrdef d
on (d.adrelid = a.attrelid and d.adnum = a.attnum)
where a.attrelid = %(oid)s
@@ -266,9 +280,13 @@ class TColumn(TElem):
and a.attnum > 0
order by a.attnum;
"""
+ seqname = None
def __init__(self, table_name, row):
self.name = row['name']
self.column_def = row['def']
+ self.sequence = None
+ if row['seqname']:
+ self.seqname = unquote_fqident(row['seqname'])
class TTable(TElem):
"""Info about table only (columns)."""
@@ -278,6 +296,7 @@ class TTable(TElem):
self.col_list = col_list
def get_create_sql(self, curs, new_name = None):
+ """Generate creation SQL."""
if not new_name:
new_name = self.name
sql = "create table %s (" % quote_fqident(new_name)
@@ -287,53 +306,88 @@ class TTable(TElem):
sep = ",\n\t"
sql += "\n);"
return sql
-
+
def get_drop_sql(self, curs):
return "DROP TABLE %s;" % quote_fqident(self.name)
+class TSeq(TElem):
+ """Info about sequence."""
+ type = T_SEQUENCE
+ SQL = """SELECT *, %(owner)s as "owner" from %(fqname)s """
+ def __init__(self, seq_name, row):
+ self.name = seq_name
+ defn = ''
+ self.owner = row['owner']
+ if row['increment_by'] != 1:
+ defn += ' INCREMENT BY %d' % row['increment_by']
+ if row['min_value'] != 1:
+ defn += ' MINVALUE %d' % row['min_value']
+ if row['max_value'] != 9223372036854775807:
+ defn += ' MAXVALUE %d' % row['max_value']
+ last_value = row['last_value']
+ if row['is_called']:
+ last_value += row['increment_by']
+ if last_value >= row['max_value']:
+ raise Exception('duh, seq passed max_value')
+ if last_value != 1:
+ defn += ' START %d' % last_value
+ if row['cache_value'] != 1:
+ defn += ' CACHE %d' % row['cache_value']
+ if row['is_cycled']:
+ defn += ' CYCLE '
+ if self.owner:
+ defn += ' OWNED BY %s' % self.owner
+ self.defn = defn
+
+ def get_create_sql(self, curs, new_seq_name = None):
+ """Generate creation SQL."""
+
+ # we are in table def, forget full def
+ if self.owner:
+ sql = "ALTER SEQUENCE %s OWNED BY %s" % (
+ quote_fqident(self.name), self.owner )
+ return sql
+
+ name = self.name
+ if new_seq_name:
+ name = new_seq_name
+ sql = 'CREATE SEQUENCE %s %s;' % (quote_fqident(name), self.defn)
+ return sql
+
+ def get_drop_sql(self, curs):
+ if self.owner:
+ return ''
+ return 'DROP SEQUENCE %s;' % quote_fqident(self.name)
+
#
# Main table object, loads all the others
#
-class TableStruct(object):
- """Collects and manages all info about table.
+class BaseStruct(object):
+ """Collects and manages all info about a higher-level db object.
Allow to issue CREATE/DROP statements about any
group of elements.
"""
- def __init__(self, curs, table_name):
+ object_list = []
+ def __init__(self, curs, name):
"""Initializes class by loading info about table_name from database."""
- self.table_name = table_name
-
- # fill args
- schema, name = fq_name_parts(table_name)
- args = {
- 'schema': schema,
- 'table': name,
- 'oid': get_table_oid(curs, table_name),
- 'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'),
- }
-
- # load table struct
- self.col_list = self._load_elem(curs, args, TColumn)
- self.object_list = [ TTable(table_name, self.col_list) ]
-
- # load additional objects
- to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner]
- for eclass in to_load:
- self.object_list += self._load_elem(curs, args, eclass)
+ self.name = name
+ self.fqname = quote_fqident(name)
- def _load_elem(self, curs, args, eclass):
- list = []
+ def _load_elem(self, curs, name, args, eclass):
+ """Fetch element(s) from db."""
+ elem_list = []
+ #print "Loading %s, name=%s, args=%s" % (repr(eclass), repr(name), repr(args))
curs.execute(eclass.SQL % args)
for row in curs.dictfetchall():
- list.append(eclass(self.table_name, row))
- return list
+ elem_list.append(eclass(name, row))
+ return elem_list
def create(self, curs, objs, new_table_name = None, log = None):
"""Issues CREATE statements for requested set of objects.
-
+
If new_table_name is giver, creates table under that name
and also tries to rename all indexes/constraints that conflict
with existing table.
@@ -361,6 +415,57 @@ class TableStruct(object):
log.debug(sql)
curs.execute(sql)
+ def get_create_sql(self, objs):
+ res = []
+ for o in self.object_list:
+ if o.type & objs:
+ sql = o.get_create_sql(None, None)
+ if sql:
+ res.append(sql)
+ return "".join(res)
+
+class TableStruct(BaseStruct):
+ """Collects and manages all info about table.
+
+ Allow to issue CREATE/DROP statements about any
+ group of elements.
+ """
+ def __init__(self, curs, table_name):
+ """Initializes class by loading info about table_name from database."""
+
+ BaseStruct.__init__(self, curs, table_name)
+
+ self.table_name = table_name
+
+ # fill args
+ schema, name = fq_name_parts(table_name)
+ args = {
+ 'schema': schema,
+ 'table': name,
+ 'fqname': self.fqname,
+ 'fq2name': quote_literal(self.fqname),
+ 'oid': get_table_oid(curs, table_name),
+ 'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'),
+ }
+
+ # load table struct
+ self.col_list = self._load_elem(curs, self.name, args, TColumn)
+ self.object_list = [ TTable(table_name, self.col_list) ]
+ self.seq_list = []
+
+ # load seqs
+ for col in self.col_list:
+ if col.seqname:
+ owner = self.fqname + '.' + quote_ident(col.name)
+ seq_args = { 'fqname': col.seqname, 'owner': quote_literal(owner) }
+ self.seq_list += self._load_elem(curs, col.seqname, seq_args, TSeq)
+ self.object_list += self.seq_list
+
+ # load additional objects
+ to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner]
+ for eclass in to_load:
+ self.object_list += self._load_elem(curs, self.name, args, eclass)
+
def get_column_list(self):
"""Returns list of column names the table has."""
@@ -369,11 +474,28 @@ class TableStruct(object):
res.append(c.name)
return res
+class SeqStruct(BaseStruct):
+ """Collects and manages all info about sequence.
+
+ Allow to issue CREATE/DROP statements about any
+ group of elements.
+ """
+ def __init__(self, curs, seq_name):
+ """Initializes class by loading info about table_name from database."""
+
+ BaseStruct.__init__(self, curs, seq_name)
+
+ # fill args
+ args = { 'fqname': self.fqname, 'owner': 'null' }
+
+ # load table struct
+ self.object_list = self._load_elem(curs, seq_name, args, TSeq)
+
def test():
from skytools import connect_database
db = connect_database("dbname=fooz")
curs = db.cursor()
-
+
s = TableStruct(curs, "public.data1")
s.drop(curs, T_ALL)
diff --git a/python/skytools/installer_config.py.in b/python/skytools/installer_config.py.in
index 06c9b956..a01621f7 100644
--- a/python/skytools/installer_config.py.in
+++ b/python/skytools/installer_config.py.in
@@ -1,4 +1,8 @@
+"""SQL script locations."""
+
+__all__ = ['sql_locations']
+
sql_locations = [
"@SQLDIR@",
]
diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py
index 3aa94991..36b17d53 100644
--- a/python/skytools/parsing.py
+++ b/python/skytools/parsing.py
@@ -42,11 +42,14 @@ def parse_pgarray(array):
#
class _logtriga_parser:
+ """Parses logtriga/sqltriga partial SQL to values."""
def tokenizer(self, sql):
+ """Token generator."""
for typ, tok in sql_tokenizer(sql, ignore_whitespace = True):
yield tok
def parse_insert(self, tk, fields, values):
+ """Handler for inserts."""
# (col1, col2) values ('data', null)
if tk.next() != "(":
raise Exception("syntax error")
@@ -73,6 +76,7 @@ class _logtriga_parser:
raise Exception("expected EOF, got " + repr(t))
def parse_update(self, tk, fields, values):
+ """Handler for updates."""
# col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2'
while 1:
fields.append(tk.next())
@@ -97,6 +101,7 @@ class _logtriga_parser:
raise Exception("syntax error, expected AND got "+repr(t))
def parse_delete(self, tk, fields, values):
+ """Handler for deletes."""
# pk1 = 'pk1' and pk2 = 'pk2'
while 1:
fields.append(tk.next())
@@ -108,6 +113,7 @@ class _logtriga_parser:
raise Exception("syntax error, expected AND, got "+repr(t))
def parse_sql(self, op, sql):
+ """Main entry point."""
tk = self.tokenizer(sql)
fields = []
values = []
diff --git a/python/skytools/psycopgwrapper.py b/python/skytools/psycopgwrapper.py
index 9d23e506..7bce0c33 100644
--- a/python/skytools/psycopgwrapper.py
+++ b/python/skytools/psycopgwrapper.py
@@ -1,16 +1,60 @@
-"""Wrapper around psycopg1/2.
+"""Wrapper around psycopg2.
-Preferred is psycopg2, fallback to psycopg1.
+Database connection provides regular DB-API 2.0 interface.
-Interface provided is psycopg1:
- - dict* methods.
- - new columns can be assigned to row.
+Connection object methods::
-"""
+ .cursor()
+
+ .commit()
+
+ .rollback()
+
+ .close()
+
+Cursor methods::
+
+ .execute(query[, args])
+
+ .fetchone()
+
+ .fetchall()
+
+
+Sample usage::
-import sys
+ db = self.get_database('somedb')
+ curs = db.cursor()
+
+ # query arguments as array
+ q = "select * from table where id = %s and name = %s"
+ curs.execute(q, [1, 'somename'])
+
+ # query arguments as dict
+ q = "select id, name from table where id = %(id)s and name = %(name)s"
+ curs.execute(q, {'id': 1, 'name': 'somename'})
+
+ # loop over resultset
+ for row in curs.fetchall():
+
+ # columns can be asked by index:
+ id = row[0]
+ name = row[1]
+
+ # and by name:
+ id = row['id']
+ name = row['name']
+
+ # now commit the transaction
+ db.commit()
+
+Deprecated interface: .dictfetchall/.dictfetchone functions on cursor.
+Plain .fetchall() / .fetchone() give exact same result.
+
+"""
+# no exports
__all__ = []
##from psycopg2.psycopg1 import connect as _pgconnect
@@ -54,6 +98,7 @@ class _CompatCursor(psycopg2.extras.DictCursor):
class _CompatConnection(psycopg2.extensions.connection):
"""Connection object that uses _CompatCursor."""
+ my_name = '?'
def cursor(self):
return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor)
diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py
index 8225c7b0..9d281254 100644
--- a/python/skytools/quoting.py
+++ b/python/skytools/quoting.py
@@ -12,7 +12,7 @@ __all__ = [
# local
"quote_bytea_literal", "quote_bytea_copy", "quote_statement",
"quote_ident", "quote_fqident", "quote_json", "unescape_copy",
- "unquote_ident",
+ "unquote_ident", "unquote_fqident",
]
try:
@@ -34,15 +34,19 @@ def quote_bytea_copy(s):
return quote_copy(quote_bytea_raw(s))
-def quote_statement(sql, dict):
+def quote_statement(sql, dict_or_list):
"""Quote whole statement.
- Data values are taken from dict.
+ Data values are taken from dict or list or tuple.
"""
- xdict = {}
- for k, v in dict.items():
- xdict[k] = quote_literal(v)
- return sql % xdict
+ if hasattr(dict_or_list, 'items'):
+ qdict = {}
+ for k, v in dict_or_list.items():
+ qdict[k] = quote_literal(v)
+ return sql % qdict
+ else:
+ qvals = [quote_literal(v) for v in dict_or_list]
+ return sql % tuple(qvals)
# reserved keywords
_ident_kwmap = {
@@ -58,6 +62,8 @@ _ident_kwmap = {
"primary":1, "references":1, "returning":1, "select":1, "session_user":1,
"some":1, "symmetric":1, "table":1, "then":1, "to":1, "trailing":1, "true":1,
"union":1, "unique":1, "user":1, "using":1, "when":1, "where":1,
+# greenplum?
+"errors":1,
}
_ident_bad = re.compile(r"[^a-z0-9_]")
@@ -90,6 +96,7 @@ _jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r",
}
def _json_quote_char(m):
+ """Quote single char."""
c = m.group(0)
try:
return _jsmap[c]
@@ -114,3 +121,11 @@ def unquote_ident(val):
return val[1:-1].replace('""', '"')
return val
+def unquote_fqident(val):
+ """Unquotes fully-qualified possibly quoted SQL identifier.
+
+ It must be prefixed schema, which does not contain dots.
+ """
+ tmp = val.split('.', 1)
+ return "%s.%s" % (unquote_ident(tmp[0]), unquote_ident(tmp[1]))
+
diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py
index b2c26bc0..a61154de 100644
--- a/python/skytools/scripting.py
+++ b/python/skytools/scripting.py
@@ -1,13 +1,29 @@
"""Useful functions and classes for database scripts."""
-import sys, os, signal, optparse, traceback, time, errno
+import sys, os, signal, optparse, time, errno
import logging, logging.handlers, logging.config
from skytools.config import *
from skytools.psycopgwrapper import connect_database
+from skytools.quoting import quote_statement
import skytools.skylog
+__pychecker__ = 'no-badexcept'
+
+#: how old connections need to be closed
+DEF_CONN_AGE = 20*60 # 20 min
+
+#: isolation level not set
+I_DEFAULT = -1
+
+#: isolation level constant for AUTOCOMMIT
+I_AUTOCOMMIT = 0
+#: isolation level constant for READ COMMITTED
+I_READ_COMMITTED = 1
+#: isolation level constant for SERIALIZABLE
+I_SERIALIZABLE = 2
+
__all__ = ['DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE',
'signal_pidfile']
#__all__ += ['daemonize', 'run_single_process']
@@ -80,10 +96,10 @@ def run_single_process(runnable, daemon, pidfile):
# check if another process is running
if pidfile and os.path.isfile(pidfile):
if signal_pidfile(pidfile, 0):
- print "Pidfile exists, another process running?"
+ print("Pidfile exists, another process running?")
sys.exit(1)
else:
- print "Ignoring stale pidfile"
+ print("Ignoring stale pidfile")
# daemonize if needed and write pidfile
if daemon:
@@ -122,8 +138,8 @@ def _init_log(job_name, service_name, cf, log_level):
skytools.skylog.set_service_name(service_name)
# load general config
- list = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']
- for fn in list:
+ flist = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']
+ for fn in flist:
fn = os.path.expanduser(fn)
if os.path.isfile(fn):
defs = {'job_name': job_name, 'service_name': service_name}
@@ -163,33 +179,24 @@ def _init_log(job_name, service_name, cf, log_level):
return log
-#: how old connections need to be closed
-DEF_CONN_AGE = 20*60 # 20 min
-
-#: isolation level not set
-I_DEFAULT = -1
-
-#: isolation level constant for AUTOCOMMIT
-I_AUTOCOMMIT = 0
-#: isolation level constant for READ COMMITTED
-I_READ_COMMITTED = 1
-#: isolation level constant for SERIALIZABLE
-I_SERIALIZABLE = 2
-
class DBCachedConn(object):
"""Cache a db connection."""
- def __init__(self, name, loc, max_age = DEF_CONN_AGE):
+ def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None):
self.name = name
self.loc = loc
self.conn = None
self.conn_time = 0
self.max_age = max_age
self.autocommit = -1
- self.isolation_level = -1
+ self.isolation_level = I_DEFAULT
+ self.verbose = verbose
+ self.setup_func = setup_func
- def get_connection(self, autocommit = 0, isolation_level = -1):
+ def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT):
# autocommit overrider isolation_level
if autocommit:
+ if isolation_level == I_SERIALIZABLE:
+ raise Exception('autocommit is not compatible with I_SERIALIZABLE')
isolation_level = I_AUTOCOMMIT
# default isolation_level is READ COMMITTED
@@ -200,9 +207,12 @@ class DBCachedConn(object):
if not self.conn:
self.isolation_level = isolation_level
self.conn = connect_database(self.loc)
+ self.conn.my_name = self.name
self.conn.set_isolation_level(isolation_level)
self.conn_time = time.time()
+ if self.setup_func:
+ self.setup_func(self.name, self.conn)
else:
if self.isolation_level != isolation_level:
raise Exception("Conflict in isolation_level")
@@ -250,6 +260,8 @@ class DBScript(object):
job_name = None
cf = None
log = None
+ pidfile = None
+ loop_delay = 1
def __init__(self, service_name, args):
"""Script setup.
@@ -286,7 +298,7 @@ class DBScript(object):
if self.options.verbose:
self.log_level = logging.DEBUG
if len(self.args) < 1:
- print "need config file"
+ print("need config file")
sys.exit(1)
# read config file
@@ -305,6 +317,8 @@ class DBScript(object):
self.send_signal(signal.SIGHUP)
def load_config(self):
+ """Loads and returns skytools.Config instance."""
+
conf_file = self.args[0]
return Config(self.service_name, conf_file)
@@ -369,7 +383,21 @@ class DBScript(object):
if not self.pidfile:
self.log.error("Daemon needs pidfile")
sys.exit(1)
- run_single_process(self, self.go_daemon, self.pidfile)
+
+ try:
+ run_single_process(self, self.go_daemon, self.pidfile)
+ except KeyboardInterrupt:
+ raise
+ except SystemExit:
+ raise
+ except Exception:
+ # catch startup errors
+ exc, msg, tb = sys.exc_info()
+ self.log.exception("Job %s crashed on startup: %s: %s" % (
+ self.job_name, str(exc), str(msg).rstrip()))
+ del tb
+ sys.exit(1)
+
def stop(self):
"""Safely stops processing loop."""
@@ -386,9 +414,15 @@ class DBScript(object):
"Internal SIGHUP handler. Minimal code here."
self.need_reload = 1
+ last_sigint = 0
def hook_sigint(self, sig, frame):
"Internal SIGINT handler. Minimal code here."
self.stop()
+ t = time.time()
+ if t - self.last_sigint < 1:
+ self.log.warning("Double ^C, fast exit")
+ sys.exit(1)
+ self.last_sigint = t
def stat_add(self, key, value):
"""Old, deprecated function."""
@@ -419,6 +453,9 @@ class DBScript(object):
self.log.info(logmsg)
self.stat_dict = {}
+ def connection_setup(self, dbname, conn):
+ pass
+
def get_database(self, dbname, autocommit = 0, isolation_level = -1,
cache = None, connstr = None):
"""Load cached database connection.
@@ -435,7 +472,7 @@ class DBScript(object):
else:
if not connstr:
connstr = self.cf.get(dbname)
- dbc = DBCachedConn(cache, connstr, max_age)
+ dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_setup)
self.db_cache[cache] = dbc
return dbc.get_connection(autocommit, isolation_level)
@@ -462,15 +499,14 @@ class DBScript(object):
# run startup, safely
try:
self.startup()
- except KeyboardInterrupt, det:
+ except KeyboardInterrupt:
raise
- except SystemExit, det:
+ except SystemExit:
raise
- except Exception, det:
+ except Exception:
exc, msg, tb = sys.exc_info()
- self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % (
- self.job_name, str(exc), str(msg).rstrip(),
- str(tb), repr(traceback.format_tb(tb))))
+ self.log.exception("Job %s crashed: %s: %s" % (
+ self.job_name, str(exc), str(msg).rstrip()))
del tb
self.reset()
sys.exit(1)
@@ -523,10 +559,9 @@ class DBScript(object):
except Exception, d:
self.send_stats()
exc, msg, tb = sys.exc_info()
- self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % (
- self.job_name, str(exc), str(msg).rstrip(),
- str(tb), repr(traceback.format_tb(tb))))
del tb
+ self.log.exception("Job %s crashed: %s: %s" % (
+ self.job_name, str(exc), str(msg).rstrip()))
self.reset()
if self.looping and not self.do_single_loop:
time.sleep(20)
@@ -553,4 +588,82 @@ class DBScript(object):
signal.signal(signal.SIGHUP, self.hook_sighup)
signal.signal(signal.SIGINT, self.hook_sigint)
+ def _exec_cmd(self, curs, sql, args, quiet = False):
+ """Internal tool: Run SQL on cursor."""
+ self.log.debug("exec_cmd: %s" % quote_statement(sql, args))
+ curs.execute(sql, args)
+ ok = True
+ rows = curs.fetchall()
+ for row in rows:
+ try:
+ code = row['ret_code']
+ msg = row['ret_note']
+ except KeyError:
+ self.log.error("Query does not conform to exec_cmd API:")
+ self.log.error("SQL: %s" % quote_statement(sql, args))
+ self.log.error("Row: %s" % repr(row.copy()))
+ sys.exit(1)
+ level = code / 100
+ if level == 1:
+ self.log.debug("%d %s" % (code, msg))
+ elif level == 2:
+ if quiet:
+ self.log.debug("%d %s" % (code, msg))
+ else:
+ self.log.info("%s" % (msg,))
+ elif level == 3:
+ self.log.warning("%s" % (msg,))
+ else:
+ self.log.error("%s" % (msg,))
+ self.log.error("Query was: %s" % quote_statement(sql, args))
+ ok = False
+ return (ok, rows)
+
+ def _exec_cmd_many(self, curs, sql, baseargs, extra_list, quiet = False):
+ """Internal tool: Run SQL on cursor multiple times."""
+ ok = True
+ rows = []
+ for a in extra_list:
+ (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a], quiet=quiet)
+ if not tmp_ok:
+ ok = False
+ rows += tmp_rows
+ return (ok, rows)
+
+ def exec_cmd(self, db_or_curs, q, args, commit = True, quiet = False):
+ """Run SQL on db with code/value error handling."""
+ if hasattr(db_or_curs, 'cursor'):
+ db = db_or_curs
+ curs = db.cursor()
+ else:
+ db = None
+ curs = db_or_curs
+ (ok, rows) = self._exec_cmd(curs, q, args, quiet = quiet)
+ if ok:
+ if commit and db:
+ db.commit()
+ return rows
+ else:
+ if db:
+ db.rollback()
+ raise Exception("db error")
+
+ def exec_cmd_many(self, db_or_curs, sql, baseargs, extra_list, commit = True, quiet = False):
+ """Run SQL on db multiple times."""
+ if hasattr(db_or_curs, 'cursor'):
+ db = db_or_curs
+ curs = db.cursor()
+ else:
+ db = None
+ curs = db_or_curs
+ (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list, quiet=quiet)
+ if ok:
+ if commit and db:
+ db.commit()
+ return rows
+ else:
+ if db:
+ db.rollback()
+ raise Exception("db error")
+
diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py
index 883bbe8b..fd4fb855 100644
--- a/python/skytools/sqltools.py
+++ b/python/skytools/sqltools.py
@@ -9,6 +9,7 @@ import skytools.installer_config
__all__ = [
"fq_name_parts", "fq_name", "get_table_oid", "get_table_pkeys",
"get_table_columns", "exists_schema", "exists_table", "exists_type",
+ "exists_sequence",
"exists_function", "exists_language", "Snapshot", "magic_insert",
"CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction",
"DBLanguage", "db_install", "installer_find_file", "installer_apply_file",
@@ -19,9 +20,15 @@ class dbdict(dict):
"""Wrapper on actual dict that allows
accessing dict keys as attributes."""
# obj.foo access
- def __getattr__(self, k): return self[k]
- def __setattr__(self, k, v): self[k] = v
- def __delattr__(self, k): del self[k]
+ def __getattr__(self, k):
+ "Return attribute."
+ return self[k]
+ def __setattr__(self, k, v):
+ "Set attribute."
+ self[k] = v
+ def __delattr__(self, k):
+ "Remove attribute"
+ del self[k]
#
# Fully qualified table name
@@ -46,6 +53,7 @@ def fq_name(tbl):
# info about table
#
def get_table_oid(curs, table_name):
+ """Find Postgres OID for table."""
schema, name = fq_name_parts(table_name)
q = """select c.oid from pg_namespace n, pg_class c
where c.relnamespace = n.oid
@@ -57,6 +65,7 @@ def get_table_oid(curs, table_name):
return res[0][0]
def get_table_pkeys(curs, tbl):
+ """Return list of pkey column names."""
oid = get_table_oid(curs, tbl)
q = "SELECT k.attname FROM pg_index i, pg_attribute k"\
" WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\
@@ -66,6 +75,7 @@ def get_table_pkeys(curs, tbl):
return map(lambda x: x[0], curs.fetchall())
def get_table_columns(curs, tbl):
+ """Return list of column names for table."""
oid = get_table_oid(curs, tbl)
q = "SELECT k.attname FROM pg_attribute k"\
" WHERE k.attrelid = %s"\
@@ -78,12 +88,14 @@ def get_table_columns(curs, tbl):
# exist checks
#
def exists_schema(curs, schema):
+ """Does schema exists?"""
q = "select count(1) from pg_namespace where nspname = %s"
curs.execute(q, [schema])
res = curs.fetchone()
return res[0]
def exists_table(curs, table_name):
+ """Does table exists?"""
schema, name = fq_name_parts(table_name)
q = """select count(1) from pg_namespace n, pg_class c
where c.relnamespace = n.oid and c.relkind = 'r'
@@ -92,7 +104,18 @@ def exists_table(curs, table_name):
res = curs.fetchone()
return res[0]
+def exists_sequence(curs, seq_name):
+ """Does sequence exists?"""
+ schema, name = fq_name_parts(seq_name)
+ q = """select count(1) from pg_namespace n, pg_class c
+ where c.relnamespace = n.oid and c.relkind = 'S'
+ and n.nspname = %s and c.relname = %s"""
+ curs.execute(q, [schema, name])
+ res = curs.fetchone()
+ return res[0]
+
def exists_type(curs, type_name):
+ """Does type exists?"""
schema, name = fq_name_parts(type_name)
q = """select count(1) from pg_namespace n, pg_type t
where t.typnamespace = n.oid
@@ -102,6 +125,7 @@ def exists_type(curs, type_name):
return res[0]
def exists_function(curs, function_name, nargs):
+ """Does function exists?"""
# this does not check arg types, so may match several functions
schema, name = fq_name_parts(function_name)
q = """select count(1) from pg_namespace n, pg_proc p
@@ -118,6 +142,7 @@ def exists_function(curs, function_name, nargs):
return res[0]
def exists_language(curs, lang_name):
+ """Does PL exists?"""
q = """select count(1) from pg_language
where lanname = %s"""
curs.execute(q, [lang_name])
@@ -331,11 +356,13 @@ class DBObject(object):
sql = None
sql_file = None
def __init__(self, name, sql = None, sql_file = None):
+ """Generic dbobject init."""
self.name = name
self.sql = sql
self.sql_file = sql_file
def create(self, curs, log = None):
+ """Create a dbobject."""
if log:
log.info('Installing %s' % self.name)
if self.sql:
@@ -352,13 +379,14 @@ class DBObject(object):
curs.execute(stmt)
def find_file(self):
+ """Find install script file."""
full_fn = None
if self.sql_file[0] == "/":
full_fn = self.sql_file
else:
dir_list = skytools.installer_config.sql_locations
- for dir in dir_list:
- fn = os.path.join(dir, self.sql_file)
+ for fdir in dir_list:
+ fn = os.path.join(fdir, self.sql_file)
if os.path.isfile(fn):
full_fn = fn
break
@@ -370,26 +398,32 @@ class DBObject(object):
class DBSchema(DBObject):
"""Handles db schema."""
def exists(self, curs):
+ """Does schema exists."""
return exists_schema(curs, self.name)
class DBTable(DBObject):
"""Handles db table."""
def exists(self, curs):
+ """Does table exists."""
return exists_table(curs, self.name)
class DBFunction(DBObject):
"""Handles db function."""
def __init__(self, name, nargs, sql = None, sql_file = None):
+ """Function object - number of args is significant."""
DBObject.__init__(self, name, sql, sql_file)
self.nargs = nargs
def exists(self, curs):
+ """Does function exists."""
return exists_function(curs, self.name, self.nargs)
class DBLanguage(DBObject):
"""Handles db language."""
def __init__(self, name):
+ """PL object - creation happens with CREATE LANGUAGE."""
DBObject.__init__(self, name, sql = "create language %s" % name)
def exists(self, curs):
+ """Does PL exists."""
return exists_language(curs, self.name)
def db_install(curs, list, log = None):
@@ -402,14 +436,15 @@ def db_install(curs, list, log = None):
log.info('%s is installed' % obj.name)
def installer_find_file(filename):
+ """Find SQL script from pre-defined paths."""
full_fn = None
if filename[0] == "/":
if os.path.isfile(filename):
full_fn = filename
else:
dir_list = ["."] + skytools.installer_config.sql_locations
- for dir in dir_list:
- fn = os.path.join(dir, filename)
+ for fdir in dir_list:
+ fn = os.path.join(fdir, filename)
if os.path.isfile(fn):
full_fn = fn
break
@@ -419,6 +454,7 @@ def installer_find_file(filename):
return full_fn
def installer_apply_file(db, filename, log):
+ """Find SQL file and apply it to db, statement-by-statement."""
fn = installer_find_file(filename)
sql = open(fn, "r").read()
if log: