comparison MySQLdb/cursors.py @ 10:3f4c6af70e52 MySQLdb

Me and PyLint had a knife fight, but PyLint had a gun.
author adustman
date Mon, 26 Feb 2007 02:40:02 +0000
parents fa8974a41c76
children 7773efbe9b30
comparison
equal deleted inserted replaced
9:0e37ee00beb7 10:3f4c6af70e52
3 This module implements Cursors of various types for MySQLdb. By 3 This module implements Cursors of various types for MySQLdb. By
4 default, MySQLdb uses the Cursor class. 4 default, MySQLdb uses the Cursor class.
5 5
6 """ 6 """
7 7
8 __revision__ = "$ Revision: $"[11:-2]
9
8 import re 10 import re
9 insert_values = re.compile(r"\svalues\s*(\(((?<!\\)'.*?\).*(?<!\\)?'|.)+?\))", re.IGNORECASE) 11 INSERT_VALUES = re.compile(
10 from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \ 12 r"^(P<start>.+\svalues\s*)(P<values>\(((?<!\\)'.*?\).*(?<!\\)?'|.)+?\))(P<end>.*)$",
11 DatabaseError, OperationalError, IntegrityError, InternalError, \ 13 re.IGNORECASE)
12 NotSupportedError, ProgrammingError
13 14
14 15
15 class BaseCursor(object): 16 class BaseCursor(object):
16 17
17 """A base for Cursor classes. Useful attributes: 18 """A base for Cursor classes. Useful attributes:
34 from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \ 35 from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
35 DatabaseError, DataError, OperationalError, IntegrityError, \ 36 DatabaseError, DataError, OperationalError, IntegrityError, \
36 InternalError, ProgrammingError, NotSupportedError 37 InternalError, ProgrammingError, NotSupportedError
37 38
38 _defer_warnings = False 39 _defer_warnings = False
40 _fetch_type = None
39 41
40 def __init__(self, connection): 42 def __init__(self, connection):
41 from weakref import proxy 43 from weakref import proxy
42 44
43 self.connection = proxy(connection) 45 self.connection = proxy(connection)
59 self.errorhandler = None 61 self.errorhandler = None
60 self._result = None 62 self._result = None
61 63
62 def close(self): 64 def close(self):
63 """Close the cursor. No further queries will be possible.""" 65 """Close the cursor. No further queries will be possible."""
64 if not self.connection: return 66 if not self.connection:
65 while self.nextset(): pass 67 return
68 while self.nextset():
69 pass
66 self.connection = None 70 self.connection = None
67 71
68 def _check_executed(self): 72 def _check_executed(self):
73 """Ensure that .execute() has been called."""
69 if not self._executed: 74 if not self._executed:
70 self.errorhandler(self, ProgrammingError, "execute() first") 75 self.errorhandler(self, self.ProgrammingError, "execute() first")
71 76
72 def _warning_check(self): 77 def _warning_check(self):
78 """Check for warnings, and report via the warnings module."""
73 from warnings import warn 79 from warnings import warn
74 if self._warnings: 80 if self._warnings:
75 warnings = self._get_db().show_warnings() 81 warnings = self._get_db().show_warnings()
76 if warnings: 82 if warnings:
77 # This is done in two loops in case 83 # This is done in two loops in case
78 # Warnings are set to raise exceptions. 84 # Warnings are set to raise exceptions.
79 for w in warnings: 85 for warning in warnings:
80 self.messages.append((self.Warning, w)) 86 self.messages.append((self.Warning, warning))
81 for w in warnings: 87 for warning in warnings:
82 warn(w[-1], self.Warning, 3) 88 warn(warning[-1], self.Warning, 3)
83 elif self._info: 89 elif self._info:
84 self.messages.append((self.Warning, self._info)) 90 self.messages.append((self.Warning, self._info))
85 warn(self._info, self.Warning, 3) 91 warn(self._info, self.Warning, 3)
86 92
87 def nextset(self): 93 def nextset(self):
91 """ 97 """
92 if self._executed: 98 if self._executed:
93 self.fetchall() 99 self.fetchall()
94 del self.messages[:] 100 del self.messages[:]
95 101
96 db = self._get_db() 102 connection = self._get_db()
97 nr = db.next_result() 103 num_rows = connection.next_result()
98 if nr == -1: 104 if num_rows == -1:
99 return None 105 return None
100 self._do_get_result() 106 self._do_get_result()
101 self._post_get_result() 107 self._post_get_result()
102 self._warning_check() 108 self._warning_check()
103 return 1 109 return True
104 110
105 def _post_get_result(self): pass 111 def _post_get_result(self):
112 """Stub to be overridden by MixIn."""
113
114 def _get_result(self):
115 """Stub to be overridden by MixIn."""
116 return []
106 117
107 def _do_get_result(self): 118 def _do_get_result(self):
108 db = self._get_db() 119 """Get the result from the last query."""
120 connection = self._get_db()
109 self._result = self._get_result() 121 self._result = self._get_result()
110 self.rowcount = db.affected_rows() 122 self.rowcount = connection.affected_rows()
111 self.rownumber = 0 123 self.rownumber = 0
112 self.description = self._result and self._result.describe() or None 124 self.description = self._result and self._result.describe() or None
113 self.description_flags = self._result and self._result.field_flags() or None 125 self.description_flags = self._result and self._result.field_flags() or None
114 self.lastrowid = db.insert_id() 126 self.lastrowid = connection.insert_id()
115 self._warnings = db.warning_count() 127 self._warnings = connection.warning_count()
116 self._info = db.info() 128 self._info = connection.info()
117 129
118 def setinputsizes(self, *args): 130 def setinputsizes(self, *args):
119 """Does nothing, required by DB API.""" 131 """Does nothing, required by DB API."""
120 132
121 def setoutputsizes(self, *args): 133 def setoutputsizes(self, *args):
122 """Does nothing, required by DB API.""" 134 """Does nothing, required by DB API."""
123 135
124 def _get_db(self): 136 def _get_db(self):
137 """Get the database connection.
138
139 Raises ProgrammingError if the connection has been closed."""
125 if not self.connection: 140 if not self.connection:
126 self.errorhandler(self, ProgrammingError, "cursor closed") 141 self.errorhandler(self, self.ProgrammingError, "cursor closed")
127 return self.connection 142 return self.connection
128 143
129 def execute(self, query, args=None): 144 def execute(self, query, args=None):
130
131 """Execute a query. 145 """Execute a query.
132 146
133 query -- string, query to execute on server 147 query -- string, query to execute on server
134 args -- optional sequence or mapping, parameters to use with query. 148 args -- optional sequence or mapping, parameters to use with query.
135 149
138 %(key)s must be used as the placeholder. 152 %(key)s must be used as the placeholder.
139 153
140 Returns long integer rows affected, if any 154 Returns long integer rows affected, if any
141 155
142 """ 156 """
143 from types import ListType, TupleType
144 from sys import exc_info 157 from sys import exc_info
145 del self.messages[:] 158 del self.messages[:]
146 db = self._get_db() 159 connection = self._get_db()
147 charset = db.character_set_name() 160 charset = connection.character_set_name()
148 if isinstance(query, unicode): 161 if isinstance(query, unicode):
149 query = query.encode(charset) 162 query = query.encode(charset)
150 if args is not None: 163 if args is not None:
151 query = query % db.literal(args) 164 query = query % connection.literal(args)
152 try: 165 try:
153 r = self._query(query) 166 result = self._query(query)
154 except TypeError, m: 167 except TypeError, msg:
155 if m.args[0] in ("not enough arguments for format string", 168 if msg.args[0] in ("not enough arguments for format string",
156 "not all arguments converted"): 169 "not all arguments converted"):
157 self.messages.append((ProgrammingError, m.args[0])) 170 self.messages.append((self.ProgrammingError, msg.args[0]))
158 self.errorhandler(self, ProgrammingError, m.args[0]) 171 self.errorhandler(self, self.ProgrammingError, msg.args[0])
159 else: 172 else:
160 self.messages.append((TypeError, m)) 173 self.messages.append((TypeError, msg))
161 self.errorhandler(self, TypeError, m) 174 self.errorhandler(self, TypeError, msg)
162 except: 175 except:
163 exc, value, tb = exc_info() 176 exc, value, traceback = exc_info()
164 del tb 177 del traceback
165 self.messages.append((exc, value)) 178 self.messages.append((exc, value))
166 self.errorhandler(self, exc, value) 179 self.errorhandler(self, exc, value)
167 self._executed = query 180 self._executed = query
168 if not self._defer_warnings: self._warning_check() 181 if not self._defer_warnings:
169 return r 182 self._warning_check()
183 return result
170 184
171 def executemany(self, query, args): 185 def executemany(self, query, args):
172
173 """Execute a multi-row query. 186 """Execute a multi-row query.
174 187
175 query -- string, query to execute on server 188 query
189
190 string, query to execute on server
176 191
177 args 192 args
178 193
179 Sequence of sequences or mappings, parameters to use with 194 Sequence of sequences or mappings, parameters to use with
180 query. 195 query.
185 REPLACE. Otherwise it is equivalent to looping over args with 200 REPLACE. Otherwise it is equivalent to looping over args with
186 execute(). 201 execute().
187 202
188 """ 203 """
189 del self.messages[:] 204 del self.messages[:]
190 db = self._get_db() 205 connection = self._get_db()
191 if not args: return 206 if not args:
192 charset = db.character_set_name() 207 return
193 if isinstance(query, unicode): query = query.encode(charset) 208 charset = connection.character_set_name()
194 m = insert_values.search(query) 209 if isinstance(query, unicode):
195 if not m: 210 query = query.encode(charset)
196 r = 0 211 matched = INSERT_VALUES.match(query)
197 for a in args: 212 if not matched:
198 r = r + self.execute(query, a) 213 self.rowcount = sum([ self.execute(query, arg) for arg in args ])
199 return r 214 return self.rowcount
200 p = m.start(1) 215
201 e = m.end(1) 216 start = matched.group('start')
202 qv = m.group(1) 217 end = matched.group('end')
218 values = matched.group('values')
219
203 try: 220 try:
204 q = [ qv % db.literal(a) for a in args ] 221 sql_params = [ values % connection.literal(arg) for arg in args ]
205 except TypeError, msg: 222 except TypeError, msg:
206 if msg.args[0] in ("not enough arguments for format string", 223 if msg.args[0] in ("not enough arguments for format string",
207 "not all arguments converted"): 224 "not all arguments converted"):
208 self.messages.append((ProgrammingError, msg.args[0])) 225 self.messages.append((self.ProgrammingError, msg.args[0]))
209 self.errorhandler(self, ProgrammingError, msg.args[0]) 226 self.errorhandler(self, self.ProgrammingError, msg.args[0])
210 else: 227 else:
211 self.messages.append((TypeError, msg)) 228 self.messages.append((TypeError, msg))
212 self.errorhandler(self, TypeError, msg) 229 self.errorhandler(self, TypeError, msg)
213 except: 230 except:
214 from sys import exc_info 231 from sys import exc_info
215 exc, value, tb = exc_info() 232 exc, value, traceback = exc_info()
216 del tb 233 del traceback
217 self.errorhandler(self, exc, value) 234 self.errorhandler(self, exc, value)
218 r = self._query('\n'.join([query[:p], ',\n'.join(q), query[e:]])) 235 self.rowcount = int(self._query(
219 if not self._defer_warnings: self._warning_check() 236 '\n'.join([start, ',\n'.join(sql_params), end,
220 return r 237 ])))
238 if not self._defer_warnings:
239 self._warning_check()
240 return self.rowcount
221 241
222 def callproc(self, procname, args=()): 242 def callproc(self, procname, args=()):
223
224 """Execute stored procedure procname with args 243 """Execute stored procedure procname with args
225 244
226 procname -- string, name of procedure to execute on server 245 procname
227 246 string, name of procedure to execute on server
228 args -- Sequence of parameters to use with procedure 247
248 args
249 Sequence of parameters to use with procedure
229 250
230 Returns the original args. 251 Returns the original args.
231 252
232 Compatibility warning: PEP-249 specifies that any modified 253 Compatibility warning: PEP-249 specifies that any modified
233 parameters must be returned. This is currently impossible 254 parameters must be returned. This is currently impossible
247 behavior with respect to the DB-API. Be sure to use nextset() 268 behavior with respect to the DB-API. Be sure to use nextset()
248 to advance through all result sets; otherwise you may get 269 to advance through all result sets; otherwise you may get
249 disconnected. 270 disconnected.
250 """ 271 """
251 272
252 from types import UnicodeType 273 connection = self._get_db()
253 db = self._get_db() 274 charset = connection.character_set_name()
254 charset = db.character_set_name()
255 for index, arg in enumerate(args): 275 for index, arg in enumerate(args):
256 q = "SET @_%s_%d=%s" % (procname, index, 276 query = "SET @_%s_%d=%s" % (procname, index,
257 db.literal(arg)) 277 connection.literal(arg))
258 if isinstance(q, unicode): 278 if isinstance(query, unicode):
259 q = q.encode(charset) 279 query = query.encode(charset)
260 self._query(q) 280 self._query(query)
261 self.nextset() 281 self.nextset()
262 282
263 q = "CALL %s(%s)" % (procname, 283 query = "CALL %s(%s)" % (procname,
264 ','.join(['@_%s_%d' % (procname, i) 284 ','.join(['@_%s_%d' % (procname, i)
265 for i in range(len(args))])) 285 for i in range(len(args))]))
266 if type(q) is UnicodeType: 286 if isinstance(query, unicode):
267 q = q.encode(charset) 287 query = query.encode(charset)
268 self._query(q) 288 self._query(query)
269 self._executed = q 289 self._executed = query
270 if not self._defer_warnings: self._warning_check() 290 if not self._defer_warnings:
291 self._warning_check()
271 return args 292 return args
272 293
273 def _do_query(self, q): 294 def _do_query(self, query):
274 db = self._get_db() 295 """Low-levey query wrapper. Overridden by MixIns."""
275 self._last_executed = q 296 connection = self._get_db()
276 db.query(q) 297 self._last_executed = query
298 connection.query(query)
277 self._do_get_result() 299 self._do_get_result()
278 return self.rowcount 300 return self.rowcount
279 301
280 def _query(self, q): return self._do_query(q) 302 def _query(self, query):
303 """Hook for _do_query."""
304 return self._do_query(query)
281 305
282 def _fetch_row(self, size=1): 306 def _fetch_row(self, size=1):
307 """Low-level fetch_row wrapper."""
283 if not self._result: 308 if not self._result:
284 return () 309 return ()
285 return self._result.fetch_row(size, self._fetch_type) 310 return self._result.fetch_row(size, self._fetch_type)
286 311
287 def __iter__(self): 312 def __iter__(self):
288 return iter(self.fetchone, None) 313 return iter(self.fetchone, None)
289 314
290 Warning = Warning 315 def fetchone(self):
291 Error = Error 316 """Stub to be overridden by a MixIn."""
292 InterfaceError = InterfaceError 317 return None
293 DatabaseError = DatabaseError 318
294 DataError = DataError 319 def fetchall(self):
295 OperationalError = OperationalError 320 """Stub to be overridden by a MixIn."""
296 IntegrityError = IntegrityError 321 return []
297 InternalError = InternalError 322
298 ProgrammingError = ProgrammingError
299 NotSupportedError = NotSupportedError
300
301 323
302 class CursorStoreResultMixIn(object): 324 class CursorStoreResultMixIn(object):
303 325
304 """This is a MixIn class which causes the entire result set to be 326 """This is a MixIn class which causes the entire result set to be
305 stored on the client side, i.e. it uses mysql_store_result(). If the 327 stored on the client side, i.e. it uses mysql_store_result(). If the
306 result set can be very large, consider adding a LIMIT clause to your 328 result set can be very large, consider adding a LIMIT clause to your
307 query, or using CursorUseResultMixIn instead.""" 329 query, or using CursorUseResultMixIn instead."""
308 330
309 def _get_result(self): return self._get_db().store_result() 331 def _get_result(self):
310 332 """Low-level; uses mysql_store_result()"""
311 def _query(self, q): 333 return self._get_db().store_result()
312 rowcount = self._do_query(q) 334
335 def _query(self, query):
336 """Low-level; executes query, gets result, and returns rowcount."""
337 rowcount = self._do_query(query)
313 self._post_get_result() 338 self._post_get_result()
314 return rowcount 339 return rowcount
315 340
316 def _post_get_result(self): 341 def _post_get_result(self):
342 """Low-level"""
317 self._rows = self._fetch_row(0) 343 self._rows = self._fetch_row(0)
318 self._result = None 344 self._result = None
319 345
320 def fetchone(self): 346 def fetchone(self):
321 """Fetches a single row from the cursor. None indicates that 347 """Fetches a single row from the cursor. None indicates that
322 no more rows are available.""" 348 no more rows are available."""
323 self._check_executed() 349 self._check_executed()
324 if self.rownumber >= len(self._rows): return None 350 if self.rownumber >= len(self._rows):
351 return None
325 result = self._rows[self.rownumber] 352 result = self._rows[self.rownumber]
326 self.rownumber = self.rownumber+1 353 self.rownumber += 1
327 return result 354 return result
328 355
329 def fetchmany(self, size=None): 356 def fetchmany(self, size=None):
330 """Fetch up to size rows from the cursor. Result set may be smaller 357 """Fetch up to size rows from the cursor. Result set may be smaller
331 than size. If size is not defined, cursor.arraysize is used.""" 358 than size. If size is not defined, cursor.arraysize is used."""
352 If mode is 'relative' (default), value is taken as offset to 379 If mode is 'relative' (default), value is taken as offset to
353 the current position in the result set, if set to 'absolute', 380 the current position in the result set, if set to 'absolute',
354 value states an absolute target position.""" 381 value states an absolute target position."""
355 self._check_executed() 382 self._check_executed()
356 if mode == 'relative': 383 if mode == 'relative':
357 r = self.rownumber + value 384 row = self.rownumber + value
358 elif mode == 'absolute': 385 elif mode == 'absolute':
359 r = value 386 row = value
360 else: 387 else:
361 self.errorhandler(self, ProgrammingError, 388 self.errorhandler(self, self.ProgrammingError,
362 "unknown scroll mode %s" % `mode`) 389 "unknown scroll mode %s" % `mode`)
363 if r < 0 or r >= len(self._rows): 390 if row < 0 or row >= len(self._rows):
364 self.errorhandler(self, IndexError, "out of range") 391 self.errorhandler(self, IndexError, "out of range")
365 self.rownumber = r 392 self.rownumber = row
366 393
367 def __iter__(self): 394 def __iter__(self):
368 self._check_executed() 395 self._check_executed()
369 result = self.rownumber and self._rows[self.rownumber:] or self._rows 396 result = self.rownumber and self._rows[self.rownumber:] or self._rows
370 return iter(result) 397 return iter(result)
378 close() the cursor before additional queries can be peformed on 405 close() the cursor before additional queries can be peformed on
379 the connection.""" 406 the connection."""
380 407
381 _defer_warnings = True 408 _defer_warnings = True
382 409
383 def _get_result(self): return self._get_db().use_result() 410 def _get_result(self):
411 """Low-level; calls mysql_use_result()"""
412 return self._get_db().use_result()
384 413
385 def fetchone(self): 414 def fetchone(self):
386 """Fetches a single row from the cursor.""" 415 """Fetches a single row from the cursor."""
387 self._check_executed() 416 self._check_executed()
388 r = self._fetch_row(1) 417 rows = self._fetch_row(1)
389 if not r: 418 if not rows:
390 self._warning_check() 419 self._warning_check()
391 return None 420 return None
392 self.rownumber = self.rownumber + 1 421 self.rownumber = self.rownumber + 1
393 return r[0] 422 return rows[0]
394 423
395 def fetchmany(self, size=None): 424 def fetchmany(self, size=None):
396 """Fetch up to size rows from the cursor. Result set may be smaller 425 """Fetch up to size rows from the cursor. Result set may be smaller
397 than size. If size is not defined, cursor.arraysize is used.""" 426 than size. If size is not defined, cursor.arraysize is used."""
398 self._check_executed() 427 self._check_executed()
399 r = self._fetch_row(size or self.arraysize) 428 rows = self._fetch_row(size or self.arraysize)
400 self.rownumber = self.rownumber + len(r) 429 self.rownumber = self.rownumber + len(rows)
401 if not r: 430 if not rows:
402 self._warning_check() 431 self._warning_check()
403 return r 432 return rows
404 433
405 def fetchall(self): 434 def fetchall(self):
406 """Fetchs all available rows from the cursor.""" 435 """Fetchs all available rows from the cursor."""
407 self._check_executed() 436 self._check_executed()
408 r = self._fetch_row(0) 437 rows = self._fetch_row(0)
409 self.rownumber = self.rownumber + len(r) 438 self.rownumber = self.rownumber + len(rows)
410 self._warning_check() 439 self._warning_check()
411 return r 440 return rows
412 441
413 def __iter__(self): 442 def __iter__(self):
414 return self 443 return self
415 444
416 def next(self): 445 def next(self):
433 """This is a MixIn class that causes all rows to be returned as 462 """This is a MixIn class that causes all rows to be returned as
434 dictionaries. This is a non-standard feature.""" 463 dictionaries. This is a non-standard feature."""
435 464
436 _fetch_type = 1 465 _fetch_type = 1
437 466
438 def fetchoneDict(self):
439 """Fetch a single row as a dictionary. Deprecated:
440 Use fetchone() instead. Will be removed in 1.3."""
441 from warnings import warn
442 warn("fetchoneDict() is non-standard and will be removed in 1.3",
443 DeprecationWarning, 2)
444 return self.fetchone()
445
446 def fetchmanyDict(self, size=None):
447 """Fetch several rows as a list of dictionaries. Deprecated:
448 Use fetchmany() instead. Will be removed in 1.3."""
449 from warnings import warn
450 warn("fetchmanyDict() is non-standard and will be removed in 1.3",
451 DeprecationWarning, 2)
452 return self.fetchmany(size)
453
454 def fetchallDict(self):
455 """Fetch all available rows as a list of dictionaries. Deprecated:
456 Use fetchall() instead. Will be removed in 1.3."""
457 from warnings import warn
458 warn("fetchallDict() is non-standard and will be removed in 1.3",
459 DeprecationWarning, 2)
460 return self.fetchall()
461
462
463 class CursorOldDictRowsMixIn(CursorDictRowsMixIn):
464
465 """This is a MixIn class that returns rows as dictionaries with
466 the same key convention as the old Mysqldb (MySQLmodule). Don't
467 use this."""
468
469 _fetch_type = 2
470
471 467
472 class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn, 468 class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn,
473 BaseCursor): 469 BaseCursor):
474 470
475 """This is the standard Cursor class that returns rows as tuples 471 """This is the standard Cursor class that returns rows as tuples
477 473
478 474
479 class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn, 475 class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn,
480 BaseCursor): 476 BaseCursor):
481 477
482 """This is a Cursor class that returns rows as dictionaries and 478 """This is a Cursor class that returns rows as dictionaries and
483 stores the result set in the client.""" 479 stores the result set in the client."""
484 480
485 481
486 class SSCursor(CursorUseResultMixIn, CursorTupleRowsMixIn, 482 class SSCursor(CursorUseResultMixIn, CursorTupleRowsMixIn,
487 BaseCursor): 483 BaseCursor):