0001import atexit
0002from cgi import parse_qsl
0003import inspect
0004import new
0005import os
0006import sys
0007import threading
0008import types
0009import urllib
0010import warnings
0011import weakref
0012
0013from cache import CacheSet
0014import classregistry
0015import col
0016from converters import sqlrepr
0017import main
0018import sqlbuilder
0019from util.threadinglocal import local as threading_local
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026 conn = ref()
0027 if conn is not None:
0028 conn.close()
0029
0030class ConsoleWriter:
0031 def __init__(self, connection, loglevel):
0032
0033 self.loglevel = loglevel or "stdout"
0034 self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035 def write(self, text):
0036 logfile = getattr(sys, self.loglevel)
0037 if isinstance(text, unicode):
0038 try:
0039 text = text.encode(self.dbEncoding)
0040 except UnicodeEncodeError:
0041 text = repr(text)[2:-1]
0042 logfile.write(text + '\n')
0043
0044class LogWriter:
0045 def __init__(self, connection, logger, loglevel):
0046 self.logger = logger
0047 self.loglevel = loglevel
0048 self.logmethod = getattr(logger, loglevel)
0049 def write(self, text):
0050 self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053 if not loggerName:
0054 return ConsoleWriter(connection, loglevel)
0055 import logging
0056 logger = logging.getLogger(loggerName)
0057 return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060 """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061 _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062 '0': False, 'no': False, 'false': False, 'off': False}
0063 def __new__(cls, value):
0064 try:
0065 return Boolean._keywords[value.lower()]
0066 except (AttributeError, KeyError):
0067 return bool(value)
0068
0069class DBConnection:
0070
0071 def __init__(self, name=None, debug=False, debugOutput=False,
0072 cache=True, style=None, autoCommit=True,
0073 debugThreading=False, registry=None,
0074 logger=None, loglevel=None):
0075 self.name = name
0076 self.debug = Boolean(debug)
0077 self.debugOutput = Boolean(debugOutput)
0078 self.debugThreading = Boolean(debugThreading)
0079 self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080 self.doCache = Boolean(cache)
0081 self.cache = CacheSet(cache=self.doCache)
0082 self.style = style
0083 self._connectionNumbers = {}
0084 self._connectionCount = 1
0085 self.autoCommit = Boolean(autoCommit)
0086 self.registry = registry or None
0087 classregistry.registry(self.registry).addCallback(self.soClassAdded)
0088 registerConnectionInstance(self)
0089 atexit.register(_closeConnection, weakref.ref(self))
0090
0091 def oldUri(self):
0092 auth = getattr(self, 'user', '') or ''
0093 if auth:
0094 if self.password:
0095 auth = auth + ':' + self.password
0096 auth = auth + '@'
0097 else:
0098 assert not getattr(self, 'password', None), (
0099 'URIs cannot express passwords without usernames')
0100 uri = '%s://%s' % (self.dbName, auth)
0101 if self.host:
0102 uri += self.host
0103 if self.port:
0104 uri += ':%d' % self.port
0105 uri += '/'
0106 db = self.db
0107 if db.startswith('/'):
0108 db = db[1:]
0109 return uri + db
0110
0111 def uri(self):
0112 auth = getattr(self, 'user', '') or ''
0113 if auth:
0114 auth = urllib.quote(auth)
0115 if self.password:
0116 auth = auth + ':' + urllib.quote(self.password)
0117 auth = auth + '@'
0118 else:
0119 assert not getattr(self, 'password', None), (
0120 'URIs cannot express passwords without usernames')
0121 uri = '%s://%s' % (self.dbName, auth)
0122 if self.host:
0123 uri += self.host
0124 if self.port:
0125 uri += ':%d' % self.port
0126 uri += '/'
0127 db = self.db
0128 if db.startswith('/'):
0129 db = db[1:]
0130 return uri + urllib.quote(db)
0131
0132 @classmethod
0133 def connectionFromOldURI(cls, uri):
0134 return cls._connectionFromParams(*cls._parseOldURI(uri))
0135
0136 @classmethod
0137 def connectionFromURI(cls, uri):
0138 return cls._connectionFromParams(*cls._parseURI(uri))
0139
0140 @staticmethod
0141 def _parseOldURI(uri):
0142 schema, rest = uri.split(':', 1)
0143 assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0144 if rest.startswith('/') and not rest.startswith('//'):
0145 host = None
0146 rest = rest[1:]
0147 elif rest.startswith('///'):
0148 host = None
0149 rest = rest[3:]
0150 else:
0151 rest = rest[2:]
0152 if rest.find('/') == -1:
0153 host = rest
0154 rest = ''
0155 else:
0156 host, rest = rest.split('/', 1)
0157 if host and host.find('@') != -1:
0158 user, host = host.rsplit('@', 1)
0159 if user.find(':') != -1:
0160 user, password = user.split(':', 1)
0161 else:
0162 password = None
0163 else:
0164 user = password = None
0165 if host and host.find(':') != -1:
0166 _host, port = host.split(':')
0167 try:
0168 port = int(port)
0169 except ValueError:
0170 raise ValueError, "port must be integer, got '%s' instead" % port
0171 if not (1 <= port <= 65535):
0172 raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0173 host = _host
0174 else:
0175 port = None
0176 path = '/' + rest
0177 if os.name == 'nt':
0178 if (len(rest) > 1) and (rest[1] == '|'):
0179 path = "%s:%s" % (rest[0], rest[2:])
0180 args = {}
0181 if path.find('?') != -1:
0182 path, arglist = path.split('?', 1)
0183 arglist = arglist.split('&')
0184 for single in arglist:
0185 argname, argvalue = single.split('=', 1)
0186 argvalue = urllib.unquote(argvalue)
0187 args[argname] = argvalue
0188 return user, password, host, port, path, args
0189
0190 @staticmethod
0191 def _parseURI(uri):
0192 protocol, request = urllib.splittype(uri)
0193 user, password, port = None, None, None
0194 host, path = urllib.splithost(request)
0195
0196 if host:
0197
0198
0199 if '@' in host:
0200 user, host = host.split('@', 1)
0201 if user:
0202 user, password = [x and urllib.unquote(x) or None for x in urllib.splitpasswd(user)]
0203 host, port = urllib.splitport(host)
0204 if port: port = int(port)
0205 elif host == '':
0206 host = None
0207
0208
0209 path, tag = urllib.splittag(path)
0210 path, query = urllib.splitquery(path)
0211
0212 path = urllib.unquote(path)
0213 if (os.name == 'nt') and (len(path) > 2):
0214
0215
0216 if path[2] == '|':
0217 path = "%s:%s" % (path[0:2], path[3:])
0218
0219 if (path[0] == '/') and (path[2] == ':'):
0220 path = path[1:]
0221
0222 args = {}
0223 if query:
0224 for name, value in parse_qsl(query):
0225 args[name] = value
0226
0227 return user, password, host, port, path, args
0228
0229 def soClassAdded(self, soClass):
0230 """
0231 This is called for each new class; we use this opportunity
0232 to create an instance method that is bound to the class
0233 and this connection.
0234 """
0235 name = soClass.__name__
0236 assert not hasattr(self, name), (
0237 "Connection %r already has an attribute with the name "
0238 "%r (and you just created the conflicting class %r)"
0239 % (self, name, soClass))
0240 setattr(self, name, ConnWrapper(soClass, self))
0241
0242 def expireAll(self):
0243 """
0244 Expire all instances of objects for this connection.
0245 """
0246 cache_set = self.cache
0247 cache_set.weakrefAll()
0248 for item in cache_set.getAll():
0249 item.expire()
0250
0251class ConnWrapper(object):
0252
0253 """
0254 This represents a SQLObject class that is bound to a specific
0255 connection (instances have a connection instance variable, but
0256 classes are global, so this is binds the connection variable
0257 lazily when a class method is accessed)
0258 """
0259
0260
0261
0262
0263 def __init__(self, soClass, connection):
0264 self._soClass = soClass
0265 self._connection = connection
0266
0267 def __call__(self, *args, **kw):
0268 kw['connection'] = self._connection
0269 return self._soClass(*args, **kw)
0270
0271 def __getattr__(self, attr):
0272 meth = getattr(self._soClass, attr)
0273 if not isinstance(meth, types.MethodType):
0274
0275 return meth
0276 try:
0277 takes_conn = meth.takes_connection
0278 except AttributeError:
0279 args, varargs, varkw, defaults = inspect.getargspec(meth)
0280 assert not varkw and not varargs, (
0281 "I cannot tell whether I must wrap this method, "
0282 "because it takes **kw: %r"
0283 % meth)
0284 takes_conn = 'connection' in args
0285 meth.im_func.takes_connection = takes_conn
0286 if not takes_conn:
0287 return meth
0288 return ConnMethodWrapper(meth, self._connection)
0289
0290class ConnMethodWrapper(object):
0291
0292 def __init__(self, method, connection):
0293 self._method = method
0294 self._connection = connection
0295
0296 def __getattr__(self, attr):
0297 return getattr(self._method, attr)
0298
0299 def __call__(self, *args, **kw):
0300 kw['connection'] = self._connection
0301 return self._method(*args, **kw)
0302
0303 def __repr__(self):
0304 return '<Wrapped %r with connection %r>' % (
0305 self._method, self._connection)
0306
0307class DBAPI(DBConnection):
0308
0309 """
0310 Subclass must define a `makeConnection()` method, which
0311 returns a newly-created connection object.
0312
0313 ``queryInsertID`` must also be defined.
0314 """
0315
0316 dbName = None
0317
0318 def __init__(self, **kw):
0319 self._pool = []
0320 self._poolLock = threading.Lock()
0321 DBConnection.__init__(self, **kw)
0322 self._binaryType = type(self.module.Binary(''))
0323
0324 def _runWithConnection(self, meth, *args):
0325 conn = self.getConnection()
0326 try:
0327 val = meth(conn, *args)
0328 finally:
0329 self.releaseConnection(conn)
0330 return val
0331
0332 def getConnection(self):
0333 self._poolLock.acquire()
0334 try:
0335 if not self._pool:
0336 conn = self.makeConnection()
0337 self._connectionNumbers[id(conn)] = self._connectionCount
0338 self._connectionCount += 1
0339 else:
0340 conn = self._pool.pop()
0341 if self.debug:
0342 s = 'ACQUIRE'
0343 if self._pool is not None:
0344 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0345 self.printDebug(conn, s, 'Pool')
0346 return conn
0347 finally:
0348 self._poolLock.release()
0349
0350 def releaseConnection(self, conn, explicit=False):
0351 if self.debug:
0352 if explicit:
0353 s = 'RELEASE (explicit)'
0354 else:
0355 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0356 if self._pool is None:
0357 s += ' no pooling'
0358 else:
0359 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0360 self.printDebug(conn, s, 'Pool')
0361 if self.supportTransactions and not explicit:
0362 if self.autoCommit == 'exception':
0363 if self.debug:
0364 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0365 conn.rollback()
0366 raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0367 elif self.autoCommit:
0368 if self.debug:
0369 self.printDebug(conn, 'auto', 'COMMIT')
0370 if not getattr(conn, 'autocommit', False):
0371 conn.commit()
0372 else:
0373 if self.debug:
0374 self.printDebug(conn, 'auto', 'ROLLBACK')
0375 conn.rollback()
0376 if self._pool is not None:
0377 if conn not in self._pool:
0378
0379
0380
0381 self._pool.insert(0, conn)
0382 else:
0383 conn.close()
0384
0385 def printDebug(self, conn, s, name, type='query'):
0386 if name == 'Pool' and self.debug != 'Pool':
0387 return
0388 if type == 'query':
0389 sep = ': '
0390 else:
0391 sep = '->'
0392 s = repr(s)
0393 n = self._connectionNumbers[id(conn)]
0394 spaces = ' '*(8-len(name))
0395 if self.debugThreading:
0396 threadName = threading.currentThread().getName()
0397 threadName = (':' + threadName + ' '*(8-len(threadName)))
0398 else:
0399 threadName = ''
0400 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0401 self.debugWriter.write(msg)
0402
0403 def _executeRetry(self, conn, cursor, query):
0404 if self.debug:
0405 self.printDebug(conn, query, 'QueryR')
0406 return cursor.execute(query)
0407
0408 def _query(self, conn, s):
0409 if self.debug:
0410 self.printDebug(conn, s, 'Query')
0411 self._executeRetry(conn, conn.cursor(), s)
0412
0413 def query(self, s):
0414 return self._runWithConnection(self._query, s)
0415
0416 def _queryAll(self, conn, s):
0417 if self.debug:
0418 self.printDebug(conn, s, 'QueryAll')
0419 c = conn.cursor()
0420 self._executeRetry(conn, c, s)
0421 value = c.fetchall()
0422 if self.debugOutput:
0423 self.printDebug(conn, value, 'QueryAll', 'result')
0424 return value
0425
0426 def queryAll(self, s):
0427 return self._runWithConnection(self._queryAll, s)
0428
0429 def _queryAllDescription(self, conn, s):
0430 """
0431 Like queryAll, but returns (description, rows), where the
0432 description is cursor.description (which gives row types)
0433 """
0434 if self.debug:
0435 self.printDebug(conn, s, 'QueryAllDesc')
0436 c = conn.cursor()
0437 self._executeRetry(conn, c, s)
0438 value = c.fetchall()
0439 if self.debugOutput:
0440 self.printDebug(conn, value, 'QueryAll', 'result')
0441 return c.description, value
0442
0443 def queryAllDescription(self, s):
0444 return self._runWithConnection(self._queryAllDescription, s)
0445
0446 def _queryOne(self, conn, s):
0447 if self.debug:
0448 self.printDebug(conn, s, 'QueryOne')
0449 c = conn.cursor()
0450 self._executeRetry(conn, c, s)
0451 value = c.fetchone()
0452 if self.debugOutput:
0453 self.printDebug(conn, value, 'QueryOne', 'result')
0454 return value
0455
0456 def queryOne(self, s):
0457 return self._runWithConnection(self._queryOne, s)
0458
0459 def _insertSQL(self, table, names, values):
0460 return ("INSERT INTO %s (%s) VALUES (%s)" %
0461 (table, ', '.join(names),
0462 ', '.join([self.sqlrepr(v) for v in values])))
0463
0464 def transaction(self):
0465 return Transaction(self)
0466
0467 def queryInsertID(self, soInstance, id, names, values):
0468 return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0469
0470 def iterSelect(self, select):
0471 return select.IterationClass(self, self.getConnection(),
0472 select, keepConnection=False)
0473
0474 def accumulateSelect(self, select, *expressions):
0475 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0476 to the select object.
0477 """
0478 q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0479 q = self.sqlrepr(q)
0480 val = self.queryOne(q)
0481 if len(expressions) == 1:
0482 val = val[0]
0483 return val
0484
0485 def queryForSelect(self, select):
0486 return self.sqlrepr(select.queryForSelect())
0487
0488 def _SO_createJoinTable(self, join):
0489 self.query(self._SO_createJoinTableSQL(join))
0490
0491 def _SO_createJoinTableSQL(self, join):
0492 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0493 (join.intermediateTable,
0494 join.joinColumn,
0495 self.joinSQLType(join),
0496 join.otherColumn,
0497 self.joinSQLType(join)))
0498
0499 def _SO_dropJoinTable(self, join):
0500 self.query("DROP TABLE %s" % join.intermediateTable)
0501
0502 def _SO_createIndex(self, soClass, index):
0503 self.query(self.createIndexSQL(soClass, index))
0504
0505 def createIndexSQL(self, soClass, index):
0506 assert 0, 'Implement in subclasses'
0507
0508 def createTable(self, soClass):
0509 createSql, constraints = self.createTableSQL(soClass)
0510 self.query(createSql)
0511
0512 return constraints
0513
0514 def createReferenceConstraints(self, soClass):
0515 refConstraints = [self.createReferenceConstraint(soClass, column) for column in soClass.sqlmeta.columnList if isinstance(column, col.SOForeignKey)]
0518 refConstraintDefs = [constraint for constraint in refConstraints if constraint]
0521 return refConstraintDefs
0522
0523 def createSQL(self, soClass):
0524 tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0525 if tableCreateSQLs:
0526 assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0527 '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0528 (soClass.__name__))
0529 if isinstance(tableCreateSQLs, dict):
0530 tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0531 if isinstance(tableCreateSQLs, str):
0532 tableCreateSQLs = [tableCreateSQLs]
0533 if isinstance(tableCreateSQLs, tuple):
0534 tableCreateSQLs = list(tableCreateSQLs)
0535 assert isinstance(tableCreateSQLs,list), (
0536 'Unable to create a list from %s.sqlmeta.createSQL' %
0537 (soClass.__name__))
0538 return tableCreateSQLs or []
0539
0540 def createTableSQL(self, soClass):
0541 constraints = self.createReferenceConstraints(soClass)
0542 extraSQL = self.createSQL(soClass)
0543 createSql = ('CREATE TABLE %s (\n%s\n)' %
0544 (soClass.sqlmeta.table, self.createColumns(soClass)))
0545 return createSql, constraints + extraSQL
0546
0547 def createColumns(self, soClass):
0548 columnDefs = [self.createIDColumn(soClass)] + [self.createColumn(soClass, col)
0550 for col in soClass.sqlmeta.columnList]
0551 return ",\n".join([" %s" % c for c in columnDefs])
0552
0553 def createReferenceConstraint(self, soClass, col):
0554 assert 0, "Implement in subclasses"
0555
0556 def createColumn(self, soClass, col):
0557 assert 0, "Implement in subclasses"
0558
0559 def dropTable(self, tableName, cascade=False):
0560 self.query("DROP TABLE %s" % tableName)
0561
0562 def clearTable(self, tableName):
0563
0564
0565
0566
0567 self.query("DELETE FROM %s" % tableName)
0568
0569 def createBinary(self, value):
0570 """
0571 Create a binary object wrapper for the given database.
0572 """
0573
0574 return self.module.Binary(value)
0575
0576
0577
0578
0579
0580
0581
0582 def _SO_update(self, so, values):
0583 self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0584 (so.sqlmeta.table,
0585 ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0586 for dbName, value in values]),
0587 so.sqlmeta.idName,
0588 self.sqlrepr(so.id)))
0589
0590 def _SO_selectOne(self, so, columnNames):
0591 return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0592
0593
0594 def _SO_selectOneAlt(self, so, columnNames, condition):
0595 if columnNames:
0596 columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0597 else:
0598 columns = None
0599 return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0600 staticTables=[so.sqlmeta.table],
0601 clause=condition)))
0602
0603 def _SO_delete(self, so):
0604 self.query("DELETE FROM %s WHERE %s = (%s)" %
0605 (so.sqlmeta.table,
0606 so.sqlmeta.idName,
0607 self.sqlrepr(so.id)))
0608
0609 def _SO_selectJoin(self, soClass, column, value):
0610 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0611 (soClass.sqlmeta.idName,
0612 soClass.sqlmeta.table,
0613 column,
0614 self.sqlrepr(value)))
0615
0616 def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0617 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0618 (getColumn,
0619 table,
0620 joinColumn,
0621 self.sqlrepr(value)))
0622
0623 def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0624 secondColumn, secondValue):
0625 self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0626 (table,
0627 firstColumn,
0628 self.sqlrepr(firstValue),
0629 secondColumn,
0630 self.sqlrepr(secondValue)))
0631
0632 def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0633 secondColumn, secondValue):
0634 self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0635 (table,
0636 firstColumn,
0637 secondColumn,
0638 self.sqlrepr(firstValue),
0639 self.sqlrepr(secondValue)))
0640
0641 def _SO_columnClause(self, soClass, kw):
0642 ops = {None: "IS"}
0643 data = []
0644 if 'id' in kw:
0645 data.append((soClass.sqlmeta.idName, kw.pop('id')))
0646 for soColumn in soClass.sqlmeta.columnList:
0647 key = soColumn.name
0648 if key in kw:
0649 val = kw.pop(key)
0650 if soColumn.from_python:
0651 val = soColumn.from_python(val, sqlbuilder.SQLObjectState(soClass, connection=self))
0652 data.append((soColumn.dbName, val))
0653 elif soColumn.foreignName in kw:
0654 obj = kw.pop(soColumn.foreignName)
0655 if isinstance(obj, main.SQLObject):
0656 data.append((soColumn.dbName, obj.id))
0657 else:
0658 data.append((soColumn.dbName, obj))
0659 if kw:
0660
0661 raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0662
0663 if not data:
0664 return None
0665 return ' AND '.join(
0666 ['%s %s %s' %
0667 (dbName, ops.get(value, "="), self.sqlrepr(value))
0668 for dbName, value
0669 in data])
0670
0671 def sqlrepr(self, v):
0672 return sqlrepr(v, self.dbName)
0673
0674 def __del__(self):
0675 self.close()
0676
0677 def close(self):
0678 if not hasattr(self, '_pool'):
0679
0680
0681 return
0682 if not self._pool:
0683 return
0684 self._poolLock.acquire()
0685 try:
0686 if not self._pool:
0687 return
0688 conns = self._pool[:]
0689 self._pool[:] = []
0690 for conn in conns:
0691 try:
0692 conn.close()
0693 except self.module.Error:
0694 pass
0695 del conn
0696 del conns
0697 finally:
0698 self._poolLock.release()
0699
0700 def createEmptyDatabase(self):
0701 """
0702 Create an empty database.
0703 """
0704 raise NotImplementedError
0705
0706class Iteration(object):
0707
0708 def __init__(self, dbconn, rawconn, select, keepConnection=False):
0709 self.dbconn = dbconn
0710 self.rawconn = rawconn
0711 self.select = select
0712 self.keepConnection = keepConnection
0713 self.cursor = rawconn.cursor()
0714 self.query = self.dbconn.queryForSelect(select)
0715 if dbconn.debug:
0716 dbconn.printDebug(rawconn, self.query, 'Select')
0717 self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0718
0719 def __iter__(self):
0720 return self
0721
0722 def next(self):
0723 result = self.cursor.fetchone()
0724 if result is None:
0725 self._cleanup()
0726 raise StopIteration
0727 if result[0] is None:
0728 return None
0729 if self.select.ops.get('lazyColumns', 0):
0730 obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0731 return obj
0732 else:
0733 obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0734 return obj
0735
0736 def _cleanup(self):
0737 if getattr(self, 'query', None) is None:
0738
0739 return
0740 self.query = None
0741 if not self.keepConnection:
0742 self.dbconn.releaseConnection(self.rawconn)
0743 self.dbconn = self.rawconn = self.select = self.cursor = None
0744
0745 def __del__(self):
0746 self._cleanup()
0747
0748class Transaction(object):
0749
0750 def __init__(self, dbConnection):
0751
0752 self._obsolete = True
0753 self._dbConnection = dbConnection
0754 self._connection = dbConnection.getConnection()
0755 self._dbConnection._setAutoCommit(self._connection, 0)
0756 self.cache = CacheSet(cache=dbConnection.doCache)
0757 self._deletedCache = {}
0758 self._obsolete = False
0759
0760 def assertActive(self):
0761 assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0762
0763 def query(self, s):
0764 self.assertActive()
0765 return self._dbConnection._query(self._connection, s)
0766
0767 def queryAll(self, s):
0768 self.assertActive()
0769 return self._dbConnection._queryAll(self._connection, s)
0770
0771 def queryOne(self, s):
0772 self.assertActive()
0773 return self._dbConnection._queryOne(self._connection, s)
0774
0775 def queryInsertID(self, soInstance, id, names, values):
0776 self.assertActive()
0777 return self._dbConnection._queryInsertID(
0778 self._connection, soInstance, id, names, values)
0779
0780 def iterSelect(self, select):
0781 self.assertActive()
0782
0783
0784
0785
0786
0787 return iter(list(select.IterationClass(self, self._connection,
0788 select, keepConnection=True)))
0789
0790 def _SO_delete(self, inst):
0791 cls = inst.__class__.__name__
0792 if not cls in self._deletedCache:
0793 self._deletedCache[cls] = []
0794 self._deletedCache[cls].append(inst.id)
0795 meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0796 return meth(inst)
0797
0798 def commit(self, close=False):
0799 if self._obsolete:
0800
0801 return
0802 if self._dbConnection.debug:
0803 self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0804 self._connection.commit()
0805 subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0806 subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0807 for cls, ids in subCaches:
0808 for id in ids:
0809 inst = self._dbConnection.cache.tryGetByName(id, cls)
0810 if inst is not None:
0811 inst.expire()
0812 if close:
0813 self._makeObsolete()
0814
0815 def rollback(self):
0816 if self._obsolete:
0817
0818 return
0819 if self._dbConnection.debug:
0820 self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0821 subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0822 self._connection.rollback()
0823
0824 for subCache, ids in subCaches:
0825 for id in ids:
0826 inst = subCache.tryGet(id)
0827 if inst is not None:
0828 inst.expire()
0829 self._makeObsolete()
0830
0831 def __getattr__(self, attr):
0832 """
0833 If nothing else works, let the parent connection handle it.
0834 Except with this transaction as 'self'. Poor man's
0835 acquisition? Bad programming? Okay, maybe.
0836 """
0837 self.assertActive()
0838 attr = getattr(self._dbConnection, attr)
0839 try:
0840 func = attr.im_func
0841 except AttributeError:
0842 if isinstance(attr, ConnWrapper):
0843 return ConnWrapper(attr._soClass, self)
0844 else:
0845 return attr
0846 else:
0847 meth = new.instancemethod(func, self, self.__class__)
0848 return meth
0849
0850 def _makeObsolete(self):
0851 self._obsolete = True
0852 if self._dbConnection.autoCommit:
0853 self._dbConnection._setAutoCommit(self._connection, 1)
0854 self._dbConnection.releaseConnection(self._connection,
0855 explicit=True)
0856 self._connection = None
0857 self._deletedCache = {}
0858
0859 def begin(self):
0860
0861
0862 assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0863 self._obsolete = False
0864 self._connection = self._dbConnection.getConnection()
0865 self._dbConnection._setAutoCommit(self._connection, 0)
0866
0867 def __del__(self):
0868 if self._obsolete:
0869 return
0870 self.rollback()
0871
0872 def close(self):
0873 raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0874
0875class ConnectionHub(object):
0876
0877 """
0878 This object serves as a hub for connections, so that you can pass
0879 in a ConnectionHub to a SQLObject subclass as though it was a
0880 connection, but actually bind a real database connection later.
0881 You can also bind connections on a per-thread basis.
0882
0883 You must hang onto the original ConnectionHub instance, as you
0884 cannot retrieve it again from the class or instance.
0885
0886 To use the hub, do something like::
0887
0888 hub = ConnectionHub()
0889 class MyClass(SQLObject):
0890 _connection = hub
0891
0892 hub.threadConnection = connectionFromURI('...')
0893
0894 """
0895
0896 def __init__(self):
0897 self.threadingLocal = threading_local()
0898
0899 def __get__(self, obj, type=None):
0900
0901
0902
0903 if (obj is not None) and '_connection' in obj.__dict__:
0904 return obj.__dict__['_connection']
0905 return self.getConnection()
0906
0907 def __set__(self, obj, value):
0908 obj.__dict__['_connection'] = value
0909
0910 def getConnection(self):
0911 try:
0912 return self.threadingLocal.connection
0913 except AttributeError:
0914 try:
0915 return self.processConnection
0916 except AttributeError:
0917 raise AttributeError(
0918 "No connection has been defined for this thread "
0919 "or process")
0920
0921 def doInTransaction(self, func, *args, **kw):
0922 """
0923 This routine can be used to run a function in a transaction,
0924 rolling the transaction back if any exception is raised from
0925 that function, and committing otherwise.
0926
0927 Use like::
0928
0929 sqlhub.doInTransaction(process_request, os.environ)
0930
0931 This will run ``process_request(os.environ)``. The return
0932 value will be preserved.
0933 """
0934
0935
0936 try:
0937 old_conn = self.threadingLocal.connection
0938 old_conn_is_threading = True
0939 except AttributeError:
0940 old_conn = self.processConnection
0941 old_conn_is_threading = False
0942 conn = old_conn.transaction()
0943 if old_conn_is_threading:
0944 self.threadConnection = conn
0945 else:
0946 self.processConnection = conn
0947 try:
0948 try:
0949 value = func(*args, **kw)
0950 except:
0951 conn.rollback()
0952 raise
0953 else:
0954 conn.commit(close=True)
0955 return value
0956 finally:
0957 if old_conn_is_threading:
0958 self.threadConnection = old_conn
0959 else:
0960 self.processConnection = old_conn
0961
0962 def _set_threadConnection(self, value):
0963 self.threadingLocal.connection = value
0964
0965 def _get_threadConnection(self):
0966 return self.threadingLocal.connection
0967
0968 def _del_threadConnection(self):
0969 del self.threadingLocal.connection
0970
0971 threadConnection = property(_get_threadConnection,
0972 _set_threadConnection,
0973 _del_threadConnection)
0974
0975class ConnectionURIOpener(object):
0976
0977 def __init__(self):
0978 self.schemeBuilders = {}
0979 self.instanceNames = {}
0980 self.cachedURIs = {}
0981
0982 def registerConnection(self, schemes, builder):
0983 for uriScheme in schemes:
0984 assert not uriScheme in self.schemeBuilders or self.schemeBuilders[uriScheme] is builder, "A driver has already been registered for the URI scheme %s" % uriScheme
0987 self.schemeBuilders[uriScheme] = builder
0988
0989 def registerConnectionInstance(self, inst):
0990 if inst.name:
0991 assert not inst.name in self.instanceNames or self.instanceNames[inst.name] is cls, "A instance has already been registered with the name %s" % inst.name
0994 assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0995 self.instanceNames[inst.name] = inst
0996
0997 def connectionForURI(self, uri, oldUri=False, **args):
0998 if args:
0999 if '?' not in uri:
1000 uri += '?' + urllib.urlencode(args)
1001 else:
1002 uri += '&' + urllib.urlencode(args)
1003 if uri in self.cachedURIs:
1004 return self.cachedURIs[uri]
1005 if uri.find(':') != -1:
1006 scheme, rest = uri.split(':', 1)
1007 connCls = self.dbConnectionForScheme(scheme)
1008 if oldUri:
1009 conn = connCls.connectionFromOldURI(uri)
1010 else:
1011 conn = connCls.connectionFromURI(uri)
1012 else:
1013
1014 assert uri in self.instanceNames, "No SQLObject driver exists under the name %s" % uri
1016 conn = self.instanceNames[uri]
1017
1018 self.cachedURIs[uri] = conn
1019 return conn
1020
1021 def dbConnectionForScheme(self, scheme):
1022 assert scheme in self.schemeBuilders, (
1023 "No SQLObject driver exists for %s (only %s)"
1024 % (scheme, ', '.join(self.schemeBuilders.keys())))
1025 return self.schemeBuilders[scheme]()
1026
1027TheURIOpener = ConnectionURIOpener()
1028
1029registerConnection = TheURIOpener.registerConnection
1030registerConnectionInstance = TheURIOpener.registerConnectionInstance
1031connectionForURI = TheURIOpener.connectionForURI
1032dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1033
1034
1035import firebird
1036import maxdb
1037import mssql
1038import mysql
1039import postgres
1040import rdbhost
1041import sqlite
1042import sybase