Source code for dynamic_batcher.batcher

"""
====================================
 :mod:`batcher` Module
====================================
.. moduleauthor:: Youngju Jaden Kim <pydemia@gmail.com>
.. note:: Info

Info
====
    `DynamicBatcher` and `BatchProcessor`

"""


from typing import Optional, List, Dict, Callable, NamedTuple
import os
import json
import logging
import asyncio
import redis
from autologging import logged
from . import logger

from .redis_engine import (
    REDIS__HOST,
    REDIS__PORT,
    REDIS__DB,
    REDIS__PASSWORD,
    REDIS__STREAM_KEY_REQUEST,
    REDIS__STREAM_GROUP_PROCESSOR,
    REDIS__STREAM_KEY_RESPONSE,
    REDIS__STREAM_GROUP_BATCHER,
    get_client,
)
from .types import ResponseStream, PendingRequestStream


__all__ = [
    "DynamicBatcher",
    "BatchProcessor",
]


DYNAMIC_BATCHER__BATCH_SIZE = int(os.getenv("DYNAMIC_BATCHER__BATCH_SIZE", "64"))
DYNAMIC_BATCHER__BATCH_TIME = int(os.getenv("DYNAMIC_BATCHER__BATCH_TIME", "2"))


# logging.config.dictConfig(CONFIG_DEFAULTS)
# logger.Logger(level="DEBUG")


[docs] @logged class DynamicBatcher: """A Client class for dynamic batch processing. A `DynamicBatcher` tries to connect a redis server with connection info., given by the following ``ENVVAR``: .. code-block:: bash REDIS__HOST=localhost REDIS__PORT=6379 REDIS__DB=0 REDIS__PASSWORD= Args: delay (int): Seconds of frequency to parse a response, corresponding a request sent. Defaults to ``0.01``. timeout (int): Seconds of deadline to wait for a response. Defaults to ``100``. If `timeout` is too large, it will be stuck on waiting too long, which is not intended. If `timeout` is too small, it will work as impatient, not waiting for the batch process is finished. Attributes: delay (int): Seconds of frequency to parse a response, corresponding a request sent. timeout (int): Seconds of deadline to wait for a response. Example: Create a `batcher`: >>> from dynamic_batcher import DynamicBatcher >>> batcher = DynamicBatcher() You can give some parameters: >>> lazy_batcher = DynamicBatcher(delay=1) Or, create a fail-fast `batcher`: >>> fail_fast_batcher = DynamicBatcher(timeout=3) Raises: redis.exceptions.ConnectionError: a redis server is not available. """ def __init__( self, # host: str = "localhost", # port: int = 6379, # db: int = 0, # password: Optional[str] = None, # key="infer", # group="infergrp", delay: int = 0.01, timeout: int = 100, ): self.log = self.__log or logging.getLogger(self.__class__.__qualname__) self._redis_client = get_client( host=REDIS__HOST, port=REDIS__PORT, db=REDIS__DB, password=REDIS__PASSWORD, ) self._request_key: str = REDIS__STREAM_KEY_REQUEST self._response_key: str = REDIS__STREAM_KEY_RESPONSE self._processor_group: str = REDIS__STREAM_GROUP_PROCESSOR self._batcher_group: str = REDIS__STREAM_GROUP_BATCHER self.delay = delay self.timeout = timeout
[docs] async def asend(self, body: Dict|List, *args, **kwargs) -> Optional[Dict|List]: """Send a request and wait for a response, with JSON-serializable body. Args: body (:obj: ``Dict`` or ``List``): A **JSON-serializable object**, especially ``Dict`` or ``List``. \*args: Variable length argument list. \**kwargs: Arbitrary keyword arguments. Returns: Dict or List: optional Example: >>> import time >>> import uvicorn >>> from typing import List >>> from fastapi import FastAPI >>> from pydantic import BaseModel >>> from dynamic_batcher import DynamicBatcher >>> >>> app = FastAPI() >>> batcher = DynamicBatcher() >>> class RequestItem(BaseModel): ... key: str ... values: List[int] = [1, 5, 2] >>> >>> @app.post("/batch/{key}") >>> async def run_batch(key: str, body: RequestItem): ... start_time = time.time() ... resp_body = await batcher.asend(body.model_dump()) ... result = { ... "key": key, ... "values": body.values, ... "values_sum": resp_body, ... "elapsed": time.time() - start_time ... } ... return result >>> >>> if __name__ == "__main__": >>> uvicorn.run(app) INFO: Started server process [27085] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) """ try: json_body = json.dumps(body) except json.JSONDecodeError as json_e: self.log.error(f"cannot serialize request body: {json_e}\n{json_e.with_traceback}") return try: requested_stream_id: bytes = self._redis_client.xadd(self._request_key, {"body": json_body}) r = await self._wait_for_start(requested_stream_id, delay=self.delay, timeout=self.timeout) r = await self._wait_for_finish(requested_stream_id, delay=self.delay, timeout=self.timeout) return r.body except redis.RedisError as redis_e: self.log.error(f"redis not available: {redis_e}\n{redis_e.with_traceback}") return except json.JSONDecodeError as json_e: self.log.error(f"cannot de-serialize response body: {json_e}\n{json_e.with_traceback}") return except Exception as unknown_e: self.log.error(f"failed to respond (unknown): {unknown_e}") return
async def _wait_for_start(self, stream_id: bytes, delay: int = 0.1, timeout=10) -> Optional[ResponseStream]: is_accepted = False total_delay = 0 while is_accepted or (total_delay < timeout): r = self._get_request_accepted(stream_id) if r: is_accepted = True break await asyncio.sleep(delay) total_delay += delay return r async def _wait_for_finish(self, stream_id: bytes, delay: int = 0.1, timeout=10) -> Optional[ResponseStream]: is_arrived = False total_delay = 0 while is_arrived or (total_delay < timeout): r = self._get_response_arrived_as_record(stream_id) if r: is_arrived = True break await asyncio.sleep(delay) total_delay += delay return r def _get_request_accepted(self, stream_id: bytes) -> Optional[bytes]: messages: List = self._redis_client.xpending_range( self._request_key, groupname=self._processor_group, count=1, min=stream_id, max=stream_id, ) if messages: message = PendingRequestStream(**messages[0]) pending_time = (message.time_since_delivered - message.times_delivered) / 1000 return message.message_id else: return def _get_response_arrived_as_record(self, stream_id: bytes) -> Optional[ResponseStream]: message: Optional[str] = self._redis_client.get(stream_id) if message: _body = json.loads(message) self._redis_client.delete(stream_id) return ResponseStream(stream_id, _body) else: return def _get_response_arrived_as_stream(self, stream_id: bytes) -> Optional[ResponseStream]: message: Optional[Dict] = self._redis_client.get(stream_id) messages: List = self._redis_client.xrange( self._response_key, min=stream_id, max=stream_id, ) if messages: message = messages[0] _id, _body = message self._redis_client.xack(self._response_key, self._batcher_group, _id) self._redis_client.xdel(self._response_key, _id) return ResponseStream(_id, _body) else: return
[docs] class BatchProcessor: """A Client class for dynamic batch processing. A `BatchProcessor` tries to connect a redis server with connection info., given by the following ``ENVVAR``: .. code-block:: bash REDIS__HOST=localhost REDIS__PORT=6379 REDIS__DB=0 REDIS__PASSWORD= DYNAMIC_BATCHER__BATCH_SIZE=64 DYNAMIC_BATCHER__BATCH_TIME=2 Args: batch_size (int): Number of requests for a batch. Defaults to ``64``. If ``DYNAMIC_BATCHER__BATCH_SIZE`` is set, the argument default value is overrided. When the argument value is passed, all other settings are ignored. Priority:: values passed > ENVVAR > default value batch_time (int): Seconds of deadline to wait for requests. Defaults to ``2``. If `timeout` is too large, it will be stuck on waiting too long, which is not intended. If `timeout` is too small, it will work as impatient, not waiting for the batch process is finished. Attributes: delay (int): Seconds of frequency to parse a request. batch_size (int): Number of requests for a batch. batch_time (int): Seconds of deadline to wait for requests. Example: Create a `processor`: >>> import asyncio >>> from dynamic_batcher import BatchProcessor >>> processor = BatchProcessor() >>> asyncio.run(batch_processor.start_daemon(lambda x: x)) Raises: redis.exceptions.ConnectionError: a redis server is not available. """ def __init__( self, batch_size: int = DYNAMIC_BATCHER__BATCH_SIZE or 64, batch_time: int = DYNAMIC_BATCHER__BATCH_TIME or 2, ): self.log = logging.getLogger(logger.LOGGERNAME_BATCHPROCESSOR) self.log.info("LOG_LEVEL: %s", logging.getLevelName(self.log.level)) self.delay = 0.001 self.batch_size = batch_size self.batch_time = batch_time self._redis_client = get_client( host=REDIS__HOST, port=REDIS__PORT, db=REDIS__DB, password=REDIS__PASSWORD, ) self._request_key = REDIS__STREAM_KEY_REQUEST self._response_key = REDIS__STREAM_KEY_RESPONSE self._processor_group = REDIS__STREAM_GROUP_PROCESSOR self._batcher_group = REDIS__STREAM_GROUP_BATCHER self.response_expiration_sec = 600
[docs] async def start_daemon(self, func: Callable) -> None: """Start a single batch process as a daemon. This will concatenate given requests to one batch, call `func`, and split into corresponding responses. Args: func (:obj: `Callable`): A callable object, like function or method. `func` should have only one positional argument, and its type should be ``List``; to handle the argument as a scalable batch. The type of the argument and the returning value should be ``List``, to handle a scalable batch and operate elementwisely. Also both argument and returning value should be **JSON (de)serializable**. Returns: None Example: First, define a function to run: >>> import asyncio >>> from dynamic_batcher import BatchProcessor >>> from typing import List, Dict >>> body_list = [ ... {'values': [1, 2, 3]}, ... {'values': [4, 5, 6]} ... ] >>> def sum_values(bodies: List[Dict]) -> List[Dict]: ... result = [] ... for body in bodies: ... result.append( { 'sum': sum(body['values']) } ) ... return result >>> sum_values(body_list) [{'sum': 6}, {'sum': 15}] Then, run a ``BatchProcessor``: >>> import asyncio >>> from dynamic_batcher import BatchProcessor >>> batch_processor = BatchProcessor() >>> asyncio.run(batch_processor.start_daemon(sum_values)) """ self.log.info( 'BatchProcessor start: ' + ', '.join([ f'delay={self.delay}', f'batch_size={self.batch_size}', f'batch_time={self.batch_time}', ]) ) while True: await self._run(func) await self._trim()
async def _run(self, func: Callable) -> None: delay_period = 0 batch_gathered = 0 requests = [] while delay_period < self.batch_time and batch_gathered < self.batch_size: # self.log.debug( # f'alive: {delay_period:.3f}/{self.batch_time}, {batch_gathered}/{self.batch_size}' # ) new_request = await self._get_next_request() if new_request: requests.extend(new_request) batch_gathered = len(requests) delay_period += self.delay await asyncio.sleep(self.delay) if requests: self.log.debug( f'batch start: {delay_period:.3f}/{self.batch_time}, {batch_gathered}/{self.batch_size}' ) streams = [v[0] for i, v in requests] streams = sorted(streams, key=lambda x: x[0]) stream_ids = [i for i, v in streams] stream_bodies = [json.loads(v['body']) for i, v in streams] try: results = func(stream_bodies) except Exception as e: self.log.error(f'Error while running `{func.__name__}`: {e}') results = None await self._mark_as_finished_as_record(stream_ids, results) async def _mark_as_finished_as_record(self, stream_ids: List[str], results: Optional[List[Dict]]) -> None: for stream_id, stream_body in zip(stream_ids, results): self._redis_client.set( stream_id, json.dumps(stream_body), ex=self.response_expiration_sec, ) self._redis_client.xdel(self._request_key, *stream_ids) async def _mark_as_finished_as_stream(self, stream_ids: List[str], results: Optional[List[Dict]]) -> None: if results is None: results = [None for i in stream_ids] for stream_id, stream_body in zip(stream_ids, results): self._redis_client.xadd( self._response_key, stream_body, stream_id, ) # self._redis_client.xack(self.request_key, self.processor_group, *stream_ids) self._redis_client.xdel(self._request_key, *stream_ids) async def _get_next_request(self) -> Optional[List]: try: requests: List = self._redis_client.xreadgroup( groupname=self._processor_group, consumername=self._processor_group, streams={self._request_key: '>'}, count=1, # block=self.batch_time, noack=False, ) if requests: return requests else: return except Exception as e: self.log.error(f'Error while reading message {e}') async def _trim(self) -> None: try: remaining_msg_cnt = self.batch_size * 10 request_trimmed_cnt: int = self._redis_client.xtrim( self._request_key, maxlen=remaining_msg_cnt, ) self.log.debug(f'Trimmed old requests: {request_trimmed_cnt}') response_trimmed_cnt: int = self._redis_client.xtrim( self._response_key, maxlen=remaining_msg_cnt, ) self.log.debug(f'Trimmed old responses: {response_trimmed_cnt}') except Exception as e: self.log.error(f'Error while trimming message {e}' )