Source code for taskhawk.consumer

import itertools
import json
import logging
import threading
import typing

import boto3
import boto3.resources.base
import boto3.resources.model

from taskhawk.conf import settings
from taskhawk.exceptions import RetryException, LoggingException, ValidationError, IgnoreException
from taskhawk.models import Message
from taskhawk import Priority


WAIT_TIME_SECONDS = 20  # Maximum allowed by SQS

logger = logging.getLogger(__name__)


def _get_sqs_resource():
    return boto3.resource(
        'sqs',
        region_name=settings.AWS_REGION,
        aws_access_key_id=settings.AWS_ACCESS_KEY,
        aws_secret_access_key=settings.AWS_SECRET_KEY,
        aws_session_token=settings.AWS_SESSION_TOKEN,
        endpoint_url=settings.AWS_ENDPOINT_SQS,
    )


def get_queue(queue_name: str):
    sqs = _get_sqs_resource()
    return sqs.get_queue_by_name(QueueName=queue_name)


def log_received_message(message_body: dict) -> None:
    logger.debug('Received message', extra={'message_body': message_body})


def log_invalid_message(message_json: str) -> None:
    logger.error('Received invalid message', extra={'message_json': message_json})


def message_handler(message_json: str, receipt: typing.Optional[str]) -> None:
    try:
        message_body = json.loads(message_json)
        message = Message(message_body)
    except (ValidationError, ValueError):
        log_invalid_message(message_json)
        raise

    log_received_message(message_body)

    try:
        message.call_task(receipt)
    except IgnoreException:
        logger.info(f'Ignoring task {message.id}')
        return
    except LoggingException as e:
        # log with message and extra
        logger.exception(str(e), extra=e.extra)
        # let it bubble up so message ends up in DLQ
        raise
    except RetryException:
        # Retry without logging exception
        logger.info('Retrying due to exception')
        # let it bubble up so message ends up in DLQ
        raise
    except Exception:
        logger.exception(f'Exception while processing message')
        # let it bubble up so message ends up in DLQ
        raise


def message_handler_sqs(queue_message) -> None:
    message_json = queue_message.body
    receipt = queue_message.receipt_handle

    message_handler(message_json, receipt)


def message_handler_lambda(lambda_record: dict) -> None:
    message_json = lambda_record['Sns']['Message']
    receipt = None

    message_handler(message_json, receipt)


def get_queue_messages(queue, num_messages: int, visibility_timeout: int = None) -> list:
    params = {
        'MaxNumberOfMessages': num_messages,
        'WaitTimeSeconds': WAIT_TIME_SECONDS,
        'MessageAttributeNames': ['All'],
    }
    if visibility_timeout is not None:
        params['VisibilityTimeout'] = visibility_timeout
    return queue.receive_messages(**params)


def get_queue_name(priority: Priority) -> str:
    name = f'TASKHAWK-{settings.TASKHAWK_QUEUE.upper()}'
    if priority is Priority.high:
        name += '-HIGH-PRIORITY'
    elif priority is Priority.low:
        name += '-LOW-PRIORITY'
    elif priority is Priority.bulk:
        name += '-BULK'
    return name


def fetch_and_process_messages(queue_name: str, queue, num_messages: int = 1, visibility_timeout: int = None) -> None:

    for queue_message in get_queue_messages(queue, num_messages, visibility_timeout=visibility_timeout):
        settings.TASKHAWK_PRE_PROCESS_HOOK(queue_name=queue_name, sqs_queue_message=queue_message)

        try:
            message_handler_sqs(queue_message)
            try:
                settings.TASKHAWK_POST_PROCESS_HOOK(queue_name=queue_name, sqs_queue_message=queue_message)
            except Exception:
                logger.exception(f'Exception in post process hook for message from {queue_name}')
                raise
            try:
                queue_message.delete()
            except Exception:
                logger.exception(f'Exception while deleting message from {queue_name}')
        except Exception:
            # already logged in message_handler
            pass


[docs]def process_messages_for_lambda_consumer(lambda_event: dict) -> None: """ Process messages for a Taskhawk consumer Lambda app, and calls the task function with given `args` and `kwargs` If the task function accepts a param called `metadata`, it'll be passed in with a dict containing the metadata fields: id, timestamp, version, receipt. In case of an exception, the message is kept on Lambda's retry queue and processed again a fixed number of times. If the task function keeps failing, Lambda dead letter queue mechanism kicks in and the message is moved to the dead-letter queue. """ for record in lambda_event['Records']: settings.TASKHAWK_PRE_PROCESS_HOOK(sns_record=record) message_handler_lambda(record) settings.TASKHAWK_POST_PROCESS_HOOK(sns_record=record)
[docs]def listen_for_messages( priority: Priority, num_messages: int = 1, visibility_timeout_s: int = None, loop_count: int = None, shutdown_event: threading.Event = None, ) -> None: """ Starts a taskhawk listener for message types provided and calls the task function with given `args` and `kwargs`. If the task function accepts a param called `metadata`, it'll be passed in with a dict containing the metadata fields: id, timestamp, version, receipt. The message is explicitly deleted only if task function ran successfully. In case of an exception the message is kept on queue and processed again. If the task function keeps failing, SQS dead letter queue mechanism kicks in and the message is moved to the dead-letter queue. This function is blocking by default. It may be run for specific number of loops by passing `loop_count`. It may also be stopped by passing a shut down event object which can be set to stop the function. :param priority: The priority queue to listen to :param num_messages: Maximum number of messages to fetch in one SQS API call. Defaults to 1 :param visibility_timeout_s: The number of seconds the message should remain invisible to other queue readers. Defaults to None, which is queue default :param loop_count: How many times to fetch messages from SQS. Default to None, which means loop forever. :param shutdown_event: An event to signal that the process should shut down. This prevents more messages from being de-queued and function exits after the current messages have been processed. """ if not shutdown_event: shutdown_event = threading.Event() queue_name = get_queue_name(priority) queue = get_queue(queue_name) for count in itertools.count(): if (loop_count is None or count < loop_count) and not shutdown_event.is_set(): fetch_and_process_messages( queue_name, queue, num_messages=num_messages, visibility_timeout=visibility_timeout_s ) else: break