Mercurial > p > mysql-python > mysqldb-2
diff MySQLdb/cursors.py @ 10:3f4c6af70e52 MySQLdb
Me and PyLint had a knife fight, but PyLint had a gun.
author | adustman |
---|---|
date | Mon, 26 Feb 2007 02:40:02 +0000 |
parents | fa8974a41c76 |
children | 7773efbe9b30 |
line wrap: on
line diff
--- a/MySQLdb/cursors.py Mon Feb 26 00:55:29 2007 +0000 +++ b/MySQLdb/cursors.py Mon Feb 26 02:40:02 2007 +0000 @@ -5,11 +5,12 @@ """ +__revision__ = "$ Revision: $"[11:-2] + import re -insert_values = re.compile(r"\svalues\s*(\(((?<!\\)'.*?\).*(?<!\\)?'|.)+?\))", re.IGNORECASE) -from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \ - DatabaseError, OperationalError, IntegrityError, InternalError, \ - NotSupportedError, ProgrammingError +INSERT_VALUES = re.compile( + r"^(P<start>.+\svalues\s*)(P<values>\(((?<!\\)'.*?\).*(?<!\\)?'|.)+?\))(P<end>.*)$", + re.IGNORECASE) class BaseCursor(object): @@ -36,6 +37,7 @@ InternalError, ProgrammingError, NotSupportedError _defer_warnings = False + _fetch_type = None def __init__(self, connection): from weakref import proxy @@ -61,25 +63,29 @@ def close(self): """Close the cursor. No further queries will be possible.""" - if not self.connection: return - while self.nextset(): pass + if not self.connection: + return + while self.nextset(): + pass self.connection = None def _check_executed(self): + """Ensure that .execute() has been called.""" if not self._executed: - self.errorhandler(self, ProgrammingError, "execute() first") + self.errorhandler(self, self.ProgrammingError, "execute() first") def _warning_check(self): + """Check for warnings, and report via the warnings module.""" from warnings import warn if self._warnings: warnings = self._get_db().show_warnings() if warnings: # This is done in two loops in case # Warnings are set to raise exceptions. - for w in warnings: - self.messages.append((self.Warning, w)) - for w in warnings: - warn(w[-1], self.Warning, 3) + for warning in warnings: + self.messages.append((self.Warning, warning)) + for warning in warnings: + warn(warning[-1], self.Warning, 3) elif self._info: self.messages.append((self.Warning, self._info)) warn(self._info, self.Warning, 3) @@ -93,27 +99,33 @@ self.fetchall() del self.messages[:] - db = self._get_db() - nr = db.next_result() - if nr == -1: + connection = self._get_db() + num_rows = connection.next_result() + if num_rows == -1: return None self._do_get_result() self._post_get_result() self._warning_check() - return 1 + return True - def _post_get_result(self): pass + def _post_get_result(self): + """Stub to be overridden by MixIn.""" + + def _get_result(self): + """Stub to be overridden by MixIn.""" + return [] def _do_get_result(self): - db = self._get_db() + """Get the result from the last query.""" + connection = self._get_db() self._result = self._get_result() - self.rowcount = db.affected_rows() + self.rowcount = connection.affected_rows() self.rownumber = 0 self.description = self._result and self._result.describe() or None self.description_flags = self._result and self._result.field_flags() or None - self.lastrowid = db.insert_id() - self._warnings = db.warning_count() - self._info = db.info() + self.lastrowid = connection.insert_id() + self._warnings = connection.warning_count() + self._info = connection.info() def setinputsizes(self, *args): """Does nothing, required by DB API.""" @@ -122,12 +134,14 @@ """Does nothing, required by DB API.""" def _get_db(self): + """Get the database connection. + + Raises ProgrammingError if the connection has been closed.""" if not self.connection: - self.errorhandler(self, ProgrammingError, "cursor closed") + self.errorhandler(self, self.ProgrammingError, "cursor closed") return self.connection def execute(self, query, args=None): - """Execute a query. query -- string, query to execute on server @@ -140,39 +154,40 @@ Returns long integer rows affected, if any """ - from types import ListType, TupleType from sys import exc_info del self.messages[:] - db = self._get_db() - charset = db.character_set_name() + connection = self._get_db() + charset = connection.character_set_name() if isinstance(query, unicode): query = query.encode(charset) if args is not None: - query = query % db.literal(args) + query = query % connection.literal(args) try: - r = self._query(query) - except TypeError, m: - if m.args[0] in ("not enough arguments for format string", - "not all arguments converted"): - self.messages.append((ProgrammingError, m.args[0])) - self.errorhandler(self, ProgrammingError, m.args[0]) + result = self._query(query) + except TypeError, msg: + if msg.args[0] in ("not enough arguments for format string", + "not all arguments converted"): + self.messages.append((self.ProgrammingError, msg.args[0])) + self.errorhandler(self, self.ProgrammingError, msg.args[0]) else: - self.messages.append((TypeError, m)) - self.errorhandler(self, TypeError, m) + self.messages.append((TypeError, msg)) + self.errorhandler(self, TypeError, msg) except: - exc, value, tb = exc_info() - del tb + exc, value, traceback = exc_info() + del traceback self.messages.append((exc, value)) self.errorhandler(self, exc, value) self._executed = query - if not self._defer_warnings: self._warning_check() - return r + if not self._defer_warnings: + self._warning_check() + return result def executemany(self, query, args): - """Execute a multi-row query. - query -- string, query to execute on server + query + + string, query to execute on server args @@ -187,45 +202,51 @@ """ del self.messages[:] - db = self._get_db() - if not args: return - charset = db.character_set_name() - if isinstance(query, unicode): query = query.encode(charset) - m = insert_values.search(query) - if not m: - r = 0 - for a in args: - r = r + self.execute(query, a) - return r - p = m.start(1) - e = m.end(1) - qv = m.group(1) + connection = self._get_db() + if not args: + return + charset = connection.character_set_name() + if isinstance(query, unicode): + query = query.encode(charset) + matched = INSERT_VALUES.match(query) + if not matched: + self.rowcount = sum([ self.execute(query, arg) for arg in args ]) + return self.rowcount + + start = matched.group('start') + end = matched.group('end') + values = matched.group('values') + try: - q = [ qv % db.literal(a) for a in args ] + sql_params = [ values % connection.literal(arg) for arg in args ] except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", "not all arguments converted"): - self.messages.append((ProgrammingError, msg.args[0])) - self.errorhandler(self, ProgrammingError, msg.args[0]) + self.messages.append((self.ProgrammingError, msg.args[0])) + self.errorhandler(self, self.ProgrammingError, msg.args[0]) else: self.messages.append((TypeError, msg)) self.errorhandler(self, TypeError, msg) except: from sys import exc_info - exc, value, tb = exc_info() - del tb + exc, value, traceback = exc_info() + del traceback self.errorhandler(self, exc, value) - r = self._query('\n'.join([query[:p], ',\n'.join(q), query[e:]])) - if not self._defer_warnings: self._warning_check() - return r + self.rowcount = int(self._query( + '\n'.join([start, ',\n'.join(sql_params), end, + ]))) + if not self._defer_warnings: + self._warning_check() + return self.rowcount def callproc(self, procname, args=()): - """Execute stored procedure procname with args - procname -- string, name of procedure to execute on server + procname + string, name of procedure to execute on server - args -- Sequence of parameters to use with procedure + args + Sequence of parameters to use with procedure Returns the original args. @@ -249,37 +270,41 @@ disconnected. """ - from types import UnicodeType - db = self._get_db() - charset = db.character_set_name() + connection = self._get_db() + charset = connection.character_set_name() for index, arg in enumerate(args): - q = "SET @_%s_%d=%s" % (procname, index, - db.literal(arg)) - if isinstance(q, unicode): - q = q.encode(charset) - self._query(q) + query = "SET @_%s_%d=%s" % (procname, index, + connection.literal(arg)) + if isinstance(query, unicode): + query = query.encode(charset) + self._query(query) self.nextset() - q = "CALL %s(%s)" % (procname, - ','.join(['@_%s_%d' % (procname, i) - for i in range(len(args))])) - if type(q) is UnicodeType: - q = q.encode(charset) - self._query(q) - self._executed = q - if not self._defer_warnings: self._warning_check() + query = "CALL %s(%s)" % (procname, + ','.join(['@_%s_%d' % (procname, i) + for i in range(len(args))])) + if isinstance(query, unicode): + query = query.encode(charset) + self._query(query) + self._executed = query + if not self._defer_warnings: + self._warning_check() return args - def _do_query(self, q): - db = self._get_db() - self._last_executed = q - db.query(q) + def _do_query(self, query): + """Low-levey query wrapper. Overridden by MixIns.""" + connection = self._get_db() + self._last_executed = query + connection.query(query) self._do_get_result() return self.rowcount - def _query(self, q): return self._do_query(q) + def _query(self, query): + """Hook for _do_query.""" + return self._do_query(query) def _fetch_row(self, size=1): + """Low-level fetch_row wrapper.""" if not self._result: return () return self._result.fetch_row(size, self._fetch_type) @@ -287,17 +312,14 @@ def __iter__(self): return iter(self.fetchone, None) - Warning = Warning - Error = Error - InterfaceError = InterfaceError - DatabaseError = DatabaseError - DataError = DataError - OperationalError = OperationalError - IntegrityError = IntegrityError - InternalError = InternalError - ProgrammingError = ProgrammingError - NotSupportedError = NotSupportedError - + def fetchone(self): + """Stub to be overridden by a MixIn.""" + return None + + def fetchall(self): + """Stub to be overridden by a MixIn.""" + return [] + class CursorStoreResultMixIn(object): @@ -306,14 +328,18 @@ result set can be very large, consider adding a LIMIT clause to your query, or using CursorUseResultMixIn instead.""" - def _get_result(self): return self._get_db().store_result() + def _get_result(self): + """Low-level; uses mysql_store_result()""" + return self._get_db().store_result() - def _query(self, q): - rowcount = self._do_query(q) + def _query(self, query): + """Low-level; executes query, gets result, and returns rowcount.""" + rowcount = self._do_query(query) self._post_get_result() return rowcount def _post_get_result(self): + """Low-level""" self._rows = self._fetch_row(0) self._result = None @@ -321,9 +347,10 @@ """Fetches a single row from the cursor. None indicates that no more rows are available.""" self._check_executed() - if self.rownumber >= len(self._rows): return None + if self.rownumber >= len(self._rows): + return None result = self._rows[self.rownumber] - self.rownumber = self.rownumber+1 + self.rownumber += 1 return result def fetchmany(self, size=None): @@ -354,15 +381,15 @@ value states an absolute target position.""" self._check_executed() if mode == 'relative': - r = self.rownumber + value + row = self.rownumber + value elif mode == 'absolute': - r = value + row = value else: - self.errorhandler(self, ProgrammingError, + self.errorhandler(self, self.ProgrammingError, "unknown scroll mode %s" % `mode`) - if r < 0 or r >= len(self._rows): + if row < 0 or row >= len(self._rows): self.errorhandler(self, IndexError, "out of range") - self.rownumber = r + self.rownumber = row def __iter__(self): self._check_executed() @@ -380,35 +407,37 @@ _defer_warnings = True - def _get_result(self): return self._get_db().use_result() + def _get_result(self): + """Low-level; calls mysql_use_result()""" + return self._get_db().use_result() def fetchone(self): """Fetches a single row from the cursor.""" self._check_executed() - r = self._fetch_row(1) - if not r: + rows = self._fetch_row(1) + if not rows: self._warning_check() return None self.rownumber = self.rownumber + 1 - return r[0] + return rows[0] def fetchmany(self, size=None): """Fetch up to size rows from the cursor. Result set may be smaller than size. If size is not defined, cursor.arraysize is used.""" self._check_executed() - r = self._fetch_row(size or self.arraysize) - self.rownumber = self.rownumber + len(r) - if not r: + rows = self._fetch_row(size or self.arraysize) + self.rownumber = self.rownumber + len(rows) + if not rows: self._warning_check() - return r + return rows def fetchall(self): """Fetchs all available rows from the cursor.""" self._check_executed() - r = self._fetch_row(0) - self.rownumber = self.rownumber + len(r) + rows = self._fetch_row(0) + self.rownumber = self.rownumber + len(rows) self._warning_check() - return r + return rows def __iter__(self): return self @@ -435,39 +464,6 @@ _fetch_type = 1 - def fetchoneDict(self): - """Fetch a single row as a dictionary. Deprecated: - Use fetchone() instead. Will be removed in 1.3.""" - from warnings import warn - warn("fetchoneDict() is non-standard and will be removed in 1.3", - DeprecationWarning, 2) - return self.fetchone() - - def fetchmanyDict(self, size=None): - """Fetch several rows as a list of dictionaries. Deprecated: - Use fetchmany() instead. Will be removed in 1.3.""" - from warnings import warn - warn("fetchmanyDict() is non-standard and will be removed in 1.3", - DeprecationWarning, 2) - return self.fetchmany(size) - - def fetchallDict(self): - """Fetch all available rows as a list of dictionaries. Deprecated: - Use fetchall() instead. Will be removed in 1.3.""" - from warnings import warn - warn("fetchallDict() is non-standard and will be removed in 1.3", - DeprecationWarning, 2) - return self.fetchall() - - -class CursorOldDictRowsMixIn(CursorDictRowsMixIn): - - """This is a MixIn class that returns rows as dictionaries with - the same key convention as the old Mysqldb (MySQLmodule). Don't - use this.""" - - _fetch_type = 2 - class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn, BaseCursor): @@ -479,7 +475,7 @@ class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn, BaseCursor): - """This is a Cursor class that returns rows as dictionaries and + """This is a Cursor class that returns rows as dictionaries and stores the result set in the client."""