diff MySQLdb/cursors.py @ 10:3f4c6af70e52 MySQLdb

Me and PyLint had a knife fight, but PyLint had a gun.
author adustman
date Mon, 26 Feb 2007 02:40:02 +0000
parents fa8974a41c76
children 7773efbe9b30
line wrap: on
line diff
--- a/MySQLdb/cursors.py	Mon Feb 26 00:55:29 2007 +0000
+++ b/MySQLdb/cursors.py	Mon Feb 26 02:40:02 2007 +0000
@@ -5,11 +5,12 @@
 
 """
 
+__revision__ = "$ Revision: $"[11:-2]
+
 import re
-insert_values = re.compile(r"\svalues\s*(\(((?<!\\)'.*?\).*(?<!\\)?'|.)+?\))", re.IGNORECASE)
-from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
-     DatabaseError, OperationalError, IntegrityError, InternalError, \
-     NotSupportedError, ProgrammingError
+INSERT_VALUES = re.compile(
+    r"^(P<start>.+\svalues\s*)(P<values>\(((?<!\\)'.*?\).*(?<!\\)?'|.)+?\))(P<end>.*)$",
+    re.IGNORECASE)
 
 
 class BaseCursor(object):
@@ -36,6 +37,7 @@
          InternalError, ProgrammingError, NotSupportedError
     
     _defer_warnings = False
+    _fetch_type = None
     
     def __init__(self, connection):
         from weakref import proxy
@@ -61,25 +63,29 @@
         
     def close(self):
         """Close the cursor. No further queries will be possible."""
-        if not self.connection: return
-        while self.nextset(): pass
+        if not self.connection:
+            return
+        while self.nextset():
+            pass
         self.connection = None
 
     def _check_executed(self):
+        """Ensure that .execute() has been called."""
         if not self._executed:
-            self.errorhandler(self, ProgrammingError, "execute() first")
+            self.errorhandler(self, self.ProgrammingError, "execute() first")
 
     def _warning_check(self):
+        """Check for warnings, and report via the warnings module."""
         from warnings import warn
         if self._warnings:
             warnings = self._get_db().show_warnings()
             if warnings:
                 # This is done in two loops in case
                 # Warnings are set to raise exceptions.
-                for w in warnings:
-                    self.messages.append((self.Warning, w))
-                for w in warnings:
-                    warn(w[-1], self.Warning, 3)
+                for warning in warnings:
+                    self.messages.append((self.Warning, warning))
+                for warning in warnings:
+                    warn(warning[-1], self.Warning, 3)
             elif self._info:
                 self.messages.append((self.Warning, self._info))
                 warn(self._info, self.Warning, 3)
@@ -93,27 +99,33 @@
             self.fetchall()
         del self.messages[:]
         
-        db = self._get_db()
-        nr = db.next_result()
-        if nr == -1:
+        connection = self._get_db()
+        num_rows = connection.next_result()
+        if num_rows == -1:
             return None
         self._do_get_result()
         self._post_get_result()
         self._warning_check()
-        return 1
+        return True
 
-    def _post_get_result(self): pass
+    def _post_get_result(self):
+        """Stub to be overridden by MixIn."""
+        
+    def _get_result(self):
+        """Stub to be overridden by MixIn."""
+        return []
     
     def _do_get_result(self):
-        db = self._get_db()
+        """Get the result from the last query."""
+        connection = self._get_db()
         self._result = self._get_result()
-        self.rowcount = db.affected_rows()
+        self.rowcount = connection.affected_rows()
         self.rownumber = 0
         self.description = self._result and self._result.describe() or None
         self.description_flags = self._result and self._result.field_flags() or None
-        self.lastrowid = db.insert_id()
-        self._warnings = db.warning_count()
-        self._info = db.info()
+        self.lastrowid = connection.insert_id()
+        self._warnings = connection.warning_count()
+        self._info = connection.info()
     
     def setinputsizes(self, *args):
         """Does nothing, required by DB API."""
@@ -122,12 +134,14 @@
         """Does nothing, required by DB API."""
 
     def _get_db(self):
+        """Get the database connection.
+        
+        Raises ProgrammingError if the connection has been closed."""
         if not self.connection:
-            self.errorhandler(self, ProgrammingError, "cursor closed")
+            self.errorhandler(self, self.ProgrammingError, "cursor closed")
         return self.connection
     
     def execute(self, query, args=None):
-
         """Execute a query.
         
         query -- string, query to execute on server
@@ -140,39 +154,40 @@
         Returns long integer rows affected, if any
 
         """
-        from types import ListType, TupleType
         from sys import exc_info
         del self.messages[:]
-        db = self._get_db()
-        charset = db.character_set_name()
+        connection = self._get_db()
+        charset = connection.character_set_name()
         if isinstance(query, unicode):
             query = query.encode(charset)
         if args is not None:
-            query = query % db.literal(args)
+            query = query % connection.literal(args)
         try:
-            r = self._query(query)
-        except TypeError, m:
-            if m.args[0] in ("not enough arguments for format string",
-                             "not all arguments converted"):
-                self.messages.append((ProgrammingError, m.args[0]))
-                self.errorhandler(self, ProgrammingError, m.args[0])
+            result = self._query(query)
+        except TypeError, msg:
+            if msg.args[0] in ("not enough arguments for format string",
+                               "not all arguments converted"):
+                self.messages.append((self.ProgrammingError, msg.args[0]))
+                self.errorhandler(self, self.ProgrammingError, msg.args[0])
             else:
-                self.messages.append((TypeError, m))
-                self.errorhandler(self, TypeError, m)
+                self.messages.append((TypeError, msg))
+                self.errorhandler(self, TypeError, msg)
         except:
-            exc, value, tb = exc_info()
-            del tb
+            exc, value, traceback = exc_info()
+            del traceback
             self.messages.append((exc, value))
             self.errorhandler(self, exc, value)
         self._executed = query
-        if not self._defer_warnings: self._warning_check()
-        return r
+        if not self._defer_warnings:
+            self._warning_check()
+        return result
 
     def executemany(self, query, args):
-
         """Execute a multi-row query.
         
-        query -- string, query to execute on server
+        query
+        
+            string, query to execute on server
 
         args
 
@@ -187,45 +202,51 @@
 
         """
         del self.messages[:]
-        db = self._get_db()
-        if not args: return
-        charset = db.character_set_name()
-        if isinstance(query, unicode): query = query.encode(charset)
-        m = insert_values.search(query)
-        if not m:
-            r = 0
-            for a in args:
-                r = r + self.execute(query, a)
-            return r
-        p = m.start(1)
-        e = m.end(1)
-        qv = m.group(1)
+        connection = self._get_db()
+        if not args:
+            return
+        charset = connection.character_set_name()
+        if isinstance(query, unicode):
+            query = query.encode(charset)
+        matched = INSERT_VALUES.match(query)
+        if not matched:
+            self.rowcount = sum([ self.execute(query, arg) for arg in args ])
+            return self.rowcount
+        
+        start = matched.group('start')
+        end = matched.group('end')
+        values = matched.group('values')
+ 
         try:
-            q = [ qv % db.literal(a) for a in args ]
+            sql_params = [ values % connection.literal(arg) for arg in args ]
         except TypeError, msg:
             if msg.args[0] in ("not enough arguments for format string",
                                "not all arguments converted"):
-                self.messages.append((ProgrammingError, msg.args[0]))
-                self.errorhandler(self, ProgrammingError, msg.args[0])
+                self.messages.append((self.ProgrammingError, msg.args[0]))
+                self.errorhandler(self, self.ProgrammingError, msg.args[0])
             else:
                 self.messages.append((TypeError, msg))
                 self.errorhandler(self, TypeError, msg)
         except:
             from sys import exc_info
-            exc, value, tb = exc_info()
-            del tb
+            exc, value, traceback = exc_info()
+            del traceback
             self.errorhandler(self, exc, value)
-        r = self._query('\n'.join([query[:p], ',\n'.join(q), query[e:]]))
-        if not self._defer_warnings: self._warning_check()
-        return r
+        self.rowcount = int(self._query(
+            '\n'.join([start, ',\n'.join(sql_params), end,
+        ])))
+        if not self._defer_warnings:
+            self._warning_check()
+        return self.rowcount
     
     def callproc(self, procname, args=()):
-
         """Execute stored procedure procname with args
         
-        procname -- string, name of procedure to execute on server
+        procname
+            string, name of procedure to execute on server
 
-        args -- Sequence of parameters to use with procedure
+        args
+            Sequence of parameters to use with procedure
 
         Returns the original args.
 
@@ -249,37 +270,41 @@
         disconnected.
         """
 
-        from types import UnicodeType
-        db = self._get_db()
-        charset = db.character_set_name()
+        connection = self._get_db()
+        charset = connection.character_set_name()
         for index, arg in enumerate(args):
-            q = "SET @_%s_%d=%s" % (procname, index,
-                                         db.literal(arg))
-            if isinstance(q, unicode):
-                q = q.encode(charset)
-            self._query(q)
+            query = "SET @_%s_%d=%s" % (procname, index,
+                                        connection.literal(arg))
+            if isinstance(query, unicode):
+                query = query.encode(charset)
+            self._query(query)
             self.nextset()
             
-        q = "CALL %s(%s)" % (procname,
-                             ','.join(['@_%s_%d' % (procname, i)
-                                       for i in range(len(args))]))
-        if type(q) is UnicodeType:
-            q = q.encode(charset)
-        self._query(q)
-        self._executed = q
-        if not self._defer_warnings: self._warning_check()
+        query = "CALL %s(%s)" % (procname,
+                                 ','.join(['@_%s_%d' % (procname, i)
+                                           for i in range(len(args))]))
+        if isinstance(query, unicode):
+            query = query.encode(charset)
+        self._query(query)
+        self._executed = query
+        if not self._defer_warnings:
+            self._warning_check()
         return args
     
-    def _do_query(self, q):
-        db = self._get_db()
-        self._last_executed = q
-        db.query(q)
+    def _do_query(self, query):
+        """Low-levey query wrapper. Overridden by MixIns."""
+        connection = self._get_db()
+        self._last_executed = query
+        connection.query(query)
         self._do_get_result()
         return self.rowcount
 
-    def _query(self, q): return self._do_query(q)
+    def _query(self, query):
+        """Hook for _do_query."""
+        return self._do_query(query)
     
     def _fetch_row(self, size=1):
+        """Low-level fetch_row wrapper."""
         if not self._result:
             return ()
         return self._result.fetch_row(size, self._fetch_type)
@@ -287,17 +312,14 @@
     def __iter__(self):
         return iter(self.fetchone, None)
 
-    Warning = Warning
-    Error = Error
-    InterfaceError = InterfaceError
-    DatabaseError = DatabaseError
-    DataError = DataError
-    OperationalError = OperationalError
-    IntegrityError = IntegrityError
-    InternalError = InternalError
-    ProgrammingError = ProgrammingError
-    NotSupportedError = NotSupportedError
-   
+    def fetchone(self):
+        """Stub to be overridden by a MixIn."""
+        return None
+    
+    def fetchall(self):
+        """Stub to be overridden by a MixIn."""
+        return []
+    
 
 class CursorStoreResultMixIn(object):
 
@@ -306,14 +328,18 @@
     result set can be very large, consider adding a LIMIT clause to your
     query, or using CursorUseResultMixIn instead."""
 
-    def _get_result(self): return self._get_db().store_result()
+    def _get_result(self):
+        """Low-level; uses mysql_store_result()"""
+        return self._get_db().store_result()
 
-    def _query(self, q):
-        rowcount = self._do_query(q)
+    def _query(self, query):
+        """Low-level; executes query, gets result, and returns rowcount."""
+        rowcount = self._do_query(query)
         self._post_get_result()
         return rowcount
 
     def _post_get_result(self):
+        """Low-level"""
         self._rows = self._fetch_row(0)
         self._result = None
 
@@ -321,9 +347,10 @@
         """Fetches a single row from the cursor. None indicates that
         no more rows are available."""
         self._check_executed()
-        if self.rownumber >= len(self._rows): return None
+        if self.rownumber >= len(self._rows):
+            return None
         result = self._rows[self.rownumber]
-        self.rownumber = self.rownumber+1
+        self.rownumber += 1
         return result
 
     def fetchmany(self, size=None):
@@ -354,15 +381,15 @@
         value states an absolute target position."""
         self._check_executed()
         if mode == 'relative':
-            r = self.rownumber + value
+            row = self.rownumber + value
         elif mode == 'absolute':
-            r = value
+            row = value
         else:
-            self.errorhandler(self, ProgrammingError,
+            self.errorhandler(self, self.ProgrammingError,
                               "unknown scroll mode %s" % `mode`)
-        if r < 0 or r >= len(self._rows):
+        if row < 0 or row >= len(self._rows):
             self.errorhandler(self, IndexError, "out of range")
-        self.rownumber = r
+        self.rownumber = row
 
     def __iter__(self):
         self._check_executed()
@@ -380,35 +407,37 @@
 
     _defer_warnings = True
     
-    def _get_result(self): return self._get_db().use_result()
+    def _get_result(self):
+        """Low-level; calls mysql_use_result()"""
+        return self._get_db().use_result()
 
     def fetchone(self):
         """Fetches a single row from the cursor."""
         self._check_executed()
-        r = self._fetch_row(1)
-        if not r:
+        rows = self._fetch_row(1)
+        if not rows:
             self._warning_check()
             return None
         self.rownumber = self.rownumber + 1
-        return r[0]
+        return rows[0]
              
     def fetchmany(self, size=None):
         """Fetch up to size rows from the cursor. Result set may be smaller
         than size. If size is not defined, cursor.arraysize is used."""
         self._check_executed()
-        r = self._fetch_row(size or self.arraysize)
-        self.rownumber = self.rownumber + len(r)
-        if not r:
+        rows = self._fetch_row(size or self.arraysize)
+        self.rownumber = self.rownumber + len(rows)
+        if not rows:
             self._warning_check()
-        return r
+        return rows
          
     def fetchall(self):
         """Fetchs all available rows from the cursor."""
         self._check_executed()
-        r = self._fetch_row(0)
-        self.rownumber = self.rownumber + len(r)
+        rows = self._fetch_row(0)
+        self.rownumber = self.rownumber + len(rows)
         self._warning_check()
-        return r
+        return rows
 
     def __iter__(self):
         return self
@@ -435,39 +464,6 @@
 
     _fetch_type = 1
 
-    def fetchoneDict(self):
-        """Fetch a single row as a dictionary. Deprecated:
-        Use fetchone() instead. Will be removed in 1.3."""
-        from warnings import warn
-        warn("fetchoneDict() is non-standard and will be removed in 1.3",
-             DeprecationWarning, 2)
-        return self.fetchone()
-
-    def fetchmanyDict(self, size=None):
-        """Fetch several rows as a list of dictionaries. Deprecated:
-        Use fetchmany() instead. Will be removed in 1.3."""
-        from warnings import warn
-        warn("fetchmanyDict() is non-standard and will be removed in 1.3",
-             DeprecationWarning, 2)
-        return self.fetchmany(size)
-
-    def fetchallDict(self):
-        """Fetch all available rows as a list of dictionaries. Deprecated:
-        Use fetchall() instead. Will be removed in 1.3."""
-        from warnings import warn
-        warn("fetchallDict() is non-standard and will be removed in 1.3",
-             DeprecationWarning, 2)
-        return self.fetchall()
-
-
-class CursorOldDictRowsMixIn(CursorDictRowsMixIn):
-
-    """This is a MixIn class that returns rows as dictionaries with
-    the same key convention as the old Mysqldb (MySQLmodule). Don't
-    use this."""
-
-    _fetch_type = 2
-
 
 class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn,
              BaseCursor):
@@ -479,7 +475,7 @@
 class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn,
                  BaseCursor):
 
-     """This is a Cursor class that returns rows as dictionaries and
+    """This is a Cursor class that returns rows as dictionaries and
     stores the result set in the client."""