123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- # Copyright © 2023 Ingram Micro Inc. All rights reserved.
- import logging
- import time
- from datetime import timedelta
- from socket import gaierror
- from urllib.parse import unquote, urlparse
- import ujson
- from django.conf import settings
- from django.utils import timezone
- from pika import (
- BasicProperties,
- BlockingConnection,
- ConnectionParameters,
- credentials,
- exceptions,
- )
- from pika.adapters.utils.connection_workflow import AMQPConnectorException
- from dj_cqrs.constants import DEFAULT_DEAD_MESSAGE_TTL, SignalType
- from dj_cqrs.controller import consumer
- from dj_cqrs.dataclasses import TransportPayload
- from dj_cqrs.delay import DelayMessage, DelayQueue
- from dj_cqrs.registries import ReplicaRegistry
- from dj_cqrs.transport import BaseTransport
- from dj_cqrs.transport.mixins import LoggingMixin
- from dj_cqrs.utils import get_delay_queue_max_size, get_messages_prefetch_count_per_worker
- logger = logging.getLogger('django-cqrs')
- class RabbitMQTransport(LoggingMixin, BaseTransport):
- """Transport class for RabbitMQ."""
- CONSUMER_RETRY_TIMEOUT = 5
- PRODUCER_RETRIES = 1
- _producer_connection = None
- _producer_channel = None
- @classmethod
- def clean_connection(cls):
- """Clean the RabbitMQ connection."""
- connection = cls._producer_connection
- if connection and not connection.is_closed:
- try:
- connection.close()
- except (exceptions.StreamLostError, exceptions.ConnectionClosed, ConnectionError):
- logger.warning('Connection was closed or is closing. Skip it...')
- cls._producer_connection = None
- cls._producer_channel = None
- @classmethod
- def consume(cls, cqrs_ids=None):
- """Receive data from master model.
- Args:
- cqrs_ids (str): cqrs ids.
- """
- consumer_rabbit_settings = cls._get_consumer_settings()
- common_rabbit_settings = cls._get_common_settings()
- while True:
- connection = None
- try:
- delay_queue = DelayQueue(max_size=get_delay_queue_max_size())
- connection, channel, consumer_generator = cls._get_consumer_rmq_objects(
- *(common_rabbit_settings + consumer_rabbit_settings),
- cqrs_ids=cqrs_ids,
- )
- for method_frame, properties, body in consumer_generator:
- if method_frame is not None:
- cls._consume_message(
- channel,
- method_frame,
- properties,
- body,
- delay_queue,
- )
- cls._process_delay_messages(channel, delay_queue)
- except (
- exceptions.AMQPError,
- exceptions.ChannelError,
- exceptions.ReentrancyError,
- gaierror,
- ):
- logger.warning('AMQP connection error. Reconnecting...', exc_info=True)
- time.sleep(cls.CONSUMER_RETRY_TIMEOUT)
- finally:
- if connection and not connection.is_closed:
- connection.close()
- @classmethod
- def produce(cls, payload):
- """
- Send data from master model to replicas.
- Args:
- payload (dj_cqrs.dataclasses.TransportPayload): Transport payload from master model.
- """
- cls._produce_with_retries(payload, retries=cls.PRODUCER_RETRIES)
- @classmethod
- def _produce_with_retries(cls, payload, retries):
- try:
- rmq_settings = cls._get_common_settings()
- exchange = rmq_settings[-1]
- # Decided not to create context-manager to stay within the class
- _, channel = cls._get_producer_rmq_objects(
- *rmq_settings,
- signal_type=payload.signal_type,
- )
- cls._produce_message(channel, exchange, payload)
- cls.log_produced(payload)
- except (
- exceptions.AMQPError,
- exceptions.ChannelError,
- exceptions.ReentrancyError,
- AMQPConnectorException,
- AssertionError,
- ) as e:
- # in case of any error - close connection and try to reconnect
- cls.clean_connection()
- base_log_message = "CQRS couldn't be published: pk = {0} ({1}).".format(
- payload.pk,
- payload.cqrs_id,
- )
- if not retries:
- logger.exception(base_log_message)
- return
- logger.warning(
- '{0} Error: {1}. Reconnect...'.format(
- base_log_message,
- e.__class__.__name__,
- ),
- )
- cls._produce_with_retries(payload, retries - 1)
- @classmethod
- def _consume_message(cls, ch, method, properties, body, delay_queue):
- try:
- dct = ujson.loads(body)
- except ValueError:
- logger.error("CQRS couldn't be parsed: {0}.".format(body))
- ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
- return
- required_keys = {'instance_pk', 'signal_type', 'cqrs_id', 'instance_data'}
- for key in required_keys:
- if key not in dct:
- msg = "CQRS couldn't proceed, %s isn't found in body: %s."
- logger.error(msg, key, body)
- ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
- return
- payload = TransportPayload.from_message(dct)
- cls.log_consumed(payload)
- delivery_tag = method.delivery_tag
- if payload.is_expired():
- cls._add_to_dead_letter_queue(ch, payload)
- cls._nack(ch, delivery_tag)
- return
- instance, exception = None, None
- try:
- instance = consumer.consume(payload)
- except Exception as e:
- exception = e
- logger.error('CQRS service exception', exc_info=True)
- if instance and exception is None:
- cls._ack(ch, delivery_tag, payload)
- else:
- cls._fail_message(
- ch,
- delivery_tag,
- payload,
- exception,
- delay_queue,
- )
- @classmethod
- def _fail_message(cls, channel, delivery_tag, payload, exception, delay_queue):
- cls.log_consumed_failed(payload)
- model_cls = ReplicaRegistry.get_model_by_cqrs_id(payload.cqrs_id)
- if model_cls is None:
- logger.error('Model for cqrs_id {0} is not found.'.format(payload.cqrs_id))
- cls._nack(channel, delivery_tag)
- return
- if model_cls.should_retry_cqrs(payload.retries, exception):
- delay = model_cls.get_cqrs_retry_delay(payload.retries)
- cls._delay_message(channel, delivery_tag, payload, delay, delay_queue)
- else:
- cls._add_to_dead_letter_queue(channel, payload)
- cls._nack(channel, delivery_tag)
- @classmethod
- def _delay_message(cls, channel, delivery_tag, payload, delay, delay_queue):
- if delay_queue.full():
- # Memory limits handling, requeuing message with lowest ETA
- requeue_message = delay_queue.get()
- cls._requeue_message(
- channel,
- requeue_message.delivery_tag,
- requeue_message.payload,
- )
- eta = timezone.now() + timedelta(seconds=delay)
- delay_message = DelayMessage(delivery_tag, payload, eta)
- delay_queue.put(delay_message)
- cls.log_delayed(payload, delay, delay_message.eta)
- @classmethod
- def _add_to_dead_letter_queue(cls, channel, payload):
- replica_settings = settings.CQRS.get('replica', {})
- dead_message_ttl = DEFAULT_DEAD_MESSAGE_TTL
- if 'dead_message_ttl' in replica_settings:
- dead_message_ttl = replica_settings['dead_message_ttl']
- expiration = None
- if dead_message_ttl is not None:
- expiration = str(dead_message_ttl * 1000) # milliseconds
- payload.is_dead_letter = True
- exchange = cls._get_common_settings()[-1]
- cls._produce_message(channel, exchange, payload, expiration)
- cls.log_dead_letter(payload)
- @classmethod
- def _requeue_message(cls, channel, delivery_tag, payload):
- payload.retries += 1
- payload.is_requeue = True
- cls.produce(payload)
- cls._nack(channel, delivery_tag)
- cls.log_requeued(payload)
- @classmethod
- def _process_delay_messages(cls, channel, delay_queue):
- for delay_message in delay_queue.get_ready():
- cls._requeue_message(channel, delay_message.delivery_tag, delay_message.payload)
- @classmethod
- def _produce_message(cls, channel, exchange, payload, expiration=None):
- routing_key = cls._get_produced_message_routing_key(payload)
- channel.basic_publish(
- exchange=exchange,
- routing_key=routing_key,
- body=ujson.dumps(payload.to_dict()),
- mandatory=True,
- properties=BasicProperties(
- content_type='text/plain',
- delivery_mode=2, # make message persistent
- expiration=expiration,
- ),
- )
- @classmethod
- def _get_produced_message_routing_key(cls, payload):
- routing_key = payload.cqrs_id
- if payload.signal_type == SignalType.SYNC and payload.queue:
- routing_key = 'cqrs.{0}.{1}'.format(payload.queue, routing_key)
- elif getattr(payload, 'is_dead_letter', False):
- dead_letter_queue_name = cls._get_consumer_settings()[1]
- routing_key = 'cqrs.{0}.{1}'.format(dead_letter_queue_name, routing_key)
- elif getattr(payload, 'is_requeue', False):
- queue = cls._get_consumer_settings()[0]
- routing_key = 'cqrs.{0}.{1}'.format(queue, routing_key)
- return routing_key
- @classmethod
- def _get_consumer_rmq_objects(
- cls,
- host,
- port,
- creds,
- exchange,
- queue_name,
- dead_letter_queue_name,
- prefetch_count,
- cqrs_ids=None,
- ):
- connection = BlockingConnection(
- ConnectionParameters(host=host, port=port, credentials=creds),
- )
- channel = connection.channel()
- channel.basic_qos(prefetch_count=prefetch_count)
- cls._declare_exchange(channel, exchange)
- channel.queue_declare(queue_name, durable=True, exclusive=False)
- channel.queue_declare(dead_letter_queue_name, durable=True, exclusive=False)
- for cqrs_id, _ in ReplicaRegistry.models.items():
- if cqrs_ids and cqrs_id not in cqrs_ids:
- continue
- channel.queue_bind(exchange=exchange, queue=queue_name, routing_key=cqrs_id)
- # Every service must have specific SYNC or requeue routes
- channel.queue_bind(
- exchange=exchange,
- queue=queue_name,
- routing_key='cqrs.{0}.{1}'.format(queue_name, cqrs_id),
- )
- # Dead letter
- channel.queue_bind(
- exchange=exchange,
- queue=dead_letter_queue_name,
- routing_key='cqrs.{0}.{1}'.format(dead_letter_queue_name, cqrs_id),
- )
- delay_queue_check_timeout = 1 # seconds
- consumer_generator = channel.consume(
- queue=queue_name,
- auto_ack=False,
- exclusive=False,
- inactivity_timeout=delay_queue_check_timeout,
- )
- return connection, channel, consumer_generator
- @classmethod
- def _get_producer_rmq_objects(cls, host, port, creds, exchange, signal_type=None):
- """
- Use shared connection in case of sync mode, otherwise create new connection for each
- message
- """
- if signal_type == SignalType.SYNC:
- if cls._producer_connection is None:
- connection, channel = cls._create_connection(host, port, creds, exchange)
- cls._producer_connection = connection
- cls._producer_channel = channel
- return cls._producer_connection, cls._producer_channel
- else:
- return cls._create_connection(host, port, creds, exchange)
- @classmethod
- def _create_connection(cls, host, port, creds, exchange):
- connection = BlockingConnection(
- ConnectionParameters(
- host=host,
- port=port,
- credentials=creds,
- blocked_connection_timeout=10,
- ),
- )
- channel = connection.channel()
- channel.basic_qos(prefetch_count=get_messages_prefetch_count_per_worker())
- cls._declare_exchange(channel, exchange)
- return connection, channel
- @staticmethod
- def _declare_exchange(channel, exchange):
- channel.exchange_declare(
- exchange=exchange,
- exchange_type='topic',
- durable=True,
- )
- @staticmethod
- def _parse_url(url):
- scheme = urlparse(url).scheme
- assert scheme == 'amqp', 'Scheme must be "amqp" for RabbitMQTransport.'
- schemeless = url[len(scheme) + 3 :]
- parts = urlparse('http://' + schemeless)
- return (
- unquote(parts.hostname or '') or ConnectionParameters.DEFAULT_HOST,
- parts.port or ConnectionParameters.DEFAULT_PORT,
- unquote(parts.username or '') or ConnectionParameters.DEFAULT_USERNAME,
- unquote(parts.password or '') or ConnectionParameters.DEFAULT_PASSWORD,
- )
- @classmethod
- def _get_common_settings(cls):
- if 'url' in settings.CQRS:
- host, port, user, password = cls._parse_url(settings.CQRS.get('url'))
- else:
- host = settings.CQRS.get('host', ConnectionParameters.DEFAULT_HOST)
- port = settings.CQRS.get('port', ConnectionParameters.DEFAULT_PORT)
- user = settings.CQRS.get('user', ConnectionParameters.DEFAULT_USERNAME)
- password = settings.CQRS.get('password', ConnectionParameters.DEFAULT_PASSWORD)
- exchange = settings.CQRS.get('exchange', 'cqrs')
- return (
- host,
- port,
- credentials.PlainCredentials(user, password, erase_on_connect=True),
- exchange,
- )
- @staticmethod
- def _get_consumer_settings():
- queue_name = settings.CQRS['queue']
- replica_settings = settings.CQRS.get('replica', {})
- dead_letter_queue_name = 'dead_letter_{0}'.format(queue_name)
- if 'dead_letter_queue' in replica_settings:
- dead_letter_queue_name = replica_settings['dead_letter_queue']
- if 'consumer_prefetch_count' in settings.CQRS:
- logger.warning(
- "The 'consumer_prefetch_count' setting is ignored for RabbitMQTransport.",
- )
- prefetch_count = get_messages_prefetch_count_per_worker()
- return (
- queue_name,
- dead_letter_queue_name,
- prefetch_count,
- )
- @classmethod
- def _ack(cls, channel, delivery_tag, payload=None):
- channel.basic_ack(delivery_tag)
- if payload is not None:
- cls.log_consumed_accepted(payload)
- @classmethod
- def _nack(cls, channel, delivery_tag, payload=None):
- channel.basic_nack(delivery_tag, requeue=False)
- if payload is not None:
- cls.log_consumed_denied(payload)
|