summaryrefslogtreecommitdiff
path: root/python/londiste/table_copy.py
blob: 92ccd7f310e20ccf39d956fb3219bdb4f7a1c078 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#! /usr/bin/env python

"""Do a full table copy.

For internal usage.
"""

import sys, time, skytools

from londiste.util import find_copy_source
from skytools.dbstruct import *
from londiste.playback import *

__all__ = ['CopyTable']

class CopyTable(Replicator):
    """Table copy thread implementation."""

    reg_ok = False

    def __init__(self, args, copy_thread = 1):
        """Initializer.  copy_thread arg shows if the copy process is separate
        from main Playback thread or not.  copy_thread=0 means copying happens
        in same process.
        """

        Replicator.__init__(self, args)

        if not copy_thread:
            raise Exception("Combined copy not supported")

        if len(self.args) != 3:
            self.log.error("londiste copy requires table name")
            sys.exit(1)
        self.copy_table_name = self.args[2]

        sfx = self.get_copy_suffix(self.copy_table_name)
        self.old_consumer_name = self.consumer_name
        self.pidfile += sfx
        self.consumer_name += sfx
        self.copy_thread = 1
        self.main_worker = False

    def get_copy_suffix(self, tblname):
        return ".copy.%s" % tblname

    def reload_table_stat(self, dst_curs, tblname):
        self.load_table_state(dst_curs)
        if tblname not in self.table_map:
            self.log.warning('Table %s removed from replication', tblname)
            sys.exit(1)
        t = self.table_map[tblname]
        return t

    def do_copy(self, tbl_stat, src_db, dst_db):
        """Entry point into copying logic."""

        dst_db.commit()

        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()

        while 1:
            if tbl_stat.copy_role == 'wait-copy':
                self.log.info('waiting for first partition to initialize copy')
            elif tbl_stat.max_parallel_copies_reached():
                self.log.info('number of max parallel copies (%s) reached',
                                tbl_stat.max_parallel_copy)
            else:
                break
            time.sleep(10)
            tbl_stat = self.reload_table_stat(dst_curs, tbl_stat.name)
            dst_db.commit()

        while 1:
            pmap = self.get_state_map(src_db.cursor())
            src_db.commit()
            if tbl_stat.name not in pmap:
                raise Exception("table %s not available on provider" % tbl_stat.name)
            pt = pmap[tbl_stat.name]
            if pt.state == TABLE_OK:
                break

            self.log.warning("table %s not in sync yet on provider, waiting", tbl_stat.name)
            time.sleep(10)

        src_real_table = pt.dest_table

        # 0 - dont touch
        # 1 - single tx
        # 2 - multi tx
        cmode = 1
        if tbl_stat.copy_role == 'lead':
            cmode = 2
        elif tbl_stat.copy_role:
            cmode = 0

        # We need to see COPY snapshot from txid_current_snapshot() later.
        oldiso = src_db.isolation_level
        src_db.set_isolation_level(skytools.I_REPEATABLE_READ)
        src_db.commit()

        self.sync_database_encodings(src_db, dst_db)

        self.log.info("Starting full copy of %s", tbl_stat.name)

        # just in case, drop all fkeys (in case "replay" was skipped)
        # !! this may commit, so must be done before anything else !!
        if cmode > 0:
            self.drop_fkeys(dst_db, tbl_stat.dest_table)

        # now start ddl-dropping tx
        if cmode > 0:
            q = "lock table " + skytools.quote_fqident(tbl_stat.dest_table)
            dst_curs.execute(q)

        # find dst struct
        src_struct = TableStruct(src_curs, src_real_table)
        dst_struct = TableStruct(dst_curs, tbl_stat.dest_table)

        # take common columns, warn on missing ones
        dlist = dst_struct.get_column_list()
        slist = src_struct.get_column_list()
        common_cols = []
        for c in slist:
            if c not in dlist:
                self.log.warning("Table %s column %s does not exist on subscriber",
                                 tbl_stat.name, c)
            else:
                common_cols.append(c)
        for c in dlist:
            if c not in slist:
                self.log.warning("Table %s column %s does not exist on provider",
                                 tbl_stat.name, c)

        # drop unnecessary stuff
        if cmode > 0:
            objs = T_CONSTRAINT | T_INDEX | T_RULE | T_PARENT # | T_TRIGGER
            dst_struct.drop(dst_curs, objs, log = self.log)

            # drop data
            if tbl_stat.table_attrs.get('skip_truncate'):
                self.log.info("%s: skipping truncate", tbl_stat.name)
            else:
                self.log.info("%s: truncating", tbl_stat.name)
                q = "truncate "
                if dst_db.server_version >= 80400:
                    q += "only "
                q += skytools.quote_fqident(tbl_stat.dest_table)
                dst_curs.execute(q)

            if cmode == 2 and tbl_stat.dropped_ddl is None:
                ddl = dst_struct.get_create_sql(objs)
                if ddl:
                    q = "select * from londiste.local_set_table_struct(%s, %s, %s)"
                    self.exec_cmd(dst_curs, q, [self.queue_name, tbl_stat.name, ddl])
                else:
                    ddl = None
                dst_db.commit()
                tbl_stat.dropped_ddl = ddl

        # do truncate & copy
        self.log.info("%s: start copy", tbl_stat.name)
        p = tbl_stat.get_plugin()
        stats = p.real_copy(src_real_table, src_curs, dst_curs, common_cols)
        if stats:
            self.log.info("%s: copy finished: %d bytes, %d rows",
                          tbl_stat.name, stats[0], stats[1])

        # get snapshot
        src_curs.execute("select txid_current_snapshot()")
        snapshot = src_curs.fetchone()[0]
        src_db.commit()

        # restore old behaviour
        src_db.set_isolation_level(oldiso)
        src_db.commit()

        tbl_stat.change_state(TABLE_CATCHING_UP)
        tbl_stat.change_snapshot(snapshot)
        self.save_table_state(dst_curs)

        # create previously dropped objects
        if cmode == 1:
            dst_struct.create(dst_curs, objs, log = self.log)
        elif cmode == 2:
            dst_db.commit()

            # start waiting for other copy processes to finish
            while tbl_stat.copy_role:
                self.log.info('waiting for other partitions to finish copy')
                time.sleep(10)
                tbl_stat = self.reload_table_stat(dst_curs, tbl_stat.name)
                dst_db.commit()

            if tbl_stat.dropped_ddl is not None:
                self.looping = 0
                for ddl in skytools.parse_statements(tbl_stat.dropped_ddl):
                    self.log.info(ddl)
                    dst_curs.execute(ddl)
                q = "select * from londiste.local_set_table_struct(%s, %s, NULL)"
                self.exec_cmd(dst_curs, q, [self.queue_name, tbl_stat.name])
                tbl_stat.dropped_ddl = None
                self.looping = 1
            dst_db.commit()

        # hack for copy-in-playback
        if not self.copy_thread:
            tbl_stat.change_state(TABLE_OK)
            self.save_table_state(dst_curs)
        dst_db.commit()

        # copy finished
        if tbl_stat.copy_role == 'wait-replay':
            return

        # if copy done, request immediate tick from pgqd,
        # to make state juggling faster.  on mostly idle db-s
        # each step may take tickers idle_timeout secs, which is pain.
        q = "select pgq.force_tick(%s)"
        src_curs.execute(q, [self.queue_name])
        src_db.commit()

    def work(self):
        if not self.reg_ok:
            # check if needed? (table, not existing reg)
            self.register_copy_consumer()
            self.reg_ok = True
        return Replicator.work(self)

    def register_copy_consumer(self):
        dst_db = self.get_database('db')
        dst_curs = dst_db.cursor()

        # fetch table attrs
        q = "select * from londiste.get_table_list(%s) where table_name = %s"
        dst_curs.execute(q, [ self.queue_name, self.copy_table_name ])
        rows = dst_curs.fetchall()
        attrs = {}
        if len(rows) > 0:
            v_attrs = rows[0]['table_attrs']
            if v_attrs:
                attrs = skytools.db_urldecode(v_attrs)

        # fetch parent consumer state
        q = "select * from pgq_node.get_consumer_state(%s, %s)"
        rows = self.exec_cmd(dst_db, q, [ self.queue_name, self.old_consumer_name ])
        state = rows[0]
        source_node = state['provider_node']
        source_location = state['provider_location']

        # do we have node here?
        if 'copy_node' in attrs:
            if attrs['copy_node'] == '?':
                source_node, source_location, wname = find_copy_source(self,
                        self.queue_name, self.copy_table_name, source_node, source_location)
            else:
                # take node from attrs
                source_node = attrs['copy_node']
                q = "select * from pgq_node.get_queue_locations(%s) where node_name = %s"
                dst_curs.execute(q, [ self.queue_name, source_node ])
                rows = dst_curs.fetchall()
                if len(rows):
                    source_location = rows[0]['node_location']

        self.log.info("Using '%s' as source node", source_node)
        self.register_consumer(source_location)

if __name__ == '__main__':
    script = CopyTable(sys.argv[1:])
    script.start()