diff options
Diffstat (limited to 'python/skytools/checker.py')
-rw-r--r-- | python/skytools/checker.py | 551 |
1 files changed, 551 insertions, 0 deletions
diff --git a/python/skytools/checker.py b/python/skytools/checker.py new file mode 100644 index 00000000..765afd14 --- /dev/null +++ b/python/skytools/checker.py @@ -0,0 +1,551 @@ + +"""Catch moment when tables are in sync on master and slave. +""" + +import sys, time, os + +import pkgloader +pkgloader.require('skytools', '3.0') +import skytools + +CONFDB = "dbname=confdb host=confdb.service user=replicator" + +def unescape(s): + """Remove copy escapes.""" + return skytools.unescape_copy(s) + +def get_pkey_list(curs, tbl): + """Get list of pkey fields in right order.""" + + oid = skytools.get_table_oid(curs, tbl) + q = """SELECT k.attname FROM pg_index i, pg_attribute k + WHERE i.indrelid = %s AND k.attrelid = i.indexrelid + AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped + ORDER BY k.attnum""" + curs.execute(q, [oid]) + list = [] + for row in curs.fetchall(): + list.append(row[0]) + return list + +def get_column_list(curs, tbl): + """Get list of columns in right order.""" + + oid = skytools.get_table_oid(curs, tbl) + q = """SELECT a.attname FROM pg_attribute a + WHERE a.attrelid = %s + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum""" + curs.execute(q, [oid]) + list = [] + for row in curs.fetchall(): + list.append(row[0]) + return list + +class Checker(skytools.DBScript): + """Checks that tables in two databases are in sync.""" + cnt_insert = 0 + cnt_update = 0 + cnt_delete = 0 + total_src = 0 + total_dst = 0 + pkey_list = [] + common_fields = [] + + def __init__(self, args): + """Checker init.""" + skytools.DBScript.__init__(self, 'cross_mover', args) + self.set_single_loop(1) + self.log.info('Checker starting %s' % str(args)) + # compat names + self.queue_name = self.cf.get("pgq_queue_name", '') + self.consumer_name = self.cf.get('pgq_consumer_id', '') + # good names + if not self.queue_name: + self.queue_name = self.cf.get("queue_name") + if not self.consumer_name: + self.consumer_name = self.cf.get('consumer_name', self.job_name) + self.lock_timeout = self.cf.getfloat('lock_timeout', 10) + # get tables to be compared + if not self.options.table_list: + self.log.error("--table is required") + # create temp pidfile + if self.pidfile: + self.pidfile += ".repair" + + def set_lock_timeout(self, curs): + ms = int(1000 * self.lock_timeout) + if ms > 0: + q = "SET LOCAL statement_timeout = %d" % ms + self.log.debug(q) + curs.execute(q) + + def init_optparse(self, p=None): + """ Initialize cmdline switches. + """ + p = skytools.DBScript.init_optparse(self, p) + p.add_option("--table", dest='table_list', help="space separated list of table names") + p.add_option("--part_expr", dest='part_expr', help="table partitioning expression") + + return p + + def check_consumer(self, setup_curs): + """ Before locking anything check if consumer is working ok. + """ + self.log.info("Queue: %s Consumer: %s" % (self.queue_name, self.consumer_name)) + # get ticker lag + q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);" + setup_curs.execute(q, [self.queue_name]) + ticker_lag = setup_curs.fetchone()[0] + self.log.info("Ticker lag: %s" % ticker_lag) + # get consumer lag + q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);" + setup_curs.execute(q, [self.queue_name, self.consumer_name]) + res = setup_curs.fetchall() + if len(res) == 0: + self.log.error('No such consumer') + sys.exit(1) + consumer_lag = res[0][0] + self.log.info("Consumer lag: %s" % consumer_lag) + # check that lag is acceptable + if consumer_lag > ticker_lag + 10: + self.log.error('Consumer lagging too much, cannot proceed') + sys.exit(1) + + def work(self): + """Syncer main function.""" + # get sourcedb connection and slots provided there + setup_db = self.get_database('setup_db', autocommit = 1, connstr = self.cf.get('src_db')) + setup_curs = setup_db.cursor() + setup_curs.execute("select hostname(), current_database();") + r_source = setup_curs.fetchone() + self.log.info("Source: %s" % str(r_source)) + + # get proxy db name and host (used to find out target cluster target partitons and their respective slots) + proxy_db = self.get_database('dst_db', autocommit = 1) + proxy_curs = proxy_db.cursor() + proxy_curs.execute("select hostname(), current_database();") + r_proxy = proxy_curs.fetchone() + self.log.info("Proxy: %s" % str(r_proxy)) + + # get target partitions from confdb and do also some sanity checks + conf_db = self.get_database('conf_db', autocommit = 1, connstr = CONFDB) + conf_curs = conf_db.cursor() + q = "select db_name, hostname, slots, max_slot from dba.get_cross_targets(%s, %s, %s, %s)" + conf_curs.execute(q, r_source + r_proxy) + targets = conf_curs.fetchall() + + # get special purpose connections for magic locking + lock_db = self.get_database('lock_db', connstr = self.cf.get('src_db')) + src_db = self.get_database('src_db', isolation_level = skytools.I_SERIALIZABLE) + + # check that consumer is up and running + self.check_consumer(setup_curs) + + # loop over all tables and all targets + mismatch_count = 0 + for tbl in self.options.table_list.split(): + self.log.info("Checking table: %s" % tbl) + tbl = skytools.fq_name(tbl) + for target in targets: + self.log.info("Target: %s" % str(target)) + connstr = "dbname=%s host=%s user=replicator" % (target[0], target[1]) + fn = "%s.%s" % (target[1], target[0]) + dst_db = self.get_database(target[0], isolation_level = skytools.I_SERIALIZABLE, connstr = connstr) + where = "%s & %s in (%s)" % (self.options.part_expr, target[3],target[2]) + if not self.check_table(tbl, lock_db, src_db, dst_db, setup_curs, where, fn): + mismatch_count += 1 + lock_db.commit() + src_db.commit() + dst_db.commit() + if mismatch_count > 0: + self.log.error("%s mismatching tables found" % mismatch_count) + sys.exit(1) + + def force_tick(self, setup_curs): + """ Force tick into source queue so that consumer can move on faster + """ + q = "select pgq.force_tick(%s)" + setup_curs.execute(q, [self.queue_name]) + res = setup_curs.fetchone() + cur_pos = res[0] + + start = time.time() + while 1: + time.sleep(0.5) + setup_curs.execute(q, [self.queue_name]) + res = setup_curs.fetchone() + if res[0] != cur_pos: + # new pos + return res[0] + + # dont loop more than 10 secs + dur = time.time() - start + if dur > 10 and not self.options.force: + raise Exception("Ticker seems dead") + + def check_table(self, tbl, lock_db, src_db, dst_db, setup_curs, where, target): + """ Get transaction to same state, then process. + """ + lock_curs = lock_db.cursor() + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + if not skytools.exists_table(src_curs, tbl): + self.log.warning("Table %s does not exist on provider side" % tbl) + return + if not skytools.exists_table(dst_curs, tbl): + self.log.warning("Table %s does not exist on subscriber side" % tbl) + return + + # lock table in separate connection + self.log.info('Locking %s' % tbl) + lock_db.commit() + self.set_lock_timeout(lock_curs) + lock_time = time.time() + lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(tbl)) + + # now wait until consumer has updated target table until locking + self.log.info('Syncing %s' % tbl) + + # consumer must get further than this tick + tick_id = self.force_tick(setup_curs) + # try to force second tick also + self.force_tick(setup_curs) + + # take server time + setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')") + tpos = setup_curs.fetchone()[0] + # now wait + while 1: + time.sleep(0.5) + + q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)" + setup_curs.execute(q, [tpos, self.queue_name, self.consumer_name]) + res = setup_curs.fetchall() + if len(res) == 0: + raise Exception('No such consumer') + row = res[0] + self.log.debug("tpos=%s now=%s lag=%s ok=%s" % (tpos, row[1], row[2], row[0])) + if row[0]: + break + + # limit lock time + if time.time() > lock_time + self.lock_timeout: + self.log.error('Consumer lagging too much, exiting') + lock_db.rollback() + sys.exit(1) + + # take snapshot on provider side + src_db.commit() + src_curs.execute("SELECT 1") + + # take snapshot on subscriber side + dst_db.commit() + dst_curs.execute("SELECT 1") + + # release lock + lock_db.commit() + + # do work + result = self.do_compare(tbl, src_db, dst_db, where) + if not result: + self.do_repair(tbl, src_db, dst_db, where, target) + # done + src_db.commit() + dst_db.commit() + + return result + + def do_compare(self, tbl, src_db, dst_db, where): + """Actual comparision.""" + + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + self.log.info('Counting %s' % tbl) + + q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" % where + q = self.cf.get('compare_sql', q) + q = q.replace('_TABLE_', skytools.quote_fqident(tbl)) + + f = "%(cnt)d rows, checksum=%(chksum)s" + f = self.cf.get('compare_fmt', f) + + self.log.debug("srcdb: " + q) + src_curs.execute(q) + src_row = src_curs.fetchone() + src_str = f % src_row + self.log.info("srcdb: %s" % src_str) + + self.log.debug("dstdb: " + q) + dst_curs.execute(q) + dst_row = dst_curs.fetchone() + dst_str = f % dst_row + self.log.info("dstdb: %s" % dst_str) + + if src_str != dst_str: + self.log.warning("%s: Results do not match!" % tbl) + return False + else: + self.log.info("%s: OK!" % tbl) + return True + + def do_repair(self, tbl, src_db, dst_db, where, target): + """Actual comparision.""" + + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + self.log.info('Checking %s' % tbl) + + self.common_fields = [] + self.pkey_list = [] + copy_tbl = self.gen_copy_tbl(tbl, src_curs, dst_curs, where) + + dump_src = target + "__" + tbl + ".src" + dump_dst = target + "__" + tbl + ".dst" + + self.log.info("Dumping src table: %s" % tbl) + self.dump_table(tbl, copy_tbl, src_curs, dump_src) + src_db.commit() + self.log.info("Dumping dst table: %s" % tbl) + self.dump_table(tbl, copy_tbl, dst_curs, dump_dst) + dst_db.commit() + + self.log.info("Sorting src table: %s" % tbl) + + s_in, s_out = os.popen4("sort --version") + s_ver = s_out.read() + del s_in, s_out + if s_ver.find("coreutils") > 0: + args = "-S 30%" + else: + args = "" + os.system("sort %s -T . -o %s.sorted %s" % (args, dump_src, dump_src)) + self.log.info("Sorting dst table: %s" % tbl) + os.system("sort %s -T . -o %s.sorted %s" % (args, dump_dst, dump_dst)) + + self.dump_compare(tbl, dump_src + ".sorted", dump_dst + ".sorted", target) + + os.unlink(dump_src) + os.unlink(dump_dst) + os.unlink(dump_src + ".sorted") + os.unlink(dump_dst + ".sorted") + + def gen_copy_tbl(self, tbl, src_curs, dst_curs, where): + """Create COPY expession from common fields.""" + self.pkey_list = get_pkey_list(src_curs, tbl) + dst_pkey = get_pkey_list(dst_curs, tbl) + if dst_pkey != self.pkey_list: + self.log.error('pkeys do not match') + sys.exit(1) + + src_cols = get_column_list(src_curs, tbl) + dst_cols = get_column_list(dst_curs, tbl) + field_list = [] + for f in self.pkey_list: + field_list.append(f) + for f in src_cols: + if f in self.pkey_list: + continue + if f in dst_cols: + field_list.append(f) + + self.common_fields = field_list + + fqlist = [skytools.quote_ident(col) for col in field_list] + + tbl_expr = "( select %s from %s where %s )" % (",".join(fqlist), skytools.quote_fqident(tbl), where) + + self.log.debug("using copy expr: %s" % tbl_expr) + + return tbl_expr + + def dump_table(self, tbl, copy_tbl, curs, fn): + """Dump table to disk.""" + f = open(fn, "w", 64*1024) + curs.copy_to(f, copy_tbl) + size = f.tell() + f.close() + self.log.info('%s: Got %d bytes' % (tbl, size)) + + def get_row(self, ln): + """Parse a row into dict.""" + if not ln: + return None + t = ln[:-1].split('\t') + row = {} + for i in range(len(self.common_fields)): + row[self.common_fields[i]] = t[i] + return row + + def dump_compare(self, tbl, src_fn, dst_fn, target): + """Dump + compare single table.""" + self.log.info("Comparing dumps: %s" % tbl) + self.cnt_insert = 0 + self.cnt_update = 0 + self.cnt_delete = 0 + self.total_src = 0 + self.total_dst = 0 + f1 = open(src_fn, "r", 64*1024) + f2 = open(dst_fn, "r", 64*1024) + src_ln = f1.readline() + dst_ln = f2.readline() + if src_ln: self.total_src += 1 + if dst_ln: self.total_dst += 1 + + fix = "fix.%s.%s.sql" % (target, tbl) + if os.path.isfile(fix): + os.unlink(fix) + + while src_ln or dst_ln: + keep_src = keep_dst = 0 + if src_ln != dst_ln: + src_row = self.get_row(src_ln) + dst_row = self.get_row(dst_ln) + + diff = self.cmp_keys(src_row, dst_row) + if diff > 0: + # src > dst + self.got_missed_delete(tbl, dst_row, fix) + keep_src = 1 + elif diff < 0: + # src < dst + self.got_missed_insert(tbl, src_row, fix) + keep_dst = 1 + else: + if self.cmp_data(src_row, dst_row) != 0: + self.got_missed_update(tbl, src_row, dst_row, fix) + + if not keep_src: + src_ln = f1.readline() + if src_ln: self.total_src += 1 + if not keep_dst: + dst_ln = f2.readline() + if dst_ln: self.total_dst += 1 + + self.log.info("finished %s: src: %d rows, dst: %d rows,"\ + " missed: %d inserts, %d updates, %d deletes" % ( + tbl, self.total_src, self.total_dst, + self.cnt_insert, self.cnt_update, self.cnt_delete)) + + def got_missed_insert(self, tbl, src_row, fn): + """Create sql for missed insert.""" + self.cnt_insert += 1 + fld_list = self.common_fields + fq_list = [] + val_list = [] + for f in fld_list: + fq_list.append(skytools.quote_ident(f)) + v = unescape(src_row[f]) + val_list.append(skytools.quote_literal(v)) + q = "insert into %s (%s) values (%s);" % ( + tbl, ", ".join(fq_list), ", ".join(val_list)) + self.show_fix(tbl, q, 'insert', fn) + + def got_missed_update(self, tbl, src_row, dst_row, fn): + """Create sql for missed update.""" + self.cnt_update += 1 + fld_list = self.common_fields + set_list = [] + whe_list = [] + for f in self.pkey_list: + self.addcmp(whe_list, skytools.quote_ident(f), unescape(src_row[f])) + for f in fld_list: + v1 = src_row[f] + v2 = dst_row[f] + if self.cmp_value(v1, v2) == 0: + continue + + self.addeq(set_list, skytools.quote_ident(f), unescape(v1)) + self.addcmp(whe_list, skytools.quote_ident(f), unescape(v2)) + + q = "update only %s set %s where %s;" % ( + tbl, ", ".join(set_list), " and ".join(whe_list)) + self.show_fix(tbl, q, 'update', fn) + + def got_missed_delete(self, tbl, dst_row, fn): + """Create sql for missed delete.""" + self.cnt_delete += 1 + whe_list = [] + for f in self.pkey_list: + self.addcmp(whe_list, skytools.quote_ident(f), unescape(dst_row[f])) + q = "delete from only %s where %s;" % (skytools.quote_fqident(tbl), " and ".join(whe_list)) + self.show_fix(tbl, q, 'delete', fn) + + def show_fix(self, tbl, q, desc, fn): + """Print/write/apply repair sql.""" + self.log.debug("missed %s: %s" % (desc, q)) + open(fn, "a").write("%s\n" % q) + + def addeq(self, list, f, v): + """Add quoted SET.""" + vq = skytools.quote_literal(v) + s = "%s = %s" % (f, vq) + list.append(s) + + def addcmp(self, list, f, v): + """Add quoted comparision.""" + if v is None: + s = "%s is null" % f + else: + vq = skytools.quote_literal(v) + s = "%s = %s" % (f, vq) + list.append(s) + + def cmp_data(self, src_row, dst_row): + """Compare data field-by-field.""" + for k in self.common_fields: + v1 = src_row[k] + v2 = dst_row[k] + if self.cmp_value(v1, v2) != 0: + return -1 + return 0 + + def cmp_value(self, v1, v2): + """Compare single field, tolerates tz vs notz dates.""" + if v1 == v2: + return 0 + + # try to work around tz vs. notz + z1 = len(v1) + z2 = len(v2) + if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+': + v1 = v1[:-3] + if v1 == v2: + return 0 + elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+': + v2 = v2[:-3] + if v1 == v2: + return 0 + + return -1 + + def cmp_keys(self, src_row, dst_row): + """Compare primary keys of the rows. + + Returns 1 if src > dst, -1 if src < dst and 0 if src == dst""" + + # None means table is done. tag it larger than any existing row. + if src_row is None: + if dst_row is None: + return 0 + return 1 + elif dst_row is None: + return -1 + + for k in self.pkey_list: + v1 = src_row[k] + v2 = dst_row[k] + if v1 < v2: + return -1 + elif v1 > v2: + return 1 + return 0 + + +if __name__ == '__main__': + script = Checker(sys.argv[1:]) + script.start() + |