summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Kreen2011-04-07 12:57:28 +0000
committerMarko Kreen2011-04-15 10:20:21 +0000
commit68d1f210514bd42f6dec4492f1addd457f313c5d (patch)
tree99cfc9420b0175bd8e7f29165cb3bf0a7ade22e1
parent205de8b710a2489d101352f1390cbe3dbdbcd9c0 (diff)
skytools: Separate generic scripting from DBScript
-rw-r--r--python/pgq/cascade/consumer.py3
-rw-r--r--python/skytools/scripting.py462
2 files changed, 258 insertions, 207 deletions
diff --git a/python/pgq/cascade/consumer.py b/python/pgq/cascade/consumer.py
index fdc8a48b..798ee7e7 100644
--- a/python/pgq/cascade/consumer.py
+++ b/python/pgq/cascade/consumer.py
@@ -274,7 +274,7 @@ class CascadedConsumer(Consumer):
q = "select * from pgq_node.set_consumer_completed(%s, %s, %s)"
self.exec_cmd(dst_db, q, [ self.queue_name, self.consumer_name, tick_id ])
- def exception_hook(self, det, emsg, cname):
+ def exception_hook(self, det, emsg):
try:
dst_db = self.get_database(self.target_db)
q = "select * from pgq_node.set_consumer_error(%s, %s, %s)"
@@ -282,4 +282,5 @@ class CascadedConsumer(Consumer):
except:
self.log.warning("Failure to call pgq_node.set_consumer_error()")
self.reset()
+ Consumer.exception_hook(self, det, emsg)
diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py
index 298c1efa..a8baffd6 100644
--- a/python/skytools/scripting.py
+++ b/python/skytools/scripting.py
@@ -17,22 +17,8 @@ except ImportError:
__pychecker__ = 'no-badexcept'
-#: how old connections need to be closed
-DEF_CONN_AGE = 20*60 # 20 min
-
-#: isolation level not set
-I_DEFAULT = -1
-
-#: isolation level constant for AUTOCOMMIT
-I_AUTOCOMMIT = 0
-#: isolation level constant for READ COMMITTED
-I_READ_COMMITTED = 1
-#: isolation level constant for SERIALIZABLE
-I_SERIALIZABLE = 2
-
-__all__ = ['DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE',
- 'signal_pidfile', 'UsageError']
-#__all__ += ['daemonize', 'run_single_process']
+__all__ = ['BaseScript', 'signal_pidfile', 'UsageError', 'daemonize',
+ 'DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE']
class UsageError(Exception):
"""User induced error."""
@@ -197,97 +183,10 @@ def _init_log(job_name, service_name, cf, log_level, is_daemon):
return log
-class DBCachedConn(object):
- """Cache a db connection."""
- def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]):
- self.name = name
- self.loc = loc
- self.conn = None
- self.conn_time = 0
- self.max_age = max_age
- self.autocommit = -1
- self.isolation_level = I_DEFAULT
- self.verbose = verbose
- self.setup_func = setup_func
- self.listen_channel_list = []
-
- def fileno(self):
- if not self.conn:
- return None
- return self.conn.cursor().fileno()
-
- def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []):
- # autocommit overrider isolation_level
- if autocommit:
- if isolation_level == I_SERIALIZABLE:
- raise Exception('autocommit is not compatible with I_SERIALIZABLE')
- isolation_level = I_AUTOCOMMIT
-
- # default isolation_level is READ COMMITTED
- if isolation_level < 0:
- isolation_level = I_READ_COMMITTED
-
- # new conn?
- if not self.conn:
- self.isolation_level = isolation_level
- self.conn = skytools.connect_database(self.loc)
- self.conn.my_name = self.name
-
- self.conn.set_isolation_level(isolation_level)
- self.conn_time = time.time()
- if self.setup_func:
- self.setup_func(self.name, self.conn)
- else:
- if self.isolation_level != isolation_level:
- raise Exception("Conflict in isolation_level")
- self._sync_listen(listen_channel_list)
- # done
- return self.conn
-
- def _sync_listen(self, new_clist):
- if not new_clist and not self.listen_channel_list:
- return
- curs = self.conn.cursor()
- for ch in self.listen_channel_list:
- if ch not in new_clist:
- curs.execute("UNLISTEN %s" % skytools.quote_ident(ch))
- for ch in new_clist:
- if ch not in self.listen_channel_list:
- curs.execute("LISTEN %s" % skytools.quote_ident(ch))
- if self.isolation_level != I_AUTOCOMMIT:
- self.conn.commit()
- self.listen_channel_list = new_clist[:]
-
- def refresh(self):
- if not self.conn:
- return
- #for row in self.conn.notifies():
- # if row[0].lower() == "reload":
- # self.reset()
- # return
- if not self.max_age:
- return
- if time.time() - self.conn_time >= self.max_age:
- self.reset()
-
- def reset(self):
- if not self.conn:
- return
-
- # drop reference
- conn = self.conn
- self.conn = None
- self.listen_channel_list = []
-
- # close
- try:
- conn.close()
- except: pass
-
-class DBScript(object):
- """Base class for database scripts.
+class BaseScript(object):
+ """Base class for service scripts.
Handles logging, daemonizing, config, errors.
@@ -314,9 +213,6 @@ class DBScript(object):
# 1 - enabled, unless non-daemon on console (os.isatty())
# 2 - always enabled
#use_skylog = 0
-
- # default lifetime for database connections (in seconds)
- #connection_lifetime = 1200
"""
service_name = None
job_name = None
@@ -353,12 +249,10 @@ class DBScript(object):
@param args: cmdline args (sys.argv[1:]), but can be overrided
"""
self.service_name = service_name
- self.db_cache = {}
self.go_daemon = 0
self.need_reload = 0
self.stat_dict = {}
self.log_level = logging.INFO
- self._listen_map = {} # dbname: channel_list
# parse command line
parser = self.init_optparse()
@@ -591,50 +485,9 @@ class DBScript(object):
self.log.info(logmsg)
self.stat_dict = {}
- def connection_hook(self, dbname, conn):
- pass
-
- def get_database(self, dbname, autocommit = 0, isolation_level = -1,
- cache = None, connstr = None):
- """Load cached database connection.
-
- User must not store it permanently somewhere,
- as all connections will be invalidated on reset.
- """
-
- max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE)
- if not cache:
- cache = dbname
- if cache in self.db_cache:
- dbc = self.db_cache[cache]
- else:
- if not connstr:
- connstr = self.cf.get(dbname)
- self.log.debug("Connect '%s' to '%s'" % (cache, connstr))
- dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook)
- self.db_cache[cache] = dbc
-
- clist = []
- if cache in self._listen_map:
- clist = self._listen_map[cache]
-
- return dbc.get_connection(autocommit, isolation_level, clist)
-
- def close_database(self, dbname):
- """Explicitly close a cached connection.
-
- Next call to get_database() will reconnect.
- """
- if dbname in self.db_cache:
- dbc = self.db_cache[dbname]
- dbc.reset()
- del self.db_cache[dbname]
-
def reset(self):
- "Something bad happened, reset all connections."
- for dbc in self.db_cache.values():
- dbc.reset()
- self.db_cache = {}
+ "Something bad happened, reset all state."
+ pass
def run(self):
"Thread main loop."
@@ -651,13 +504,6 @@ class DBScript(object):
# do some work
work = self.run_once()
- # send stats that was added
- self.send_stats()
-
- # reconnect if needed
- for dbc in self.db_cache.values():
- dbc.refresh()
-
if not self.looping or self.loop_delay < 0:
break
@@ -673,7 +519,12 @@ class DBScript(object):
break
def run_once(self):
- return self.run_func_safely(self.work, True)
+ state = self.run_func_safely(self.work, True)
+
+ # send stats that was added
+ self.send_stats()
+
+ return state
def run_func_safely(self, func, prefer_looping = False):
"Run users work function, safely."
@@ -702,31 +553,13 @@ class DBScript(object):
self.log.info("got KeyboardInterrupt, exiting")
self.reset()
sys.exit(1)
- except skytools.DBError, d:
- self.send_stats()
- if d.cursor and d.cursor.connection:
- cname = d.cursor.connection.my_name
- dsn = d.cursor.connection.dsn
- sql = d.cursor.query
- if len(sql) > 200: # avoid logging londiste huge batched queries
- sql = sql[:60] + " ..."
- emsg = str(d).strip()
- self.log.exception("Job %s got error on connection '%s': %s. Query: %s" % (
- self.job_name, cname, emsg, sql))
- else:
- n = "psycopg2.%s" % d.__class__.__name__
- emsg = str(d).rstrip()
- self.log.exception("Job %s crashed: %s: %s" % (
- self.job_name, n, emsg))
except Exception, d:
self.send_stats()
emsg = str(d).rstrip()
- self.log.exception("Job %s crashed: %s" % (
- self.job_name, emsg))
-
+ self.reset()
+ self.exception_hook(d, emsg)
# reset and sleep
self.reset()
- self.exception_hook(d, emsg, cname)
if prefer_looping and self.looping and self.loop_delay > 0:
self.sleep(20)
return -1
@@ -734,39 +567,18 @@ class DBScript(object):
def sleep(self, secs):
"""Make script sleep for some amount of time."""
- fdlist = []
- for dbname in self._listen_map.keys():
- if dbname not in self.db_cache:
- continue
- fd = self.db_cache[dbname].fileno()
- if fd is None:
- continue
- fdlist.append(fd)
-
- if not fdlist:
- return time.sleep(secs)
-
- try:
- if hasattr(select, 'poll'):
- p = select.poll()
- for fd in fdlist:
- p.register(fd, select.POLLIN)
- p.poll(int(secs * 1000))
- else:
- select.select(fdlist, [], [], secs)
- except select.error, d:
- self.log.info('wait canceled')
+ time.sleep(secs)
- def exception_hook(self, det, emsg, cname):
+ def exception_hook(self, det, emsg):
"""Called on after exception processing.
Can do additional logging.
@param det: exception details
@param emsg: exception msg
- @param cname: connection name or None
"""
- pass
+ self.log.exception("Job %s crashed: %s" % (
+ self.job_name, emsg))
def work(self):
"""Here should user's processing happen.
@@ -787,6 +599,152 @@ class DBScript(object):
signal.signal(signal.SIGHUP, self.hook_sighup)
signal.signal(signal.SIGINT, self.hook_sigint)
+##
+## DBScript
+##
+
+#: how old connections need to be closed
+DEF_CONN_AGE = 20*60 # 20 min
+
+#: isolation level not set
+I_DEFAULT = -1
+
+#: isolation level constant for AUTOCOMMIT
+I_AUTOCOMMIT = 0
+#: isolation level constant for READ COMMITTED
+I_READ_COMMITTED = 1
+#: isolation level constant for SERIALIZABLE
+I_SERIALIZABLE = 2
+
+
+class DBScript(BaseScript):
+ """Base class for database scripts.
+
+ Handles database connection state.
+
+ Config template::
+
+ ## Parameters for skytools.DBScript ##
+
+ # default lifetime for database connections (in seconds)
+ #connection_lifetime = 1200
+ """
+
+ def __init__(self, service_name, args):
+ """Script setup.
+
+ User class should override work() and optionally __init__(), startup(),
+ reload(), reset() and init_optparse().
+
+ NB: in case of daemon, the __init__() and startup()/work() will be
+ run in different processes. So nothing fancy should be done in __init__().
+
+ @param service_name: unique name for script.
+ It will be also default job_name, if not specified in config.
+ @param args: cmdline args (sys.argv[1:]), but can be overrided
+ """
+ self.db_cache = {}
+ self._listen_map = {} # dbname: channel_list
+ BaseScript.__init__(self, service_name, args)
+
+ def connection_hook(self, dbname, conn):
+ pass
+
+ def get_database(self, dbname, autocommit = 0, isolation_level = -1,
+ cache = None, connstr = None):
+ """Load cached database connection.
+
+ User must not store it permanently somewhere,
+ as all connections will be invalidated on reset.
+ """
+
+ max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE)
+ if not cache:
+ cache = dbname
+ if cache in self.db_cache:
+ dbc = self.db_cache[cache]
+ else:
+ if not connstr:
+ connstr = self.cf.get(dbname)
+ self.log.debug("Connect '%s' to '%s'" % (cache, connstr))
+ dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook)
+ self.db_cache[cache] = dbc
+
+ clist = []
+ if cache in self._listen_map:
+ clist = self._listen_map[cache]
+
+ return dbc.get_connection(autocommit, isolation_level, clist)
+
+ def close_database(self, dbname):
+ """Explicitly close a cached connection.
+
+ Next call to get_database() will reconnect.
+ """
+ if dbname in self.db_cache:
+ dbc = self.db_cache[dbname]
+ dbc.reset()
+ del self.db_cache[dbname]
+
+ def reset(self):
+ "Something bad happened, reset all connections."
+ for dbc in self.db_cache.values():
+ dbc.reset()
+ self.db_cache = {}
+ BaseScript.reset(self)
+
+ def run_once(self):
+ state = BaseScript.run_once(self)
+
+ # reconnect if needed
+ for dbc in self.db_cache.values():
+ dbc.refresh()
+
+ return state
+
+ def exception_hook(self, d, emsg):
+ """Log database and query details from exception."""
+ curs = getattr(d, 'cursor', None)
+ conn = getattr(curs, 'connection', None)
+ cname = getattr(conn, 'my_name', None)
+ if cname:
+ # Properly named connection
+ cname = d.cursor.connection.my_name
+ dsn = getattr(conn, 'dsn', '?')
+ sql = getattr(curs, 'query', '?')
+ if len(sql) > 200: # avoid logging londiste huge batched queries
+ sql = sql[:60] + " ..."
+ emsg = str(d).strip()
+ self.log.exception("Job %s got error on connection '%s': %s. Query: %s" % (
+ self.job_name, cname, emsg, sql))
+ else:
+ BaseScript.exception_hook(self, d, emsg)
+
+ def sleep(self, secs):
+ """Make script sleep for some amount of time."""
+ fdlist = []
+ for dbname in self._listen_map.keys():
+ if dbname not in self.db_cache:
+ continue
+ fd = self.db_cache[dbname].fileno()
+ if fd is None:
+ continue
+ fdlist.append(fd)
+
+ if not fdlist:
+ return BaseScript.sleep(self, secs)
+
+ try:
+ if hasattr(select, 'poll'):
+ p = select.poll()
+ for fd in fdlist:
+ p.register(fd, select.POLLIN)
+ p.poll(int(secs * 1000))
+ else:
+ select.select(fdlist, [], [], secs)
+ except select.error, d:
+ self.log.info('wait canceled')
+
def _exec_cmd(self, curs, sql, args, quiet = False):
"""Internal tool: Run SQL on cursor."""
self.log.debug("exec_cmd: %s" % skytools.quote_statement(sql, args))
@@ -903,3 +861,95 @@ class DBScript(object):
except ValueError:
pass
+class DBCachedConn(object):
+ """Cache a db connection."""
+ def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]):
+ self.name = name
+ self.loc = loc
+ self.conn = None
+ self.conn_time = 0
+ self.max_age = max_age
+ self.autocommit = -1
+ self.isolation_level = I_DEFAULT
+ self.verbose = verbose
+ self.setup_func = setup_func
+ self.listen_channel_list = []
+
+ def fileno(self):
+ if not self.conn:
+ return None
+ return self.conn.cursor().fileno()
+
+ def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []):
+ # autocommit overrider isolation_level
+ if autocommit:
+ if isolation_level == I_SERIALIZABLE:
+ raise Exception('autocommit is not compatible with I_SERIALIZABLE')
+ isolation_level = I_AUTOCOMMIT
+
+ # default isolation_level is READ COMMITTED
+ if isolation_level < 0:
+ isolation_level = I_READ_COMMITTED
+
+ # new conn?
+ if not self.conn:
+ self.isolation_level = isolation_level
+ self.conn = skytools.connect_database(self.loc)
+ self.conn.my_name = self.name
+
+ self.conn.set_isolation_level(isolation_level)
+ self.conn_time = time.time()
+ if self.setup_func:
+ self.setup_func(self.name, self.conn)
+ else:
+ if self.isolation_level != isolation_level:
+ raise Exception("Conflict in isolation_level")
+
+ self._sync_listen(listen_channel_list)
+
+ # done
+ return self.conn
+
+ def _sync_listen(self, new_clist):
+ if not new_clist and not self.listen_channel_list:
+ return
+ curs = self.conn.cursor()
+ for ch in self.listen_channel_list:
+ if ch not in new_clist:
+ curs.execute("UNLISTEN %s" % skytools.quote_ident(ch))
+ for ch in new_clist:
+ if ch not in self.listen_channel_list:
+ curs.execute("LISTEN %s" % skytools.quote_ident(ch))
+ if self.isolation_level != I_AUTOCOMMIT:
+ self.conn.commit()
+ self.listen_channel_list = new_clist[:]
+
+ def refresh(self):
+ if not self.conn:
+ return
+ #for row in self.conn.notifies():
+ # if row[0].lower() == "reload":
+ # self.reset()
+ # return
+ if not self.max_age:
+ return
+ if time.time() - self.conn_time >= self.max_age:
+ self.reset()
+
+ def reset(self):
+ if not self.conn:
+ return
+
+ # drop reference
+ conn = self.conn
+ self.conn = None
+ self.listen_channel_list = []
+
+ # close
+ try:
+ conn.close()
+ except: pass
+
+
+
+