summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Kreen2010-09-23 00:52:16 +0000
committerMarko Kreen2010-10-06 15:53:32 +0000
commit6785f0a3c9aaecb1c72d4e88ea5812550af3089c (patch)
treeed7c0cbc5d14d16ad1765ce7ec5128f3435517eb
parent3beeb608f31a2392178b757cb9260f1664d136c6 (diff)
more checker
-rw-r--r--python/skytools/checker.py431
1 files changed, 87 insertions, 344 deletions
diff --git a/python/skytools/checker.py b/python/skytools/checker.py
index 765afd14..dc89f721 100644
--- a/python/skytools/checker.py
+++ b/python/skytools/checker.py
@@ -2,347 +2,99 @@
"""Catch moment when tables are in sync on master and slave.
"""
-import sys, time, os
+import sys, time, os, subprocess
import pkgloader
pkgloader.require('skytools', '3.0')
import skytools
-CONFDB = "dbname=confdb host=confdb.service user=replicator"
-
-def unescape(s):
- """Remove copy escapes."""
- return skytools.unescape_copy(s)
-
-def get_pkey_list(curs, tbl):
- """Get list of pkey fields in right order."""
-
- oid = skytools.get_table_oid(curs, tbl)
- q = """SELECT k.attname FROM pg_index i, pg_attribute k
- WHERE i.indrelid = %s AND k.attrelid = i.indexrelid
- AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped
- ORDER BY k.attnum"""
- curs.execute(q, [oid])
- list = []
- for row in curs.fetchall():
- list.append(row[0])
- return list
-
-def get_column_list(curs, tbl):
- """Get list of columns in right order."""
-
- oid = skytools.get_table_oid(curs, tbl)
- q = """SELECT a.attname FROM pg_attribute a
- WHERE a.attrelid = %s
- AND a.attnum > 0 AND NOT a.attisdropped
- ORDER BY a.attnum"""
- curs.execute(q, [oid])
- list = []
- for row in curs.fetchall():
- list.append(row[0])
- return list
-
-class Checker(skytools.DBScript):
+class TableRepair:
"""Checks that tables in two databases are in sync."""
- cnt_insert = 0
- cnt_update = 0
- cnt_delete = 0
- total_src = 0
- total_dst = 0
- pkey_list = []
- common_fields = []
-
- def __init__(self, args):
- """Checker init."""
- skytools.DBScript.__init__(self, 'cross_mover', args)
- self.set_single_loop(1)
- self.log.info('Checker starting %s' % str(args))
- # compat names
- self.queue_name = self.cf.get("pgq_queue_name", '')
- self.consumer_name = self.cf.get('pgq_consumer_id', '')
- # good names
- if not self.queue_name:
- self.queue_name = self.cf.get("queue_name")
- if not self.consumer_name:
- self.consumer_name = self.cf.get('consumer_name', self.job_name)
- self.lock_timeout = self.cf.getfloat('lock_timeout', 10)
- # get tables to be compared
- if not self.options.table_list:
- self.log.error("--table is required")
- # create temp pidfile
- if self.pidfile:
- self.pidfile += ".repair"
-
- def set_lock_timeout(self, curs):
- ms = int(1000 * self.lock_timeout)
- if ms > 0:
- q = "SET LOCAL statement_timeout = %d" % ms
- self.log.debug(q)
- curs.execute(q)
-
- def init_optparse(self, p=None):
- """ Initialize cmdline switches.
- """
- p = skytools.DBScript.init_optparse(self, p)
- p.add_option("--table", dest='table_list', help="space separated list of table names")
- p.add_option("--part_expr", dest='part_expr', help="table partitioning expression")
-
- return p
-
- def check_consumer(self, setup_curs):
- """ Before locking anything check if consumer is working ok.
- """
- self.log.info("Queue: %s Consumer: %s" % (self.queue_name, self.consumer_name))
- # get ticker lag
- q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);"
- setup_curs.execute(q, [self.queue_name])
- ticker_lag = setup_curs.fetchone()[0]
- self.log.info("Ticker lag: %s" % ticker_lag)
- # get consumer lag
- q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);"
- setup_curs.execute(q, [self.queue_name, self.consumer_name])
- res = setup_curs.fetchall()
- if len(res) == 0:
- self.log.error('No such consumer')
- sys.exit(1)
- consumer_lag = res[0][0]
- self.log.info("Consumer lag: %s" % consumer_lag)
- # check that lag is acceptable
- if consumer_lag > ticker_lag + 10:
- self.log.error('Consumer lagging too much, cannot proceed')
- sys.exit(1)
-
- def work(self):
- """Syncer main function."""
- # get sourcedb connection and slots provided there
- setup_db = self.get_database('setup_db', autocommit = 1, connstr = self.cf.get('src_db'))
- setup_curs = setup_db.cursor()
- setup_curs.execute("select hostname(), current_database();")
- r_source = setup_curs.fetchone()
- self.log.info("Source: %s" % str(r_source))
-
- # get proxy db name and host (used to find out target cluster target partitons and their respective slots)
- proxy_db = self.get_database('dst_db', autocommit = 1)
- proxy_curs = proxy_db.cursor()
- proxy_curs.execute("select hostname(), current_database();")
- r_proxy = proxy_curs.fetchone()
- self.log.info("Proxy: %s" % str(r_proxy))
-
- # get target partitions from confdb and do also some sanity checks
- conf_db = self.get_database('conf_db', autocommit = 1, connstr = CONFDB)
- conf_curs = conf_db.cursor()
- q = "select db_name, hostname, slots, max_slot from dba.get_cross_targets(%s, %s, %s, %s)"
- conf_curs.execute(q, r_source + r_proxy)
- targets = conf_curs.fetchall()
-
- # get special purpose connections for magic locking
- lock_db = self.get_database('lock_db', connstr = self.cf.get('src_db'))
- src_db = self.get_database('src_db', isolation_level = skytools.I_SERIALIZABLE)
-
- # check that consumer is up and running
- self.check_consumer(setup_curs)
-
- # loop over all tables and all targets
- mismatch_count = 0
- for tbl in self.options.table_list.split():
- self.log.info("Checking table: %s" % tbl)
- tbl = skytools.fq_name(tbl)
- for target in targets:
- self.log.info("Target: %s" % str(target))
- connstr = "dbname=%s host=%s user=replicator" % (target[0], target[1])
- fn = "%s.%s" % (target[1], target[0])
- dst_db = self.get_database(target[0], isolation_level = skytools.I_SERIALIZABLE, connstr = connstr)
- where = "%s & %s in (%s)" % (self.options.part_expr, target[3],target[2])
- if not self.check_table(tbl, lock_db, src_db, dst_db, setup_curs, where, fn):
- mismatch_count += 1
- lock_db.commit()
- src_db.commit()
- dst_db.commit()
- if mismatch_count > 0:
- self.log.error("%s mismatching tables found" % mismatch_count)
- sys.exit(1)
- def force_tick(self, setup_curs):
- """ Force tick into source queue so that consumer can move on faster
- """
- q = "select pgq.force_tick(%s)"
- setup_curs.execute(q, [self.queue_name])
- res = setup_curs.fetchone()
- cur_pos = res[0]
-
- start = time.time()
- while 1:
- time.sleep(0.5)
- setup_curs.execute(q, [self.queue_name])
- res = setup_curs.fetchone()
- if res[0] != cur_pos:
- # new pos
- return res[0]
-
- # dont loop more than 10 secs
- dur = time.time() - start
- if dur > 10 and not self.options.force:
- raise Exception("Ticker seems dead")
-
- def check_table(self, tbl, lock_db, src_db, dst_db, setup_curs, where, target):
- """ Get transaction to same state, then process.
- """
- lock_curs = lock_db.cursor()
- src_curs = src_db.cursor()
- dst_curs = dst_db.cursor()
-
- if not skytools.exists_table(src_curs, tbl):
- self.log.warning("Table %s does not exist on provider side" % tbl)
- return
- if not skytools.exists_table(dst_curs, tbl):
- self.log.warning("Table %s does not exist on subscriber side" % tbl)
- return
-
- # lock table in separate connection
- self.log.info('Locking %s' % tbl)
- lock_db.commit()
- self.set_lock_timeout(lock_curs)
- lock_time = time.time()
- lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(tbl))
-
- # now wait until consumer has updated target table until locking
- self.log.info('Syncing %s' % tbl)
-
- # consumer must get further than this tick
- tick_id = self.force_tick(setup_curs)
- # try to force second tick also
- self.force_tick(setup_curs)
-
- # take server time
- setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')")
- tpos = setup_curs.fetchone()[0]
- # now wait
- while 1:
- time.sleep(0.5)
-
- q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)"
- setup_curs.execute(q, [tpos, self.queue_name, self.consumer_name])
- res = setup_curs.fetchall()
- if len(res) == 0:
- raise Exception('No such consumer')
- row = res[0]
- self.log.debug("tpos=%s now=%s lag=%s ok=%s" % (tpos, row[1], row[2], row[0]))
- if row[0]:
- break
-
- # limit lock time
- if time.time() > lock_time + self.lock_timeout:
- self.log.error('Consumer lagging too much, exiting')
- lock_db.rollback()
- sys.exit(1)
-
- # take snapshot on provider side
- src_db.commit()
- src_curs.execute("SELECT 1")
+ def __init__(self, table_name, log):
+ self.table_name = table_name
+ self.fq_table_name = skytools.quote_fqident(table_name)
+ self.log = log
+ self.reset()
- # take snapshot on subscriber side
- dst_db.commit()
- dst_curs.execute("SELECT 1")
-
- # release lock
- lock_db.commit()
-
- # do work
- result = self.do_compare(tbl, src_db, dst_db, where)
- if not result:
- self.do_repair(tbl, src_db, dst_db, where, target)
- # done
- src_db.commit()
- dst_db.commit()
-
- return result
+ def reset(self):
+ self.cnt_insert = 0
+ self.cnt_update = 0
+ self.cnt_delete = 0
+ self.total_src = 0
+ self.total_dst = 0
+ self.pkey_list = []
+ self.common_fields = []
- def do_compare(self, tbl, src_db, dst_db, where):
+ def do_repair(self, src_db, dst_db, where, pfx = 'repair', apply_fixes = False):
"""Actual comparision."""
- src_curs = src_db.cursor()
- dst_curs = dst_db.cursor()
-
- self.log.info('Counting %s' % tbl)
-
- q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" % where
- q = self.cf.get('compare_sql', q)
- q = q.replace('_TABLE_', skytools.quote_fqident(tbl))
-
- f = "%(cnt)d rows, checksum=%(chksum)s"
- f = self.cf.get('compare_fmt', f)
-
- self.log.debug("srcdb: " + q)
- src_curs.execute(q)
- src_row = src_curs.fetchone()
- src_str = f % src_row
- self.log.info("srcdb: %s" % src_str)
-
- self.log.debug("dstdb: " + q)
- dst_curs.execute(q)
- dst_row = dst_curs.fetchone()
- dst_str = f % dst_row
- self.log.info("dstdb: %s" % dst_str)
-
- if src_str != dst_str:
- self.log.warning("%s: Results do not match!" % tbl)
- return False
- else:
- self.log.info("%s: OK!" % tbl)
- return True
-
- def do_repair(self, tbl, src_db, dst_db, where, target):
- """Actual comparision."""
+ self.reset()
src_curs = src_db.cursor()
dst_curs = dst_db.cursor()
- self.log.info('Checking %s' % tbl)
+ self.log.info('Checking %s' % self.table_name)
- self.common_fields = []
- self.pkey_list = []
- copy_tbl = self.gen_copy_tbl(tbl, src_curs, dst_curs, where)
+ copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where)
- dump_src = target + "__" + tbl + ".src"
- dump_dst = target + "__" + tbl + ".dst"
+ dump_src = "%s.%s.src" % (pfx, self.table_name)
+ dump_dst = "%s.%s.dst" % (pfx, self.table_name)
+ fix = "%s.%s.fix" % (pfx, self.table_name)
- self.log.info("Dumping src table: %s" % tbl)
- self.dump_table(tbl, copy_tbl, src_curs, dump_src)
+ self.log.info("Dumping src table: %s" % self.table_name)
+ self.dump_table(copy_tbl, src_curs, dump_src)
src_db.commit()
- self.log.info("Dumping dst table: %s" % tbl)
- self.dump_table(tbl, copy_tbl, dst_curs, dump_dst)
+ self.log.info("Dumping dst table: %s" % self.table_name)
+ self.dump_table(copy_tbl, dst_curs, dump_dst)
dst_db.commit()
- self.log.info("Sorting src table: %s" % tbl)
+ self.log.info("Sorting src table: %s" % self.table_name)
+ self.do_sort(dump_src, dump_src + '.sorted')
- s_in, s_out = os.popen4("sort --version")
- s_ver = s_out.read()
- del s_in, s_out
- if s_ver.find("coreutils") > 0:
- args = "-S 30%"
- else:
- args = ""
- os.system("sort %s -T . -o %s.sorted %s" % (args, dump_src, dump_src))
- self.log.info("Sorting dst table: %s" % tbl)
- os.system("sort %s -T . -o %s.sorted %s" % (args, dump_dst, dump_dst))
+ self.log.info("Sorting dst table: %s" % self.table_name)
+ self.do_sort(dump_dst, dump_dst + '.sorted')
- self.dump_compare(tbl, dump_src + ".sorted", dump_dst + ".sorted", target)
+ self.dump_compare(dump_src + ".sorted", dump_dst + ".sorted", fix)
os.unlink(dump_src)
os.unlink(dump_dst)
os.unlink(dump_src + ".sorted")
os.unlink(dump_dst + ".sorted")
- def gen_copy_tbl(self, tbl, src_curs, dst_curs, where):
+ if apply_fixes:
+ pass
+
+ def do_sort(self, src, dst):
+ p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ s_ver = p.communicate()[0]
+ del p
+
+ xenv = os.environ.copy()
+ xenv['LANG'] = 'C'
+ xenv['LC_ALL'] = 'C'
+
+ cmdline = ['sort', '-T', '.']
+ if s_ver.find("coreutils") > 0:
+ cmdline.append('-S')
+ cmdline.append('30%')
+ cmdline.append('-o')
+ cmdline.append(dst)
+ cmdline.append(src)
+ p = subprocess.Popen(cmdline, env = xenv)
+ if p.wait() != 0:
+ raise Exception('sort failed')
+
+ def gen_copy_tbl(self, src_curs, dst_curs, where):
"""Create COPY expession from common fields."""
- self.pkey_list = get_pkey_list(src_curs, tbl)
- dst_pkey = get_pkey_list(dst_curs, tbl)
+ self.pkey_list = skytools.get_table_pkeys(src_curs, self.table_name)
+ dst_pkey = skytools.get_table_pkeys(dst_curs, self.table_name)
if dst_pkey != self.pkey_list:
self.log.error('pkeys do not match')
sys.exit(1)
- src_cols = get_column_list(src_curs, tbl)
- dst_cols = get_column_list(dst_curs, tbl)
+ src_cols = skytools.get_table_columns(src_curs, self.table_name)
+ dst_cols = skytools.get_table_columns(dst_curs, self.table_name)
field_list = []
for f in self.pkey_list:
field_list.append(f)
@@ -356,19 +108,21 @@ class Checker(skytools.DBScript):
fqlist = [skytools.quote_ident(col) for col in field_list]
- tbl_expr = "( select %s from %s where %s )" % (",".join(fqlist), skytools.quote_fqident(tbl), where)
+ tbl_expr = "select %s from %s" % (",".join(fqlist), self.fq_table_name)
+ if where:
+ tbl_expr += ' where ' + where
+ tbl_expr = "COPY (%s) TO STDOUT" % tbl_expr
self.log.debug("using copy expr: %s" % tbl_expr)
return tbl_expr
- def dump_table(self, tbl, copy_tbl, curs, fn):
+ def dump_table(self, copy_cmd, curs, fn):
"""Dump table to disk."""
f = open(fn, "w", 64*1024)
- curs.copy_to(f, copy_tbl)
- size = f.tell()
+ curs.copy_expert(f, copy_cmd)
+ self.log.info('%s: Got %d bytes' % (self.table_name, f.tell()))
f.close()
- self.log.info('%s: Got %d bytes' % (tbl, size))
def get_row(self, ln):
"""Parse a row into dict."""
@@ -380,14 +134,9 @@ class Checker(skytools.DBScript):
row[self.common_fields[i]] = t[i]
return row
- def dump_compare(self, tbl, src_fn, dst_fn, target):
+ def dump_compare(self, src_fn, dst_fn, fix):
"""Dump + compare single table."""
- self.log.info("Comparing dumps: %s" % tbl)
- self.cnt_insert = 0
- self.cnt_update = 0
- self.cnt_delete = 0
- self.total_src = 0
- self.total_dst = 0
+ self.log.info("Comparing dumps: %s" % self.table_name)
f1 = open(src_fn, "r", 64*1024)
f2 = open(dst_fn, "r", 64*1024)
src_ln = f1.readline()
@@ -395,7 +144,6 @@ class Checker(skytools.DBScript):
if src_ln: self.total_src += 1
if dst_ln: self.total_dst += 1
- fix = "fix.%s.%s.sql" % (target, tbl)
if os.path.isfile(fix):
os.unlink(fix)
@@ -408,15 +156,15 @@ class Checker(skytools.DBScript):
diff = self.cmp_keys(src_row, dst_row)
if diff > 0:
# src > dst
- self.got_missed_delete(tbl, dst_row, fix)
+ self.got_missed_delete(dst_row, fix)
keep_src = 1
elif diff < 0:
# src < dst
- self.got_missed_insert(tbl, src_row, fix)
+ self.got_missed_insert(src_row, fix)
keep_dst = 1
else:
if self.cmp_data(src_row, dst_row) != 0:
- self.got_missed_update(tbl, src_row, dst_row, fix)
+ self.got_missed_update(src_row, dst_row, fix)
if not keep_src:
src_ln = f1.readline()
@@ -427,10 +175,10 @@ class Checker(skytools.DBScript):
self.log.info("finished %s: src: %d rows, dst: %d rows,"\
" missed: %d inserts, %d updates, %d deletes" % (
- tbl, self.total_src, self.total_dst,
+ self.table_name, self.total_src, self.total_dst,
self.cnt_insert, self.cnt_update, self.cnt_delete))
- def got_missed_insert(self, tbl, src_row, fn):
+ def got_missed_insert(self, src_row, fn):
"""Create sql for missed insert."""
self.cnt_insert += 1
fld_list = self.common_fields
@@ -438,43 +186,43 @@ class Checker(skytools.DBScript):
val_list = []
for f in fld_list:
fq_list.append(skytools.quote_ident(f))
- v = unescape(src_row[f])
+ v = skytools.unescape_copy(src_row[f])
val_list.append(skytools.quote_literal(v))
q = "insert into %s (%s) values (%s);" % (
- tbl, ", ".join(fq_list), ", ".join(val_list))
- self.show_fix(tbl, q, 'insert', fn)
+ self.fq_table_name, ", ".join(fq_list), ", ".join(val_list))
+ self.show_fix(q, 'insert', fn)
- def got_missed_update(self, tbl, src_row, dst_row, fn):
+ def got_missed_update(self, src_row, dst_row, fn):
"""Create sql for missed update."""
self.cnt_update += 1
fld_list = self.common_fields
set_list = []
whe_list = []
for f in self.pkey_list:
- self.addcmp(whe_list, skytools.quote_ident(f), unescape(src_row[f]))
+ self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(src_row[f]))
for f in fld_list:
v1 = src_row[f]
v2 = dst_row[f]
if self.cmp_value(v1, v2) == 0:
continue
- self.addeq(set_list, skytools.quote_ident(f), unescape(v1))
- self.addcmp(whe_list, skytools.quote_ident(f), unescape(v2))
+ self.addeq(set_list, skytools.quote_ident(f), skytools.unescape_copy(v1))
+ self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(v2))
q = "update only %s set %s where %s;" % (
- tbl, ", ".join(set_list), " and ".join(whe_list))
- self.show_fix(tbl, q, 'update', fn)
+ self.fq_table_name, ", ".join(set_list), " and ".join(whe_list))
+ self.show_fix(q, 'update', fn)
- def got_missed_delete(self, tbl, dst_row, fn):
+ def got_missed_delete(self, dst_row, fn):
"""Create sql for missed delete."""
self.cnt_delete += 1
whe_list = []
for f in self.pkey_list:
- self.addcmp(whe_list, skytools.quote_ident(f), unescape(dst_row[f]))
- q = "delete from only %s where %s;" % (skytools.quote_fqident(tbl), " and ".join(whe_list))
- self.show_fix(tbl, q, 'delete', fn)
+ self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(dst_row[f]))
+ q = "delete from only %s where %s;" % (self.fq_table_name, " and ".join(whe_list))
+ self.show_fix(q, 'delete', fn)
- def show_fix(self, tbl, q, desc, fn):
+ def show_fix(self, q, desc, fn):
"""Print/write/apply repair sql."""
self.log.debug("missed %s: %s" % (desc, q))
open(fn, "a").write("%s\n" % q)
@@ -544,8 +292,3 @@ class Checker(skytools.DBScript):
return 1
return 0
-
-if __name__ == '__main__':
- script = Checker(sys.argv[1:])
- script.start()
-