@@ -48,10 +48,10 @@ def __init__(self, conn_params: ConnectionParams):
48
48
self .ssh_key = conn_params .ssh_key
49
49
self .port = conn_params .port
50
50
self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
51
- if self .ssh_key :
52
- self .ssh_cmd += ["-i" , self .ssh_key ]
53
51
if self .port :
54
52
self .ssh_cmd += ["-p" , self .port ]
53
+ if self .ssh_key :
54
+ self .ssh_cmd += ["-i" , self .ssh_key ]
55
55
self .remote = True
56
56
self .username = conn_params .username or self .get_user ()
57
57
self .tunnel_process = None
@@ -283,6 +283,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
283
283
mode = "r+b" if binary else "r+"
284
284
285
285
with tempfile .NamedTemporaryFile (mode = mode , delete = False ) as tmp_file :
286
+ # Because in scp we set up port using -P option instead -p
286
287
scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
287
288
288
289
if not truncate :
@@ -302,12 +303,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
302
303
tmp_file .write (data )
303
304
304
305
tmp_file .flush ()
305
- # Because in scp we set up port using -P option
306
306
scp_cmd = ['scp' ] + scp_ssh_cmd + [tmp_file .name , f"{ self .username } @{ self .host } :{ filename } " ]
307
307
subprocess .run (scp_cmd , check = True )
308
-
309
308
remote_directory = os .path .dirname (filename )
310
- mkdir_cmd = ['ssh' ] + scp_ssh_cmd + [f"{ self .username } @{ self .host } " , f"mkdir -p { remote_directory } " ]
309
+
310
+ mkdir_cmd = ['ssh' ] + self .ssh_cmd + [f"{ self .username } @{ self .host } " , f'mkdir -p { remote_directory } ' ]
311
311
subprocess .run (mkdir_cmd , check = True )
312
312
313
313
os .remove (tmp_file .name )
@@ -385,9 +385,10 @@ def get_process_children(self, pid):
385
385
# Database control
386
386
def db_connect (self , dbname , user , password = None , host = "localhost" , port = 5432 ):
387
387
"""
388
- Established SSH tunnel and Connects to a PostgreSQL
388
+ Establish SSH tunnel and connect to a PostgreSQL database.
389
389
"""
390
- self .establish_ssh_tunnel (local_port = reserve_port (), remote_port = 5432 )
390
+ self .establish_ssh_tunnel (local_port = port , remote_port = self .conn_params .port )
391
+
391
392
try :
392
393
conn = pglib .connect (
393
394
host = host ,
@@ -396,6 +397,11 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
396
397
user = user ,
397
398
password = password ,
398
399
)
400
+ print ("Database connection established successfully." )
399
401
return conn
400
402
except Exception as e :
401
- raise Exception (f"Could not connect to the database. Error: { e } " )
403
+ print (f"Error connecting to the database: { str (e )} " )
404
+ if self .tunnel_process :
405
+ self .tunnel_process .terminate ()
406
+ print ("SSH tunnel closed due to connection failure." )
407
+ raise
0 commit comments