From 46af39a2a3286401f87a92b3f488bd8b67184c05 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Wed, 17 Jul 2024 12:10:25 +0100 Subject: [PATCH] Use redis to store the role cache --- backend/authentication/backend.py | 2 +- backend/authentication/user.py | 4 +-- backend/discord.py | 43 ++++++++----------------------- backend/routes/discord.py | 4 +-- 4 files changed, 16 insertions(+), 37 deletions(-) diff --git a/backend/authentication/backend.py b/backend/authentication/backend.py index c84ba10..e150580 100644 --- a/backend/authentication/backend.py +++ b/backend/authentication/backend.py @@ -68,6 +68,6 @@ async def authenticate( if await user.fetch_admin_status(request.state.db): scopes.append("admin") - scopes.extend(await user.get_user_roles(request.state.db)) + scopes.extend(await user.get_user_roles()) return authentication.AuthCredentials(scopes), user diff --git a/backend/authentication/user.py b/backend/authentication/user.py index ad59103..5e99546 100644 --- a/backend/authentication/user.py +++ b/backend/authentication/user.py @@ -44,12 +44,12 @@ def user_id(self) -> str: def decoded_token(self) -> dict[str, any]: return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"]) - async def get_user_roles(self, database: Database) -> list[str]: + async def get_user_roles(self) -> list[str]: """Get a list of the user's discord roles.""" if not self.member: return [] - server_roles = await discord.get_roles(database) + server_roles = await discord.get_roles() roles = [role.name for role in server_roles if role.id in self.member.roles] if "admin" in roles: diff --git a/backend/discord.py b/backend/discord.py index 192fc60..4a1ecf5 100644 --- a/backend/discord.py +++ b/backend/discord.py @@ -1,11 +1,9 @@ """Various utilities for working with the Discord API.""" -import datetime import json import httpx import starlette.requests -from pymongo.database import Database from starlette import exceptions from backend import constants, models @@ -66,7 +64,6 @@ async def _get_role_info() -> list[models.DiscordRole]: async def get_roles( - database: Database, *, force_refresh: bool = False, ) -> list[models.DiscordRole]: @@ -75,35 +72,17 @@ async def get_roles( If `force_refresh` is True, the cache is skipped and the roles are updated. """ - collection = database.get_collection("roles") - - if force_refresh: - # Drop all values in the collection - await collection.delete_many({}) - - # `create_index` creates the index if it does not exist, or passes - # This handles TTL on role objects - await collection.create_index( - "inserted_at", - expireAfterSeconds=60 * 60 * 24, # 1 day - name="inserted_at", - ) - - roles = [models.DiscordRole(**json.loads(role["data"])) async for role in collection.find()] - - if len(roles) == 0: - # Fetch roles from the API and insert into the database - roles = await _get_role_info() - await collection.insert_many( - { - "name": role.name, - "id": role.id, - "data": role.json(), - "inserted_at": datetime.datetime.now(tz=datetime.UTC), - } - for role in roles - ) - + role_cache_key = "forms-backend:role_cache" + if not force_refresh: + roles = await constants.REDIS_CLIENT.hgetall(role_cache_key) + if roles: + return [ + models.DiscordRole(**json.loads(role_data)) for role_id, role_data in roles.items() + ] + + roles = await _get_role_info() + await constants.REDIS_CLIENT.hmset(role_cache_key, {role.id: role.json() for role in roles}) + await constants.REDIS_CLIENT.expire(role_cache_key, 60 * 60 * 24) # 1 day return roles diff --git a/backend/routes/discord.py b/backend/routes/discord.py index 196d902..5cd6b47 100644 --- a/backend/routes/discord.py +++ b/backend/routes/discord.py @@ -31,9 +31,9 @@ class RolesResponse(pydantic.BaseModel): resp=Response(HTTP_200=RolesResponse), tags=["roles"], ) - async def patch(self, request: Request) -> JSONResponse: + async def patch(self, request: Request) -> JSONResponse: # noqa: ARG002 Request is required by @requires """Refresh the roles database.""" - roles = await discord.get_roles(request.state.db, force_refresh=True) + roles = await discord.get_roles(force_refresh=True) return JSONResponse( {"roles": [role.dict() for role in roles]},