diff --git a/testgres/node.py b/testgres/node.py index 3a29404..41504e8 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1372,6 +1372,8 @@ def psql(self, dbname=None, username=None, input=None, + host: typing.Optional[str] = None, + port: typing.Optional[int] = None, **variables): """ Execute a query using psql. @@ -1382,6 +1384,8 @@ def psql(self, dbname: database name to connect to. username: database user name. input: raw input to be passed. + host: an explicit host of server. + port: an explicit port of server. **variables: vars to be set before execution. Returns: @@ -1393,6 +1397,10 @@ def psql(self, >>> psql(query='select 3', ON_ERROR_STOP=1) """ + assert host is None or type(host) == str # noqa: E721 + assert port is None or type(port) == int # noqa: E721 + assert type(variables) == dict # noqa: E721 + return self._psql( ignore_errors=True, query=query, @@ -1400,6 +1408,8 @@ def psql(self, dbname=dbname, username=username, input=input, + host=host, + port=port, **variables ) @@ -1411,7 +1421,11 @@ def _psql( dbname=None, username=None, input=None, + host: typing.Optional[str] = None, + port: typing.Optional[int] = None, **variables): + assert host is None or type(host) == str # noqa: E721 + assert port is None or type(port) == int # noqa: E721 assert type(variables) == dict # noqa: E721 # @@ -1424,10 +1438,21 @@ def _psql( else: raise Exception("Input data must be None or bytes.") + if host is None: + host = self.host + + if port is None: + port = self.port + + assert host is not None + assert port is not None + assert type(host) == str # noqa: E721 + assert type(port) == int # noqa: E721 + psql_params = [ self._get_bin_path("psql"), - "-p", str(self.port), - "-h", self.host, + "-p", str(port), + "-h", host, "-U", username or self.os_ops.username, "-d", dbname or default_dbname(), "-X", # no .psqlrc diff --git a/tests/test_testgres_common.py b/tests/test_testgres_common.py index f4e5996..21fa00d 100644 --- a/tests/test_testgres_common.py +++ b/tests/test_testgres_common.py @@ -678,6 +678,89 @@ def test_psql(self, node_svc: PostgresNodeService): r = node.safe_psql('select 1') # raises! logging.error("node.safe_psql returns [{}]".format(r)) + def test_psql__another_port(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node1: + with __class__.helper__get_node(node_svc).init() as node2: + node1.start() + node2.start() + assert node1.port != node2.port + assert node1.host == node2.host + + node1.stop() + + logging.info("test table in node2 is creating ...") + node2.safe_psql( + dbname="postgres", + query="create table test (id integer);" + ) + + logging.info("try to find test table through node1.psql ...") + res = node1.psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host=node2.host, + port=node2.port, + ) + assert (__class__.helper__rm_carriage_returns(res) == (0, b'1\n', b'')) + + def test_psql__another_bad_host(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node: + logging.info("try to execute node1.psql ...") + res = node.psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host="DUMMY_HOST_NAME", + port=node.port, + ) + + res2 = __class__.helper__rm_carriage_returns(res) + + assert res2[0] != 0 + assert b"DUMMY_HOST_NAME" in res[2] + + def test_safe_psql__another_port(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node1: + with __class__.helper__get_node(node_svc).init() as node2: + node1.start() + node2.start() + assert node1.port != node2.port + assert node1.host == node2.host + + node1.stop() + + logging.info("test table in node2 is creating ...") + node2.safe_psql( + dbname="postgres", + query="create table test (id integer);" + ) + + logging.info("try to find test table through node1.psql ...") + res = node1.safe_psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host=node2.host, + port=node2.port, + ) + assert (__class__.helper__rm_carriage_returns(res) == b'1\n') + + def test_safe_psql__another_bad_host(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node: + logging.info("try to execute node1.psql ...") + + with pytest.raises(expected_exception=Exception) as x: + node.safe_psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host="DUMMY_HOST_NAME", + port=node.port, + ) + + assert "DUMMY_HOST_NAME" in str(x.value) + def test_safe_psql__expect_error(self, node_svc: PostgresNodeService): assert isinstance(node_svc, PostgresNodeService) with __class__.helper__get_node(node_svc).init().start() as node: