import asyncio
import functools
import logging
import ssl
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta
from operator import attrgetter
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from urllib.parse import urlparse
from uuid import uuid4
import aioredis
from aioredis import ConnectionPool, Redis
from aioredis.exceptions import ResponseError, WatchError
from aioredis.sentinel import Sentinel
from pydantic.validators import make_arbitrary_type_validator
from .constants import default_queue_name, job_key_prefix, result_key_prefix
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .parser import ContextAwareDefaultParser, ContextAwareEncoder, encoder_options_var # type: ignore
from .utils import timestamp_ms, to_ms, to_unix_ms
logger = logging.getLogger('arku.connections')
[docs]class SSLContext(ssl.SSLContext):
"""
Required to avoid problems with
"""
@classmethod
def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]:
yield make_arbitrary_type_validator(ssl.SSLContext)
[docs]@dataclass
class RedisSettings:
"""
No-Op class used to hold redis connection redis_settings.
Used by :func:`arku.connections.create_pool` and :class:`arku.worker.Worker`.
"""
host: Union[str, List[Tuple[str, int]]] = 'localhost'
port: int = 6379
database: int = 0
password: Optional[str] = None
ssl: Union[bool, None, SSLContext] = None
conn_timeout: int = 1
conn_retries: int = 5
conn_retry_delay: int = 1
sentinel: bool = False
sentinel_master: str = 'mymaster'
@classmethod
def from_dsn(cls, dsn: str) -> 'RedisSettings':
conf = urlparse(dsn)
assert conf.scheme in {'redis', 'rediss'}, 'invalid DSN scheme'
return RedisSettings(
host=conf.hostname or 'localhost',
port=conf.port or 6379,
ssl=conf.scheme == 'rediss',
password=conf.password,
database=int((conf.path or '0').strip('/')),
)
def __repr__(self) -> str:
return 'RedisSettings({})'.format(', '.join(f'{k}={v!r}' for k, v in self.__dict__.items()))
# extra time after the job is expected to start when the job key should expire, 1 day in ms
expires_extra_ms = 86_400_000
[docs]class ArkuRedis(Redis):
"""
Thin subclass of ``aioredis.Redis`` which adds :func:`arku.connections.enqueue_job`.
:param redis_settings: an instance of ``arku.connections.RedisSettings``.
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
:param default_queue_name: the default queue name to use, defaults to ``arku.queue``.
:param kwargs: keyword arguments directly passed to ``aioredis.Redis``.
"""
def __init__(
self,
pool_or_conn: Optional[ConnectionPool] = None,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
**kwargs: Any,
) -> None:
self.job_serializer = job_serializer
self.job_deserializer = job_deserializer
self.default_queue_name = default_queue_name
if pool_or_conn:
kwargs['connection_pool'] = pool_or_conn
super().__init__(**kwargs)
self.connection_pool.connection_kwargs['parser_class'] = ContextAwareDefaultParser
self.connection_pool.connection_kwargs['encoder_class'] = ContextAwareEncoder
@contextmanager
def encoder_context(self, **options: Any) -> Generator['ArkuRedis', None, None]:
token = encoder_options_var.set(options)
yield self
encoder_options_var.reset(token)
[docs] async def enqueue_job(
self,
function: str,
*args: Any,
_job_id: Optional[str] = None,
_queue_name: Optional[str] = None,
_defer_until: Optional[datetime] = None,
_defer_by: Union[None, int, float, timedelta] = None,
_expires: Union[None, int, float, timedelta] = None,
_job_try: Optional[int] = None,
**kwargs: Any,
) -> Optional[Job]:
"""
Enqueue a job.
:param function: Name of the function to call
:param args: args to pass to the function
:param _job_id: ID of the job, can be used to enforce job uniqueness
:param _queue_name: queue of the job, can be used to create job in different queue
:param _defer_until: datetime at which to run the job
:param _defer_by: duration to wait before running the job
:param _expires: if the job still hasn't started after this duration, do not run it
:param _job_try: useful when re-enqueueing jobs within a job
:param kwargs: any keyword arguments to pass to the function
:return: :class:`arku.jobs.Job` instance or ``None`` if a job with this ID already exists
"""
if _queue_name is None:
_queue_name = self.default_queue_name
job_id = _job_id or uuid4().hex
job_key = job_key_prefix + job_id
assert not (_defer_until and _defer_by), "use either 'defer_until' or 'defer_by' or neither, not both"
defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)
async with self as conn:
pipe = conn.pipeline()
await pipe.unwatch()
await pipe.watch(job_key)
job_exists = await pipe.exists(job_key)
job_result_exists = await pipe.exists(result_key_prefix + job_id)
if job_exists or job_result_exists:
await pipe.reset()
return None
enqueue_time_ms = timestamp_ms()
if _defer_until is not None:
score = to_unix_ms(_defer_until)
elif defer_by_ms:
score = enqueue_time_ms + defer_by_ms
else:
score = enqueue_time_ms
expires_ms = expires_ms or score - enqueue_time_ms + expires_extra_ms
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
pipe.multi()
pipe.psetex(job_key, expires_ms, job)
pipe.zadd(_queue_name, {job_id: score})
try:
await pipe.execute()
except (ResponseError, WatchError):
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
async def _get_job_result(self, key: Union[str]) -> JobResult:
job_id = key[len(result_key_prefix) :]
job = Job(job_id, self, _deserializer=self.job_deserializer)
r = await job.result_info()
if r is None:
raise KeyError(f'job "{key}" not found')
r.job_id = job_id
return r
[docs] async def all_job_results(self) -> List[JobResult]:
"""
Get results for all jobs in redis.
"""
with self.encoder_context(decode_responses=True):
keys = await self.keys(result_key_prefix + '*')
results = await asyncio.gather(*[self._get_job_result(k) for k in keys])
return sorted(results, key=attrgetter('enqueue_time'))
async def _get_job_def(self, job_id: str, score: int) -> JobDef:
v = await self.get(job_key_prefix + job_id)
jd = deserialize_job(v, deserializer=self.job_deserializer)
jd.score = score
return jd
[docs] async def queued_jobs(self, *, queue_name: str = default_queue_name) -> List[JobDef]:
"""
Get information about queued, mostly useful when testing.
"""
with self.encoder_context(decode_responses=True):
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])
[docs]async def create_pool(
settings_: RedisSettings = None,
*,
retry: int = 0,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
) -> ArkuRedis:
"""
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
Similar to ``aioredis.create_redis_pool`` except it returns a :class:`arku.connections.ArkuRedis` instance,
thus allowing job enqueuing.
"""
settings: RedisSettings = RedisSettings() if settings_ is None else settings_
assert not (
type(settings.host) is str and settings.sentinel
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
if settings.sentinel:
def pool_factory(*args: Any, **kwargs: Any) -> ArkuRedis:
client = Sentinel(*args, sentinels=settings.host, ssl=settings.ssl, **kwargs) # type: ignore
return client.master_for(settings.sentinel_master, redis_class=ArkuRedis)
else:
pool_factory = functools.partial(
ArkuRedis,
host=settings.host,
port=settings.port,
socket_connect_timeout=settings.conn_timeout,
ssl=settings.ssl,
)
try:
pool = pool_factory(db=settings.database, password=settings.password, encoding='utf8')
pool.job_serializer = job_serializer
pool.job_deserializer = job_deserializer
pool.default_queue_name = default_queue_name
await pool.ping()
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
if retry < settings.conn_retries:
logger.warning(
'redis connection error %s:%s %s %s, %d retries remaining...',
settings.host,
settings.port,
e.__class__.__name__,
e,
settings.conn_retries - retry,
)
await asyncio.sleep(settings.conn_retry_delay)
else:
raise
else:
if retry > 0:
logger.info('redis connection successful')
return pool
# recursively attempt to create the pool outside the except block to avoid
# "During handling of the above exception..." madness
return await create_pool(
settings,
retry=retry + 1,
job_serializer=job_serializer,
job_deserializer=job_deserializer,
default_queue_name=default_queue_name,
)
async def log_redis_info(redis: Redis, log_func: Callable[[str], Any]) -> None:
async with redis as r:
info_server, info_memory, info_clients, key_count = await asyncio.gather(
r.info(section='Server'),
r.info(section='Memory'),
r.info(section='Clients'),
r.dbsize(),
)
redis_version = info_server.get('redis_version', '?')
mem_usage = info_memory.get('used_memory_human', '?')
clients_connected = info_clients.get('connected_clients', '?')
log_func(
f'redis_version={redis_version} '
f'mem_usage={mem_usage} '
f'clients_connected={clients_connected} '
f'db_keys={key_count}'
)