diff options
author | Marko Kreen | 2010-08-02 11:11:43 +0000 |
---|---|---|
committer | Marko Kreen | 2010-08-02 11:14:29 +0000 |
commit | 9cbd773eb892a8b1a91bca43537affc60417e357 (patch) | |
tree | b6e5ae15e81e504fd10202a32da3c2868a618991 /python/skytools/scripting.py | |
parent | 4ef0bcbc4351aacb7b998218a0dbbb4873c63633 (diff) |
DBScript: Wake from sleep on NOTIFY.
Use DBScript .listen(dbname, chan)/.unlisten(dbname, chan) to register.
Diffstat (limited to 'python/skytools/scripting.py')
-rw-r--r-- | python/skytools/scripting.py | 91 |
1 files changed, 86 insertions, 5 deletions
diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py index 2bcf895d..2dddfa3c 100644 --- a/python/skytools/scripting.py +++ b/python/skytools/scripting.py @@ -3,7 +3,7 @@ """ -import sys, os, signal, optparse, time, errno +import sys, os, signal, optparse, time, errno, select import logging, logging.handlers, logging.config from skytools.config import * @@ -189,7 +189,7 @@ def _init_log(job_name, service_name, cf, log_level): class DBCachedConn(object): """Cache a db connection.""" - def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None): + def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]): self.name = name self.loc = loc self.conn = None @@ -199,8 +199,14 @@ class DBCachedConn(object): self.isolation_level = I_DEFAULT self.verbose = verbose self.setup_func = setup_func + self.listen_channel_list = [] - def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT): + def fileno(self): + if not self.conn: + return None + return self.conn.cursor().fileno() + + def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []): # autocommit overrider isolation_level if autocommit: if isolation_level == I_SERIALIZABLE: @@ -225,9 +231,25 @@ class DBCachedConn(object): if self.isolation_level != isolation_level: raise Exception("Conflict in isolation_level") + self._sync_listen(listen_channel_list) + # done return self.conn + def _sync_listen(self, new_clist): + if not new_clist and not self.listen_channel_list: + return + curs = self.conn.cursor() + for ch in self.listen_channel_list: + if ch not in new_clist: + curs.execute("UNLISTEN %s" % skytools.quote_ident(ch)) + for ch in new_clist: + if ch not in self.listen_channel_list: + curs.execute("LISTEN %s" % skytools.quote_ident(ch)) + if self.isolation_level != I_AUTOCOMMIT: + self.conn.commit() + self.listen_channel_list = new_clist[:] + def refresh(self): if not self.conn: return @@ -247,6 +269,7 @@ class DBCachedConn(object): # drop reference conn = self.conn self.conn = None + self.listen_channel_list = [] # close try: @@ -321,6 +344,7 @@ class DBScript(object): self.need_reload = 0 self.stat_dict = {} self.log_level = logging.INFO + self._listen_map = {} # dbname: channel_list # parse command line parser = self.init_optparse() @@ -567,7 +591,11 @@ class DBScript(object): dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook) self.db_cache[cache] = dbc - return dbc.get_connection(autocommit, isolation_level) + clist = [] + if cache in self._listen_map: + clist = self._listen_map[cache] + + return dbc.get_connection(autocommit, isolation_level, clist) def close_database(self, dbname): """Explicitly close a cached connection. @@ -616,6 +644,8 @@ class DBScript(object): if not work: if self.loop_delay > 0: self.sleep(self.loop_delay) + if not self.looping: + break else: break @@ -681,7 +711,28 @@ class DBScript(object): def sleep(self, secs): """Make script sleep for some amount of time.""" - time.sleep(secs) + fdlist = [] + for dbname in self._listen_map.keys(): + if dbname not in self.db_cache: + continue + fd = self.db_cache[dbname].fileno() + if fd is None: + continue + fdlist.append(fd) + + if not fdlist: + return time.sleep(secs) + + try: + if hasattr(select, 'poll'): + p = select.poll() + for fd in fdlist: + p.register(fd, select.POLLIN) + p.poll(int(secs * 1000)) + else: + select.select(fdlist, [], [], secs) + except select.error, d: + self.log.info('wait canceled') def exception_hook(self, det, emsg, cname): """Called on after exception processing. @@ -791,4 +842,34 @@ class DBScript(object): db.rollback() raise Exception("db error") + def listen(self, dbname, channel): + """Make connection listen for specific event channel. + + Listening will be activated on next .get_database() call. + + Basically this means that DBScript.sleep() will poll for events + on that db connection, so when event appears, script will be + woken up. + """ + if dbname not in self._listen_map: + self._listen_map[dbname] = [] + clist = self._listen_map[dbname] + if channel not in clist: + clist.append(channel) + + def unlisten(self, dbname, channel='*'): + """Stop connection for listening on specific event channel. + + Listening will stop on next .get_database() call. + """ + if dbname not in self._listen_map: + return + if channel == '*': + del self._listen_map[dbname] + return + clist = self._listen_map[dbname] + try: + clist.remove(channel) + except ValueError: + pass |