summaryrefslogtreecommitdiff
path: root/python/skytools/scripting.py
diff options
context:
space:
mode:
authorMarko Kreen2010-08-02 11:11:43 +0000
committerMarko Kreen2010-08-02 11:14:29 +0000
commit9cbd773eb892a8b1a91bca43537affc60417e357 (patch)
treeb6e5ae15e81e504fd10202a32da3c2868a618991 /python/skytools/scripting.py
parent4ef0bcbc4351aacb7b998218a0dbbb4873c63633 (diff)
DBScript: Wake from sleep on NOTIFY.
Use DBScript .listen(dbname, chan)/.unlisten(dbname, chan) to register.
Diffstat (limited to 'python/skytools/scripting.py')
-rw-r--r--python/skytools/scripting.py91
1 files changed, 86 insertions, 5 deletions
diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py
index 2bcf895d..2dddfa3c 100644
--- a/python/skytools/scripting.py
+++ b/python/skytools/scripting.py
@@ -3,7 +3,7 @@
"""
-import sys, os, signal, optparse, time, errno
+import sys, os, signal, optparse, time, errno, select
import logging, logging.handlers, logging.config
from skytools.config import *
@@ -189,7 +189,7 @@ def _init_log(job_name, service_name, cf, log_level):
class DBCachedConn(object):
"""Cache a db connection."""
- def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None):
+ def __init__(self, name, loc, max_age = DEF_CONN_AGE, verbose = False, setup_func=None, channels=[]):
self.name = name
self.loc = loc
self.conn = None
@@ -199,8 +199,14 @@ class DBCachedConn(object):
self.isolation_level = I_DEFAULT
self.verbose = verbose
self.setup_func = setup_func
+ self.listen_channel_list = []
- def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT):
+ def fileno(self):
+ if not self.conn:
+ return None
+ return self.conn.cursor().fileno()
+
+ def get_connection(self, autocommit = 0, isolation_level = I_DEFAULT, listen_channel_list = []):
# autocommit overrider isolation_level
if autocommit:
if isolation_level == I_SERIALIZABLE:
@@ -225,9 +231,25 @@ class DBCachedConn(object):
if self.isolation_level != isolation_level:
raise Exception("Conflict in isolation_level")
+ self._sync_listen(listen_channel_list)
+
# done
return self.conn
+ def _sync_listen(self, new_clist):
+ if not new_clist and not self.listen_channel_list:
+ return
+ curs = self.conn.cursor()
+ for ch in self.listen_channel_list:
+ if ch not in new_clist:
+ curs.execute("UNLISTEN %s" % skytools.quote_ident(ch))
+ for ch in new_clist:
+ if ch not in self.listen_channel_list:
+ curs.execute("LISTEN %s" % skytools.quote_ident(ch))
+ if self.isolation_level != I_AUTOCOMMIT:
+ self.conn.commit()
+ self.listen_channel_list = new_clist[:]
+
def refresh(self):
if not self.conn:
return
@@ -247,6 +269,7 @@ class DBCachedConn(object):
# drop reference
conn = self.conn
self.conn = None
+ self.listen_channel_list = []
# close
try:
@@ -321,6 +344,7 @@ class DBScript(object):
self.need_reload = 0
self.stat_dict = {}
self.log_level = logging.INFO
+ self._listen_map = {} # dbname: channel_list
# parse command line
parser = self.init_optparse()
@@ -567,7 +591,11 @@ class DBScript(object):
dbc = DBCachedConn(cache, connstr, max_age, setup_func = self.connection_hook)
self.db_cache[cache] = dbc
- return dbc.get_connection(autocommit, isolation_level)
+ clist = []
+ if cache in self._listen_map:
+ clist = self._listen_map[cache]
+
+ return dbc.get_connection(autocommit, isolation_level, clist)
def close_database(self, dbname):
"""Explicitly close a cached connection.
@@ -616,6 +644,8 @@ class DBScript(object):
if not work:
if self.loop_delay > 0:
self.sleep(self.loop_delay)
+ if not self.looping:
+ break
else:
break
@@ -681,7 +711,28 @@ class DBScript(object):
def sleep(self, secs):
"""Make script sleep for some amount of time."""
- time.sleep(secs)
+ fdlist = []
+ for dbname in self._listen_map.keys():
+ if dbname not in self.db_cache:
+ continue
+ fd = self.db_cache[dbname].fileno()
+ if fd is None:
+ continue
+ fdlist.append(fd)
+
+ if not fdlist:
+ return time.sleep(secs)
+
+ try:
+ if hasattr(select, 'poll'):
+ p = select.poll()
+ for fd in fdlist:
+ p.register(fd, select.POLLIN)
+ p.poll(int(secs * 1000))
+ else:
+ select.select(fdlist, [], [], secs)
+ except select.error, d:
+ self.log.info('wait canceled')
def exception_hook(self, det, emsg, cname):
"""Called on after exception processing.
@@ -791,4 +842,34 @@ class DBScript(object):
db.rollback()
raise Exception("db error")
+ def listen(self, dbname, channel):
+ """Make connection listen for specific event channel.
+
+ Listening will be activated on next .get_database() call.
+
+ Basically this means that DBScript.sleep() will poll for events
+ on that db connection, so when event appears, script will be
+ woken up.
+ """
+ if dbname not in self._listen_map:
+ self._listen_map[dbname] = []
+ clist = self._listen_map[dbname]
+ if channel not in clist:
+ clist.append(channel)
+
+ def unlisten(self, dbname, channel='*'):
+ """Stop connection for listening on specific event channel.
+
+ Listening will stop on next .get_database() call.
+ """
+ if dbname not in self._listen_map:
+ return
+ if channel == '*':
+ del self._listen_map[dbname]
+ return
+ clist = self._listen_map[dbname]
+ try:
+ clist.remove(channel)
+ except ValueError:
+ pass