diff options
author | Marko Kreen | 2011-04-07 12:57:28 +0000 |
---|---|---|
committer | Marko Kreen | 2011-04-15 10:20:21 +0000 |
commit | 68d1f210514bd42f6dec4492f1addd457f313c5d (patch) | |
tree | 99cfc9420b0175bd8e7f29165cb3bf0a7ade22e1 | |
parent | 205de8b710a2489d101352f1390cbe3dbdbcd9c0 (diff) |
skytools: Separate generic scripting from DBScript
-rw-r--r-- | python/pgq/cascade/consumer.py | 3 | ||||
-rw-r--r-- | python/skytools/scripting.py | 462 |
2 files changed, 258 insertions, 207 deletions
diff --git a/python/pgq/cascade/consumer.py b/python/pgq/cascade/consumer.py index fdc8a48b..798ee7e7 100644 --- a/python/pgq/cascade/consumer.py +++ b/python/pgq/cascade/consumer.py @@ -274,7 +274,7 @@ class CascadedConsumer(Consumer): q = "select * from pgq_node.set_consumer_completed(%s, %s, %s)" self.exec_cmd(dst_db, q, [ self.queue_name, self.consumer_name, tick_id ]) - def exception_hook(self, det, emsg, cname): + def exception_hook(self, det, emsg): try: dst_db = self.get_database(self.target_db) q = "select * from pgq_node.set_consumer_error(%s, %s, %s)" @@ -282,4 +282,5 @@ class CascadedConsumer(Consumer): except: self.log.warning("Failure to call pgq_node.set_consumer_error()") self.reset() + Consumer.exception_hook(self, det, emsg) diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py index 298c1efa..a8baffd6 100644 --- a/python/skytools/scripting.py +++ b/python/skytools/scripting.py @@ -17,22 +17,8 @@ except ImportError: __pychecker__ = 'no-badexcept' -#: how old connections need to be closed -DEF_CONN_AGE = 20*60 # 20 min - -#: isolation level not set -I_DEFAULT = -1 - -#: isolation level constant for AUTOCOMMIT -I_AUTOCOMMIT = 0 -#: isolation level constant for READ COMMITTED -I_READ_COMMITTED = 1 -#: isolation level constant for SERIALIZABLE -I_SERIALIZABLE = 2 - -__all__ = ['DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE', - 'signal_pidfile', 'UsageError'] -#__all__ += ['daemonize', 'run_single_process'] +__all__ = ['BaseScript', 'signal_pidfile', 'UsageError', 'daemonize', + 'DBScript', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE'] class UsageError(Exception): """User induced error.""" @@ -197,97 +183,10 @@ def _init_log(job_name, service_name, cf, log_level, is_daemon): return log -class DBCachedConn(object): - """Cache a db connection.""" - def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]): - self.name = name - self.loc = loc - self.conn = None - self.conn_time = 0 - self.max_age = max_age - self.autocommit = -1 - self.isolation_level = I_DEFAULT - self.verbose = verbose - self.setup_func = setup_func - self.listen_channel_list = [] - - def fileno(self): - if not self.conn: - return None - return self.conn.cursor().fileno() - - def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []): - # autocommit overrider isolation_level - if autocommit: - if isolation_level == I_SERIALIZABLE: - raise Exception('autocommit is not compatible with I_SERIALIZABLE') - isolation_level = I_AUTOCOMMIT - - # default isolation_level is READ COMMITTED - if isolation_level < 0: - isolation_level = I_READ_COMMITTED - - # new conn? - if not self.conn: - self.isolation_level = isolation_level - self.conn = skytools.connect_database(self.loc) - self.conn.my_name = self.name - - self.conn.set_isolation_level(isolation_level) - self.conn_time = time.time() - if self.setup_func: - self.setup_func(self.name, self.conn) - else: - if self.isolation_level != isolation_level: - raise Exception("Conflict in isolation_level") - self._sync_listen(listen_channel_list) - # done - return self.conn - - def _sync_listen(self, new_clist): - if not new_clist and not self.listen_channel_list: - return - curs = self.conn.cursor() - for ch in self.listen_channel_list: - if ch not in new_clist: - curs.execute("UNLISTEN %s" % skytools.quote_ident(ch)) - for ch in new_clist: - if ch not in self.listen_channel_list: - curs.execute("LISTEN %s" % skytools.quote_ident(ch)) - if self.isolation_level != I_AUTOCOMMIT: - self.conn.commit() - self.listen_channel_list = new_clist[:] - - def refresh(self): - if not self.conn: - return - #for row in self.conn.notifies(): - # if row[0].lower() == "reload": - # self.reset() - # return - if not self.max_age: - return - if time.time() - self.conn_time >= self.max_age: - self.reset() - - def reset(self): - if not self.conn: - return - - # drop reference - conn = self.conn - self.conn = None - self.listen_channel_list = [] - - # close - try: - conn.close() - except: pass - -class DBScript(object): - """Base class for database scripts. +class BaseScript(object): + """Base class for service scripts. Handles logging, daemonizing, config, errors. @@ -314,9 +213,6 @@ class DBScript(object): # 1 - enabled, unless non-daemon on console (os.isatty()) # 2 - always enabled #use_skylog = 0 - - # default lifetime for database connections (in seconds) - #connection_lifetime = 1200 """ service_name = None job_name = None @@ -353,12 +249,10 @@ class DBScript(object): @param args: cmdline args (sys.argv[1:]), but can be overrided """ self.service_name = service_name - self.db_cache = {} self.go_daemon = 0 self.need_reload = 0 self.stat_dict = {} self.log_level = logging.INFO - self._listen_map = {} # dbname: channel_list # parse command line parser = self.init_optparse() @@ -591,50 +485,9 @@ class DBScript(object): self.log.info(logmsg) self.stat_dict = {} - def connection_hook(self, dbname, conn): - pass - - def get_database(self, dbname, autocommit = 0, isolation_level = -1, - cache = None, connstr = None): - """Load cached database connection. - - User must not store it permanently somewhere, - as all connections will be invalidated on reset. - """ - - max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE) - if not cache: - cache = dbname - if cache in self.db_cache: - dbc = self.db_cache[cache] - else: - if not connstr: - connstr = self.cf.get(dbname) - self.log.debug("Connect '%s' to '%s'" % (cache, connstr)) - dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook) - self.db_cache[cache] = dbc - - clist = [] - if cache in self._listen_map: - clist = self._listen_map[cache] - - return dbc.get_connection(autocommit, isolation_level, clist) - - def close_database(self, dbname): - """Explicitly close a cached connection. - - Next call to get_database() will reconnect. - """ - if dbname in self.db_cache: - dbc = self.db_cache[dbname] - dbc.reset() - del self.db_cache[dbname] - def reset(self): - "Something bad happened, reset all connections." - for dbc in self.db_cache.values(): - dbc.reset() - self.db_cache = {} + "Something bad happened, reset all state." + pass def run(self): "Thread main loop." @@ -651,13 +504,6 @@ class DBScript(object): # do some work work = self.run_once() - # send stats that was added - self.send_stats() - - # reconnect if needed - for dbc in self.db_cache.values(): - dbc.refresh() - if not self.looping or self.loop_delay < 0: break @@ -673,7 +519,12 @@ class DBScript(object): break def run_once(self): - return self.run_func_safely(self.work, True) + state = self.run_func_safely(self.work, True) + + # send stats that was added + self.send_stats() + + return state def run_func_safely(self, func, prefer_looping = False): "Run users work function, safely." @@ -702,31 +553,13 @@ class DBScript(object): self.log.info("got KeyboardInterrupt, exiting") self.reset() sys.exit(1) - except skytools.DBError, d: - self.send_stats() - if d.cursor and d.cursor.connection: - cname = d.cursor.connection.my_name - dsn = d.cursor.connection.dsn - sql = d.cursor.query - if len(sql) > 200: # avoid logging londiste huge batched queries - sql = sql[:60] + " ..." - emsg = str(d).strip() - self.log.exception("Job %s got error on connection '%s': %s. Query: %s" % ( - self.job_name, cname, emsg, sql)) - else: - n = "psycopg2.%s" % d.__class__.__name__ - emsg = str(d).rstrip() - self.log.exception("Job %s crashed: %s: %s" % ( - self.job_name, n, emsg)) except Exception, d: self.send_stats() emsg = str(d).rstrip() - self.log.exception("Job %s crashed: %s" % ( - self.job_name, emsg)) - + self.reset() + self.exception_hook(d, emsg) # reset and sleep self.reset() - self.exception_hook(d, emsg, cname) if prefer_looping and self.looping and self.loop_delay > 0: self.sleep(20) return -1 @@ -734,39 +567,18 @@ class DBScript(object): def sleep(self, secs): """Make script sleep for some amount of time.""" - fdlist = [] - for dbname in self._listen_map.keys(): - if dbname not in self.db_cache: - continue - fd = self.db_cache[dbname].fileno() - if fd is None: - continue - fdlist.append(fd) - - if not fdlist: - return time.sleep(secs) - - try: - if hasattr(select, 'poll'): - p = select.poll() - for fd in fdlist: - p.register(fd, select.POLLIN) - p.poll(int(secs * 1000)) - else: - select.select(fdlist, [], [], secs) - except select.error, d: - self.log.info('wait canceled') + time.sleep(secs) - def exception_hook(self, det, emsg, cname): + def exception_hook(self, det, emsg): """Called on after exception processing. Can do additional logging. @param det: exception details @param emsg: exception msg - @param cname: connection name or None """ - pass + self.log.exception("Job %s crashed: %s" % ( + self.job_name, emsg)) def work(self): """Here should user's processing happen. @@ -787,6 +599,152 @@ class DBScript(object): signal.signal(signal.SIGHUP, self.hook_sighup) signal.signal(signal.SIGINT, self.hook_sigint) +## +## DBScript +## + +#: how old connections need to be closed +DEF_CONN_AGE = 20*60 # 20 min + +#: isolation level not set +I_DEFAULT = -1 + +#: isolation level constant for AUTOCOMMIT +I_AUTOCOMMIT = 0 +#: isolation level constant for READ COMMITTED +I_READ_COMMITTED = 1 +#: isolation level constant for SERIALIZABLE +I_SERIALIZABLE = 2 + + +class DBScript(BaseScript): + """Base class for database scripts. + + Handles database connection state. + + Config template:: + + ## Parameters for skytools.DBScript ## + + # default lifetime for database connections (in seconds) + #connection_lifetime = 1200 + """ + + def __init__(self, service_name, args): + """Script setup. + + User class should override work() and optionally __init__(), startup(), + reload(), reset() and init_optparse(). + + NB: in case of daemon, the __init__() and startup()/work() will be + run in different processes. So nothing fancy should be done in __init__(). + + @param service_name: unique name for script. + It will be also default job_name, if not specified in config. + @param args: cmdline args (sys.argv[1:]), but can be overrided + """ + self.db_cache = {} + self._listen_map = {} # dbname: channel_list + BaseScript.__init__(self, service_name, args) + + def connection_hook(self, dbname, conn): + pass + + def get_database(self, dbname, autocommit = 0, isolation_level = -1, + cache = None, connstr = None): + """Load cached database connection. + + User must not store it permanently somewhere, + as all connections will be invalidated on reset. + """ + + max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE) + if not cache: + cache = dbname + if cache in self.db_cache: + dbc = self.db_cache[cache] + else: + if not connstr: + connstr = self.cf.get(dbname) + self.log.debug("Connect '%s' to '%s'" % (cache, connstr)) + dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook) + self.db_cache[cache] = dbc + + clist = [] + if cache in self._listen_map: + clist = self._listen_map[cache] + + return dbc.get_connection(autocommit, isolation_level, clist) + + def close_database(self, dbname): + """Explicitly close a cached connection. + + Next call to get_database() will reconnect. + """ + if dbname in self.db_cache: + dbc = self.db_cache[dbname] + dbc.reset() + del self.db_cache[dbname] + + def reset(self): + "Something bad happened, reset all connections." + for dbc in self.db_cache.values(): + dbc.reset() + self.db_cache = {} + BaseScript.reset(self) + + def run_once(self): + state = BaseScript.run_once(self) + + # reconnect if needed + for dbc in self.db_cache.values(): + dbc.refresh() + + return state + + def exception_hook(self, d, emsg): + """Log database and query details from exception.""" + curs = getattr(d, 'cursor', None) + conn = getattr(curs, 'connection', None) + cname = getattr(conn, 'my_name', None) + if cname: + # Properly named connection + cname = d.cursor.connection.my_name + dsn = getattr(conn, 'dsn', '?') + sql = getattr(curs, 'query', '?') + if len(sql) > 200: # avoid logging londiste huge batched queries + sql = sql[:60] + " ..." + emsg = str(d).strip() + self.log.exception("Job %s got error on connection '%s': %s. Query: %s" % ( + self.job_name, cname, emsg, sql)) + else: + BaseScript.exception_hook(self, d, emsg) + + def sleep(self, secs): + """Make script sleep for some amount of time.""" + fdlist = [] + for dbname in self._listen_map.keys(): + if dbname not in self.db_cache: + continue + fd = self.db_cache[dbname].fileno() + if fd is None: + continue + fdlist.append(fd) + + if not fdlist: + return BaseScript.sleep(self, secs) + + try: + if hasattr(select, 'poll'): + p = select.poll() + for fd in fdlist: + p.register(fd, select.POLLIN) + p.poll(int(secs * 1000)) + else: + select.select(fdlist, [], [], secs) + except select.error, d: + self.log.info('wait canceled') + def _exec_cmd(self, curs, sql, args, quiet = False): """Internal tool: Run SQL on cursor.""" self.log.debug("exec_cmd: %s" % skytools.quote_statement(sql, args)) @@ -903,3 +861,95 @@ class DBScript(object): except ValueError: pass +class DBCachedConn(object): + """Cache a db connection.""" + def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]): + self.name = name + self.loc = loc + self.conn = None + self.conn_time = 0 + self.max_age = max_age + self.autocommit = -1 + self.isolation_level = I_DEFAULT + self.verbose = verbose + self.setup_func = setup_func + self.listen_channel_list = [] + + def fileno(self): + if not self.conn: + return None + return self.conn.cursor().fileno() + + def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []): + # autocommit overrider isolation_level + if autocommit: + if isolation_level == I_SERIALIZABLE: + raise Exception('autocommit is not compatible with I_SERIALIZABLE') + isolation_level = I_AUTOCOMMIT + + # default isolation_level is READ COMMITTED + if isolation_level < 0: + isolation_level = I_READ_COMMITTED + + # new conn? + if not self.conn: + self.isolation_level = isolation_level + self.conn = skytools.connect_database(self.loc) + self.conn.my_name = self.name + + self.conn.set_isolation_level(isolation_level) + self.conn_time = time.time() + if self.setup_func: + self.setup_func(self.name, self.conn) + else: + if self.isolation_level != isolation_level: + raise Exception("Conflict in isolation_level") + + self._sync_listen(listen_channel_list) + + # done + return self.conn + + def _sync_listen(self, new_clist): + if not new_clist and not self.listen_channel_list: + return + curs = self.conn.cursor() + for ch in self.listen_channel_list: + if ch not in new_clist: + curs.execute("UNLISTEN %s" % skytools.quote_ident(ch)) + for ch in new_clist: + if ch not in self.listen_channel_list: + curs.execute("LISTEN %s" % skytools.quote_ident(ch)) + if self.isolation_level != I_AUTOCOMMIT: + self.conn.commit() + self.listen_channel_list = new_clist[:] + + def refresh(self): + if not self.conn: + return + #for row in self.conn.notifies(): + # if row[0].lower() == "reload": + # self.reset() + # return + if not self.max_age: + return + if time.time() - self.conn_time >= self.max_age: + self.reset() + + def reset(self): + if not self.conn: + return + + # drop reference + conn = self.conn + self.conn = None + self.listen_channel_list = [] + + # close + try: + conn.close() + except: pass + + + + |