From 649478b253b698411f1ea65fd858389c151b2e41 Mon Sep 17 00:00:00 2001 From: Albin Stjerna Date: Wed, 11 Jul 2018 09:09:51 +0200 Subject: [PATCH] Implement async contest manager interface on AsyncElasticsearch Features: - updated README with the new syntax (first example) - test code in a separate file, because it has to use Python 3.5+ syntax --- README | 11 ++++------- elasticsearch_async/__init__.py | 15 +++++++++++++-- test_elasticsearch_async/test_transport.py | 13 ++++++++++--- test_elasticsearch_async/test_with_syntax.py | 15 +++++++++++++++ 4 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 test_elasticsearch_async/test_with_syntax.py diff --git a/README b/README index c7784c5..21a96c1 100644 --- a/README +++ b/README @@ -11,17 +11,14 @@ Example:: import asyncio from elasticsearch_async import AsyncElasticsearch + hosts = ['localhost', 'other-host'] - client = AsyncElasticsearch(hosts=['localhost', 'other-host']) - - @asyncio.coroutine - def print_info(): - info = yield from client.info() - print(info) + async def print_info(): + async with AsyncElasticsearch(hosts=hosts) as client: + print(await client.info()) loop = asyncio.get_event_loop() loop.run_until_complete(print_info()) - loop.run_until_complete(client.transport.close()) loop.close() diff --git a/elasticsearch_async/__init__.py b/elasticsearch_async/__init__.py index 3c27411..1549842 100644 --- a/elasticsearch_async/__init__.py +++ b/elasticsearch_async/__init__.py @@ -1,8 +1,19 @@ -from .transport import AsyncTransport -from .connection import AIOHttpConnection +import asyncio from elasticsearch import Elasticsearch +from .connection import AIOHttpConnection +from .transport import AsyncTransport + + class AsyncElasticsearch(Elasticsearch): def __init__(self, hosts=None, transport_class=AsyncTransport, **kwargs): super().__init__(hosts, transport_class=transport_class, **kwargs) + + @asyncio.coroutine + def __aenter__(self): + return self + + @asyncio.coroutine + def __aexit__(self, _exc_type, _exc_val, _exc_tb): + yield from self.transport.close() diff --git a/test_elasticsearch_async/test_transport.py b/test_elasticsearch_async/test_transport.py index d56f93a..5c6e2d5 100644 --- a/test_elasticsearch_async/test_transport.py +++ b/test_elasticsearch_async/test_transport.py @@ -4,11 +4,13 @@ from elasticsearch_async import AsyncElasticsearch + @mark.asyncio def test_sniff_on_start_sniffs(server, event_loop, port, sniff_data): server.register_response('/_nodes/_all/http', sniff_data) - client = AsyncElasticsearch(port=port, sniff_on_start=True, loop=event_loop) + client = AsyncElasticsearch( + port=port, sniff_on_start=True, loop=event_loop) # sniff has been called in the background assert client.transport.sniffing_task is not None @@ -21,10 +23,15 @@ def test_sniff_on_start_sniffs(server, event_loop, port, sniff_data): assert 'http://node1:9200' == connections[0].host yield from client.transport.close() + @mark.asyncio def test_retry_will_work(port, server, event_loop): - client = AsyncElasticsearch(hosts=['not-an-es-host', 'localhost'], port=port, loop=event_loop, randomize_hosts=False) + client = AsyncElasticsearch( + hosts=['not-an-es-host', 'localhost'], + port=port, + loop=event_loop, + randomize_hosts=False) data = yield from client.info() - assert {'body': '', 'method': 'GET', 'params': {}, 'path': '/'} == data + assert {'body': '', 'method': 'GET', 'params': {}, 'path': '/'} == data yield from client.transport.close() diff --git a/test_elasticsearch_async/test_with_syntax.py b/test_elasticsearch_async/test_with_syntax.py new file mode 100644 index 0000000..d196380 --- /dev/null +++ b/test_elasticsearch_async/test_with_syntax.py @@ -0,0 +1,15 @@ +import asyncio +import sys + +from pytest import mark + +from elasticsearch_async import AsyncElasticsearch + + +@mark.skipif(sys.version_info < (3, 5), reason="async with is Python 3.5+") +@mark.asyncio +async def test_with_notation_works(port, server, event_loop): + async with AsyncElasticsearch( + hosts=['localhost'], port=port, loop=event_loop) as client: + info = await client.info() + assert info