Mercurial > p > mysql-python > mysqldb-2
diff MySQLdb/cursors.py @ 74:80164eb2f090 MySQLdb
This passes all test, yet is still broken and ugly in many ways.
However, a lot of ugliness has been removed.
author | adustman |
---|---|
date | Sat, 20 Feb 2010 04:27:21 +0000 |
parents | c0c00294239b |
children | 3b03cb566032 |
line wrap: on
line diff
--- a/MySQLdb/cursors.py Fri Feb 19 02:21:11 2010 +0000 +++ b/MySQLdb/cursors.py Sat Feb 20 04:27:21 2010 +0000 @@ -13,6 +13,7 @@ import re import sys import weakref +from MySQLdb.converters import get_codec, tuple_row_decoder INSERT_VALUES = re.compile(r"(?P<start>.+values\s*)" r"(?P<values>\(((?<!\\)'[^\)]*?\)[^\)]*(?<!\\)?'|[^\(\)]|(?:\([^\)]*\)))+\))" @@ -39,8 +40,7 @@ _defer_warnings = False _fetch_type = None - def __init__(self, connection, encoders): - from MySQLdb.converters import default_decoders + def __init__(self, connection, encoders, decoders): self.connection = weakref.proxy(connection) self.description = None self.description_flags = None @@ -54,17 +54,41 @@ self._warnings = 0 self._info = None self.rownumber = None - self._encoders = encoders + self.maxrows = 0 + self.encoders = encoders + self.decoders = decoders + self._row_decoders = () + self.row_decoder = tuple_row_decoder + def _flush(self): + """_flush() reads to the end of the current result set, buffering what + it can, and then releases the result set.""" + if self._result: + for row in self._result: + pass + self._result = None + def __del__(self): self.close() self.errorhandler = None self._result = None + def _reset(self): + while True: + if self._result: + for row in self._result: + pass + self._result = None + if not self.nextset(): + break + del self.messages[:] + def close(self): """Close the cursor. No further queries will be possible.""" if not self.connection: return + + self._flush() try: while self.nextset(): pass @@ -106,22 +130,21 @@ num_rows = connection.next_result() if num_rows == -1: return None - self._do_get_result() - self._post_get_result() - self._warning_check() - return True - - def _do_get_result(self): - """Get the result from the last query.""" - connection = self._get_db() - self._result = self._get_result() - self.rowcount = connection.affected_rows() + result = connection.use_result() + self._result = result + if result: + self.field_flags = result.field_flags() + self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ] + self.description = result.describe() + else: + self._row_decoders = self.field_flags = () + self.description = None + self.rowcount = -1 #connection.affected_rows() self.rownumber = 0 - self.description = self._result and self._result.describe() or None - self.description_flags = self._result and self._result.field_flags() or None self.lastrowid = connection.insert_id() self._warnings = connection.warning_count() - self._info = connection.info() + self._info = connection.info() + return True def setinputsizes(self, *args): """Does nothing, required by DB API.""" @@ -150,15 +173,15 @@ Returns long integer rows affected, if any """ - del self.messages[:] db = self._get_db() + self._reset() charset = db.character_set_name() if isinstance(query, unicode): query = query.encode(charset) try: if args is not None: query = query % tuple(map(self.connection.literal, args)) - result = self._query(query) + self._query(query) except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", "not all arguments converted"): @@ -173,10 +196,9 @@ self.messages.append((exc, value)) self.errorhandler(self, exc, value) - self._executed = query if not self._defer_warnings: self._warning_check() - return result + return None def executemany(self, query, args): """Execute a multi-row query. @@ -197,8 +219,8 @@ execute(). """ - del self.messages[:] db = self._get_db() + self._reset() if not args: return charset = self.connection.character_set_name() @@ -216,8 +238,7 @@ try: sql_params = ( values % tuple(map(self.connection.literal, row)) for row in args ) multirow_query = '\n'.join([start, ',\n'.join(sql_params), end]) - self._executed = multirow_query - self.rowcount = int(self._query(multirow_query)) + self._query(multirow_query) except TypeError, msg: if msg.args[0] in ("not enough arguments for format string", @@ -234,7 +255,7 @@ if not self._defer_warnings: self._warning_check() - return self.rowcount + return None def callproc(self, procname, args=()): """Execute stored procedure procname with args @@ -283,71 +304,62 @@ if isinstance(query, unicode): query = query.encode(charset) self._query(query) - self._executed = query if not self._defer_warnings: self._warning_check() return args - - def _do_query(self, query): - """Low-levey query wrapper. Overridden by MixIns.""" - connection = self._get_db() - self._executed = query - connection.query(query) - self._do_get_result() - return self.rowcount - - def _fetch_row(self, size=1): - """Low-level fetch_row wrapper.""" - if not self._result: - return () - return self._result.fetch_row(size, self._fetch_type) def __iter__(self): return iter(self.fetchone, None) - def _get_result(self): - """Low-level; uses mysql_store_result()""" - return self._get_db().store_result() - def _query(self, query): - """Low-level; executes query, gets result, and returns rowcount.""" - rowcount = self._do_query(query) - self._post_get_result() - return rowcount - - def _post_get_result(self): - """Low-level""" - self._rows = self._fetch_row(0) - self._result = None - + """Low-level; executes query, gets result, sets up decoders.""" + connection = self._get_db() + self._flush() + self._executed = query + connection.query(query) + result = connection.use_result() + self._result = result + if result: + self.field_flags = result.field_flags() + self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ] + self.description = result.describe() + else: + self._row_decoders = self.field_flags = () + self.description = None + self.rowcount = -1 #connection.affected_rows() + self.rownumber = 0 + self.lastrowid = connection.insert_id() + self._warnings = connection.warning_count() + self._info = connection.info() + def fetchone(self): """Fetches a single row from the cursor. None indicates that no more rows are available.""" self._check_executed() - if self.rownumber >= len(self._rows): - return None - result = self._rows[self.rownumber] - self.rownumber += 1 - return result + row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row()) + return row def fetchmany(self, size=None): """Fetch up to size rows from the cursor. Result set may be smaller than size. If size is not defined, cursor.arraysize is used.""" self._check_executed() - end = self.rownumber + (size or self.arraysize) - result = self._rows[self.rownumber:end] - self.rownumber = min(end, len(self._rows)) - return result + if size is None: + size = self.arraysize + rows = [] + for i in range(size): + row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row()) + if row is None: break + rows.append(row) + return rows def fetchall(self): - """Fetchs all available rows from the cursor.""" + """Fetches all available rows from the cursor.""" self._check_executed() - if self.rownumber: - result = self._rows[self.rownumber:] + if self._result: + rows = [ self.row_decoder(self._row_decoders, row) for row in self._result ] else: - result = self._rows - self.rownumber = len(self._rows) - return result + rows = [] + return rows def scroll(self, value, mode='relative'): """Scroll the cursor in the result set to a new position according @@ -368,9 +380,3 @@ self.errorhandler(self, IndexError, "out of range") self.rownumber = row - def __iter__(self): - self._check_executed() - result = self.rownumber and self._rows[self.rownumber:] or self._rows - return iter(result) - - _fetch_type = 0