summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormartinko2013-12-17 14:54:45 +0000
committermartinko2013-12-17 14:54:45 +0000
commit8263f2fef55e078bfe15d026db5495c1880f0ffa (patch)
tree887922857ad2a15384f2a0bf5ed40d1ae705cd7c
parent80c1cff7b39d8139ab4b91d67e70d90614170c2a (diff)
scripts/data_maintainer.py: latest code cleanup
-rwxr-xr-xscripts/data_maintainer.py55
1 files changed, 31 insertions, 24 deletions
diff --git a/scripts/data_maintainer.py b/scripts/data_maintainer.py
index b78c4ea7..781d21f1 100755
--- a/scripts/data_maintainer.py
+++ b/scripts/data_maintainer.py
@@ -17,7 +17,7 @@ Config template::
from user_service
where expire_date < now();
- # if source is csv file you need to specify fileread and optionally csv_delimiter and csv_quotechar
+ # if source is csv file, you need to specify fileread and optionally csv_delimiter and csv_quotechar
#fileread = data.csv
#csv_delimiter = ,
#csv_quotechar = "
@@ -78,11 +78,11 @@ Config template::
use_skylog = 0
"""
+import csv
import datetime
+import os.path
import sys
import time
-import csv
-import os.path
import pkgloader
pkgloader.require('skytools', '3.0')
@@ -90,23 +90,29 @@ import skytools
class DataSource (object):
+ def __init__(self, log):
+ self.log = log
+
def open(self):
- raise NotImplementedError()
+ raise NotImplementedError
def close(self):
- raise NotImplementedError()
+ raise NotImplementedError
- def fetch(self, fetchcnt=0):
- raise NotImplementedError()
+ def fetch(self, count):
+ raise NotImplementedError
-class DBDataSource (object):
+
+class DBDataSource (DataSource):
def __init__(self, log, db, query, bres = None, with_hold = False):
- self.log = log
+ super(DBDataSource, self).__init__(log)
self.db = db
- self.query = "DECLARE data_maint_cur NO SCROLL CURSOR WITH HOLD FOR %s"\
- if with_hold else "DECLARE data_maint_cur NO SCROLL CURSOR FOR %s"
- self.query = self.query % query
+ if with_hold:
+ self.query = "DECLARE data_maint_cur NO SCROLL CURSOR WITH HOLD FOR %s" % query
+ else:
+ self.query = "DECLARE data_maint_cur NO SCROLL CURSOR FOR %s" % query
self.bres = bres
+ self.with_hold = with_hold
def _run_query(self, query, params = None):
self.cur.execute(query, params)
@@ -115,20 +121,21 @@ class DBDataSource (object):
def open(self):
self.cur = self.db.cursor()
- self._run_query(self.query, self.bres)
+ self._run_query(self.query, self.bres) # pass results from before_query into sql_pk
def close(self):
self.cur.execute("CLOSE data_maint_cur")
- if not self.withhold:
+ if not self.with_hold:
self.db.rollback()
- def fetch(self, fetchcnt=0):
- self._run_query("FETCH FORWARD %s FROM data_maint_cur" % fetchcnt)
+ def fetch(self, count):
+ self._run_query("FETCH FORWARD %i FROM data_maint_cur" % count)
return self.cur.fetchall()
-class CSVDataSource (object):
+
+class CSVDataSource (DataSource):
def __init__(self, log, filename, delimiter, quotechar):
- self.log = log
+ super(CSVDataSource, self).__init__(log)
self.filename = filename
self.delimiter = delimiter
self.quotechar = quotechar
@@ -140,12 +147,12 @@ class CSVDataSource (object):
def close(self):
self.fp.close()
- def fetch(self, fetchcnt=1):
+ def fetch(self, count):
ret = []
for row in self.reader:
ret.append(row)
- fetchcnt = fetchcnt - 1
- if fetchcnt <= 0:
+ count -= 1
+ if count <= 0:
break
return ret
@@ -161,10 +168,9 @@ class DataMaintainer (skytools.DBScript):
self.fileread = self.cf.get("fileread", "")
if self.fileread:
self.fileread = os.path.expanduser(self.fileread)
- # force single run if source is file
- self.loop_delay = -1
+ self.set_single_loop(True) # force single run if source is file
- self.csv_delimiter = self.cf.get("csv_delimiter", ",")
+ self.csv_delimiter = self.cf.get("csv_delimiter", ',')
self.csv_quotechar = self.cf.get("csv_quotechar", '"')
# query for fetching the PK-s of the data set to be maintained
@@ -333,6 +339,7 @@ class DataMaintainer (skytools.DBScript):
self.log.info(text, self.total_count, datetime.timedelta(0, round(time.time() - self.started)))
self.lap_time = time.time()
+
if __name__ == '__main__':
script = DataMaintainer(sys.argv[1:])
script.start()