summaryrefslogtreecommitdiff
path: root/python/pgq/cascade/worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pgq/cascade/worker.py')
-rw-r--r--python/pgq/cascade/worker.py97
1 files changed, 87 insertions, 10 deletions
diff --git a/python/pgq/cascade/worker.py b/python/pgq/cascade/worker.py
index 1d3f8325..a721eaa0 100644
--- a/python/pgq/cascade/worker.py
+++ b/python/pgq/cascade/worker.py
@@ -4,10 +4,11 @@ CascadedConsumer that also maintains node.
"""
-import sys, time
+import sys, time, skytools
from pgq.cascade.consumer import CascadedConsumer
from pgq.producer import bulk_insert_events
+from pgq.event import Event
__all__ = ['CascadedWorker']
@@ -32,10 +33,20 @@ class WorkerState:
filtered_copy = 0 # ok
process_global_wm = 0 # ok
+ sync_watermark = 0 # ?
+ wm_sync_nodes = []
+
def __init__(self, queue_name, nst):
self.node_type = nst['node_type']
self.node_name = nst['node_name']
self.local_watermark = nst['local_watermark']
+ self.global_watermark = nst['global_watermark']
+
+ self.node_attrs = {}
+ attrs = nst.get('node_attrs', '')
+ if attrs:
+ self.node_attrs = skytools.db_urldecode(attrs)
+
ntype = nst['node_type']
ctype = nst['combined_type']
if ntype == 'root':
@@ -49,7 +60,12 @@ class WorkerState:
self.process_tick_event = 1
self.keep_event_ids = 1
self.create_tick = 1
- self.process_global_wm = 1
+ if 'sync_watermark' in self.node_attrs:
+ slist = self.node_attrs['sync_watermark']
+ self.sync_watermark = 1
+ self.wm_sync_nodes = slist.split(',')
+ else:
+ self.process_global_wm = 1
elif ntype == 'leaf' and not ctype:
self.process_batch = 1
self.process_events = 1
@@ -139,8 +155,6 @@ class CascadedWorker(CascadedConsumer):
self.process_remote_event(src_curs, dst_curs, ev)
if ev.ev_id > max_id:
max_id = ev.ev_id
- if st.local_wm_publish:
- self.publish_local_wm(src_db)
if max_id > self.cur_max_id:
self.cur_max_id = max_id
@@ -195,22 +209,72 @@ class CascadedWorker(CascadedConsumer):
self.create_branch_tick(dst_db, cur_tick, tick_time)
return True
- def publish_local_wm(self, src_db):
+ def publish_local_wm(self, src_db, dst_db):
"""Send local watermark to provider.
"""
- if not self.main_worker:
- return
+
t = time.time()
if t - self.local_wm_publish_time < self.local_wm_publish_period:
return
st = self._worker_state
- self.log.debug("Publishing local watermark: %d" % st.local_watermark)
+ wm = st.local_watermark
+ if st.sync_watermark:
+ # dont send local watermark upstream
+ wm = self.batch_info['prev_tick_id']
+
+ self.log.debug("Publishing local watermark: %d" % wm)
src_curs = src_db.cursor()
q = "select * from pgq_node.set_subscriber_watermark(%s, %s, %s)"
- src_curs.execute(q, [self.pgq_queue_name, st.node_name, st.local_watermark])
+ src_curs.execute(q, [self.pgq_queue_name, st.node_name, wm])
+
+ # if last part fails, dont repeat it immediately
self.local_wm_publish_time = t
+ if st.sync_watermark:
+ # instead sync 'global-watermark' with specific nodes
+ dst_curs = dst_db.cursor()
+ nmap = self._get_node_map(dst_curs)
+ dst_db.commit()
+
+ wm = st.local_watermark
+ for node in st.wm_sync_nodes:
+ if node == st.node_name:
+ continue
+ if node not in nmap:
+ # dont ignore missing nodes - cluster may be partially set up
+ self.log.warning('Unknown node in sync_watermark list: %s' % node)
+ return
+ n = nmap[node]
+ if n['dead']:
+ # ignore dead nodes
+ continue
+ wmdb = self.get_database('wmdb', connstr = n['node_location'], autocommit = 1)
+ wmcurs = wmdb.cursor()
+ q = 'select local_watermark from pgq_node.get_node_info(%s)'
+ wmcurs.execute(q, [self.queue_name])
+ row = wmcurs.fetchone()
+ if not row:
+ # partially set up node?
+ self.log.warning('Node not working: %s' % node)
+ elif row['local_watermark'] < wm:
+ # keep lowest wm
+ wm = row['local_watermark']
+ self.close_database('wmdb')
+
+ # now we have lowest wm, store it
+ q = "select pgq_node.set_global_watermark(%s, %s)"
+ dst_curs.execute(q, [self.queue_name, wm])
+ dst_db.commit()
+
+ def _get_node_map(self, curs):
+ q = "select node_name, node_location, dead from pgq_node.get_queue_locations(%s)"
+ curs.execute(q, [self.queue_name])
+ res = {}
+ for row in curs.fetchall():
+ res[row['node_name']] = row
+ return res
+
def process_remote_event(self, src_curs, dst_curs, ev):
"""Handle cascading events.
"""
@@ -245,7 +309,10 @@ class CascadedWorker(CascadedConsumer):
q = "select * from pgq_node.unregister_location(%s, %s)"
dst_curs.execute(q, [self.pgq_queue_name, node])
elif t == "pgq.global-watermark":
- if st.process_global_wm:
+ if st.sync_watermark:
+ tick_id = int(ev.ev_data)
+ self.log.info('Ignoring global watermark %s' % tick_id)
+ elif st.process_global_wm:
tick_id = int(ev.ev_data)
q = "select * from pgq_node.set_global_watermark(%s, %s)"
dst_curs.execute(q, [self.pgq_queue_name, tick_id])
@@ -288,6 +355,8 @@ class CascadedWorker(CascadedConsumer):
tick_id = self.batch_info['tick_id']
tick_time = self.batch_info['batch_end']
self.create_branch_tick(dst_db, tick_id, tick_time)
+ if st.local_wm_publish:
+ self.publish_local_wm(src_db, dst_db)
def create_branch_tick(self, dst_db, tick_id, tick_time):
q = "select pgq.ticker(%s, %s, %s, %s)"
@@ -308,6 +377,14 @@ class CascadedWorker(CascadedConsumer):
return
if len(self.ev_buf) >= self.max_evbuf:
self.flush_events(dst_curs)
+
+ if ev.type == 'pgq.global-watermark':
+ st = self._worker_state
+ if st.sync_watermark:
+ # replace payload with synced global watermark
+ row = ev._event_row.copy()
+ row['ev_data'] = str(st.global_watermark)
+ ev = Event(self.queue_name, row)
self.ev_buf.append(ev)
def flush_events(self, dst_curs):