summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Kreen2010-07-21 16:02:07 +0000
committerMarko Kreen2010-07-21 16:02:07 +0000
commit78bdfc16c395bc355b984aa47154621a2a7afd43 (patch)
tree0f586039361efdc6644c731fcf22c5c2cd516952
parentba4043769febf0474a14c5fa08d6ab1e0e0cbbb9 (diff)
skytools: set_tcp_keepalive()
Separate TCP keepalive code out from connect_database(), to be usable separately. (Eg. for HTTP connections.) Add also support for Darwin.
-rw-r--r--python/skytools/psycopgwrapper.py66
1 files changed, 52 insertions, 14 deletions
diff --git a/python/skytools/psycopgwrapper.py b/python/skytools/psycopgwrapper.py
index 3445201f..a005100a 100644
--- a/python/skytools/psycopgwrapper.py
+++ b/python/skytools/psycopgwrapper.py
@@ -55,14 +55,14 @@ Plain .fetchall() / .fetchone() give exact same result.
"""
# no exports
-__all__ = ['connect_database']
+__all__ = ['connect_database', 'set_tcp_keepalive']
##from psycopg2.psycopg1 import connect as _pgconnect
# psycopg2.psycopg1.cursor is too backwards compatible,
# to the point of avoiding optimized access.
# only backwards compat thing we need is dict* methods
-import socket
+import sys, socket
import psycopg2.extensions, psycopg2.extras
from skytools.sqltools import dbdict
@@ -103,6 +103,53 @@ class _CompatConnection(psycopg2.extensions.connection):
def cursor(self):
return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor)
+def set_tcp_keepalive(fd, keepalive = True,
+ tcp_keepidle = 4 * 60,
+ tcp_keepcnt = 4,
+ tcp_keepintvl = 15):
+ """Turn on TCP keepalive. The fd can be either numeric or socket
+ object with 'fileno' method.
+
+ OS defaults for SO_KEEPALIVE=1:
+ - Linux: (7200, 9, 75) - can configure all.
+ - MacOS: (7200, 8, 75) - can configure only tcp_keepidle.
+ - Win32: (7200, 5|10, 1) - can configure tcp_keepidle and tcp_keepintvl.
+ Python needs SIO_KEEPALIVE_VALS support in socket.ioctl to enable it.
+
+ Our defaults: (240, 4, 15).
+ """
+
+ # usable on this OS?
+ if not hasattr(socket, 'SO_KEEPALIVE'):
+ return
+
+ # get numeric fd and cast to socket
+ if hasattr(fd, 'fileno'):
+ fd = fd.fileno()
+ s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
+
+ # skip if unix socket
+ if type(s.getsockname()) != type(()):
+ return
+
+ # turn on keepalive on the connection
+ if keepalive:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ if hasattr(socket, 'TCP_KEEPCNT'):
+ s.setsockopt(socket.IPPROTO_TCP, getattr(socket, 'TCP_KEEPIDLE'), tcp_keepidle)
+ s.setsockopt(socket.IPPROTO_TCP, getattr(socket, 'TCP_KEEPCNT'), tcp_keepcnt)
+ s.setsockopt(socket.IPPROTO_TCP, getattr(socket, 'TCP_KEEPINTVL'), tcp_keepintvl)
+ elif hasattr(socket, 'TCP_KEEPALIVE'):
+ s.setsockopt(socket.IPPROTO_TCP, getattr(socket, 'TCP_KEEPALIVE'), tcp_keepidle)
+ elif sys.platform == 'darwin':
+ TCP_KEEPALIVE = 0x10
+ s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, tcp_keepidle)
+ elif sys.platform == 'win32':
+ #s.ioctl(SIO_KEEPALIVE_VALS, (1, tcp_keepidle*1000, tcp_keepintvl*1000))
+ pass
+ else:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0)
+
def connect_database(connstr, keepalive = True,
tcp_keepidle = 4 * 60, # 7200
tcp_keepcnt = 4, # 9
@@ -120,24 +167,15 @@ def connect_database(connstr, keepalive = True,
# create connection
db = _CompatConnection(connstr)
+ curs = db.cursor()
- # turn on keepalive on the connection
- if keepalive and hasattr(socket, 'SO_KEEPALIVE'):
- fd = db.cursor().fileno()
- s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
- # avoid unix sockets
- if type(s.getsockname()) == type(()):
- s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- if hasattr(socket, 'TCP_KEEPCNT'):
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)
- s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)
+ # tune keepalive
+ set_tcp_keepalive(curs, keepalive, tcp_keepidle, tcp_keepcnt, tcp_keepintvl)
# fill .server_version on older psycopg
if not hasattr(db, 'server_version'):
iso = db.isolation_level
db.set_isolation_level(0)
- curs = db.cursor()
curs.execute('show server_version_num')
db.server_version = int(curs.fetchone()[0])
db.set_isolation_level(iso)