|
@@ -0,0 +1,439 @@
|
|
|
+# Copyright © 2023 Ingram Micro Inc. All rights reserved.
|
|
|
+
|
|
|
+import logging
|
|
|
+import time
|
|
|
+from datetime import timedelta
|
|
|
+from socket import gaierror
|
|
|
+from urllib.parse import unquote, urlparse
|
|
|
+
|
|
|
+import ujson
|
|
|
+from django.conf import settings
|
|
|
+from django.utils import timezone
|
|
|
+from pika import (
|
|
|
+ BasicProperties,
|
|
|
+ BlockingConnection,
|
|
|
+ ConnectionParameters,
|
|
|
+ credentials,
|
|
|
+ exceptions,
|
|
|
+)
|
|
|
+from pika.adapters.utils.connection_workflow import AMQPConnectorException
|
|
|
+
|
|
|
+from dj_cqrs.constants import DEFAULT_DEAD_MESSAGE_TTL, SignalType
|
|
|
+from dj_cqrs.controller import consumer
|
|
|
+from dj_cqrs.dataclasses import TransportPayload
|
|
|
+from dj_cqrs.delay import DelayMessage, DelayQueue
|
|
|
+from dj_cqrs.registries import ReplicaRegistry
|
|
|
+from dj_cqrs.transport import BaseTransport
|
|
|
+from dj_cqrs.transport.mixins import LoggingMixin
|
|
|
+from dj_cqrs.utils import get_delay_queue_max_size, get_messages_prefetch_count_per_worker
|
|
|
+
|
|
|
+
|
|
|
+logger = logging.getLogger('django-cqrs')
|
|
|
+
|
|
|
+
|
|
|
+class RabbitMQTransport(LoggingMixin, BaseTransport):
|
|
|
+ """Transport class for RabbitMQ."""
|
|
|
+
|
|
|
+ CONSUMER_RETRY_TIMEOUT = 5
|
|
|
+ PRODUCER_RETRIES = 1
|
|
|
+
|
|
|
+ _producer_connection = None
|
|
|
+ _producer_channel = None
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def clean_connection(cls):
|
|
|
+ """Clean the RabbitMQ connection."""
|
|
|
+ connection = cls._producer_connection
|
|
|
+ if connection and not connection.is_closed:
|
|
|
+ try:
|
|
|
+ connection.close()
|
|
|
+ except (exceptions.StreamLostError, exceptions.ConnectionClosed, ConnectionError):
|
|
|
+ logger.warning('Connection was closed or is closing. Skip it...')
|
|
|
+
|
|
|
+ cls._producer_connection = None
|
|
|
+ cls._producer_channel = None
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def consume(cls, cqrs_ids=None):
|
|
|
+ """Receive data from master model.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cqrs_ids (str): cqrs ids.
|
|
|
+ """
|
|
|
+ consumer_rabbit_settings = cls._get_consumer_settings()
|
|
|
+ common_rabbit_settings = cls._get_common_settings()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ connection = None
|
|
|
+ try:
|
|
|
+ delay_queue = DelayQueue(max_size=get_delay_queue_max_size())
|
|
|
+ connection, channel, consumer_generator = cls._get_consumer_rmq_objects(
|
|
|
+ *(common_rabbit_settings + consumer_rabbit_settings),
|
|
|
+ cqrs_ids=cqrs_ids,
|
|
|
+ )
|
|
|
+
|
|
|
+ for method_frame, properties, body in consumer_generator:
|
|
|
+ if method_frame is not None:
|
|
|
+ cls._consume_message(
|
|
|
+ channel,
|
|
|
+ method_frame,
|
|
|
+ properties,
|
|
|
+ body,
|
|
|
+ delay_queue,
|
|
|
+ )
|
|
|
+ cls._process_delay_messages(channel, delay_queue)
|
|
|
+ except (
|
|
|
+ exceptions.AMQPError,
|
|
|
+ exceptions.ChannelError,
|
|
|
+ exceptions.ReentrancyError,
|
|
|
+ gaierror,
|
|
|
+ ):
|
|
|
+ logger.warning('AMQP connection error. Reconnecting...', exc_info=True)
|
|
|
+ time.sleep(cls.CONSUMER_RETRY_TIMEOUT)
|
|
|
+ finally:
|
|
|
+ if connection and not connection.is_closed:
|
|
|
+ connection.close()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def produce(cls, payload):
|
|
|
+ """
|
|
|
+ Send data from master model to replicas.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ payload (dj_cqrs.dataclasses.TransportPayload): Transport payload from master model.
|
|
|
+ """
|
|
|
+ cls._produce_with_retries(payload, retries=cls.PRODUCER_RETRIES)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _produce_with_retries(cls, payload, retries):
|
|
|
+ try:
|
|
|
+ rmq_settings = cls._get_common_settings()
|
|
|
+ exchange = rmq_settings[-1]
|
|
|
+ # Decided not to create context-manager to stay within the class
|
|
|
+ _, channel = cls._get_producer_rmq_objects(
|
|
|
+ *rmq_settings,
|
|
|
+ signal_type=payload.signal_type,
|
|
|
+ )
|
|
|
+
|
|
|
+ cls._produce_message(channel, exchange, payload)
|
|
|
+ cls.log_produced(payload)
|
|
|
+ except (
|
|
|
+ exceptions.AMQPError,
|
|
|
+ exceptions.ChannelError,
|
|
|
+ exceptions.ReentrancyError,
|
|
|
+ AMQPConnectorException,
|
|
|
+ AssertionError,
|
|
|
+ ) as e:
|
|
|
+ # in case of any error - close connection and try to reconnect
|
|
|
+ cls.clean_connection()
|
|
|
+
|
|
|
+ base_log_message = "CQRS couldn't be published: pk = {0} ({1}).".format(
|
|
|
+ payload.pk,
|
|
|
+ payload.cqrs_id,
|
|
|
+ )
|
|
|
+ if not retries:
|
|
|
+ logger.exception(base_log_message)
|
|
|
+ return
|
|
|
+
|
|
|
+ logger.warning(
|
|
|
+ '{0} Error: {1}. Reconnect...'.format(
|
|
|
+ base_log_message,
|
|
|
+ e.__class__.__name__,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ cls._produce_with_retries(payload, retries - 1)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _consume_message(cls, ch, method, properties, body, delay_queue):
|
|
|
+ try:
|
|
|
+ dct = ujson.loads(body)
|
|
|
+ except ValueError:
|
|
|
+ logger.error("CQRS couldn't be parsed: {0}.".format(body))
|
|
|
+ ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
|
|
|
+ return
|
|
|
+
|
|
|
+ required_keys = {'instance_pk', 'signal_type', 'cqrs_id', 'instance_data'}
|
|
|
+ for key in required_keys:
|
|
|
+ if key not in dct:
|
|
|
+ msg = "CQRS couldn't proceed, %s isn't found in body: %s."
|
|
|
+ logger.error(msg, key, body)
|
|
|
+ ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
|
|
|
+ return
|
|
|
+
|
|
|
+ payload = TransportPayload.from_message(dct)
|
|
|
+ cls.log_consumed(payload)
|
|
|
+
|
|
|
+ delivery_tag = method.delivery_tag
|
|
|
+ if payload.is_expired():
|
|
|
+ cls._add_to_dead_letter_queue(ch, payload)
|
|
|
+ cls._nack(ch, delivery_tag)
|
|
|
+ return
|
|
|
+
|
|
|
+ instance, exception = None, None
|
|
|
+ try:
|
|
|
+ instance = consumer.consume(payload)
|
|
|
+ except Exception as e:
|
|
|
+ exception = e
|
|
|
+ logger.error('CQRS service exception', exc_info=True)
|
|
|
+
|
|
|
+ if instance and exception is None:
|
|
|
+ cls._ack(ch, delivery_tag, payload)
|
|
|
+ else:
|
|
|
+ cls._fail_message(
|
|
|
+ ch,
|
|
|
+ delivery_tag,
|
|
|
+ payload,
|
|
|
+ exception,
|
|
|
+ delay_queue,
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _fail_message(cls, channel, delivery_tag, payload, exception, delay_queue):
|
|
|
+ cls.log_consumed_failed(payload)
|
|
|
+ model_cls = ReplicaRegistry.get_model_by_cqrs_id(payload.cqrs_id)
|
|
|
+ if model_cls is None:
|
|
|
+ logger.error('Model for cqrs_id {0} is not found.'.format(payload.cqrs_id))
|
|
|
+ cls._nack(channel, delivery_tag)
|
|
|
+ return
|
|
|
+
|
|
|
+ if model_cls.should_retry_cqrs(payload.retries, exception):
|
|
|
+ delay = model_cls.get_cqrs_retry_delay(payload.retries)
|
|
|
+ cls._delay_message(channel, delivery_tag, payload, delay, delay_queue)
|
|
|
+ else:
|
|
|
+ cls._add_to_dead_letter_queue(channel, payload)
|
|
|
+ cls._nack(channel, delivery_tag)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _delay_message(cls, channel, delivery_tag, payload, delay, delay_queue):
|
|
|
+ if delay_queue.full():
|
|
|
+ # Memory limits handling, requeuing message with lowest ETA
|
|
|
+ requeue_message = delay_queue.get()
|
|
|
+ cls._requeue_message(
|
|
|
+ channel,
|
|
|
+ requeue_message.delivery_tag,
|
|
|
+ requeue_message.payload,
|
|
|
+ )
|
|
|
+
|
|
|
+ eta = timezone.now() + timedelta(seconds=delay)
|
|
|
+ delay_message = DelayMessage(delivery_tag, payload, eta)
|
|
|
+ delay_queue.put(delay_message)
|
|
|
+ cls.log_delayed(payload, delay, delay_message.eta)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _add_to_dead_letter_queue(cls, channel, payload):
|
|
|
+ replica_settings = settings.CQRS.get('replica', {})
|
|
|
+ dead_message_ttl = DEFAULT_DEAD_MESSAGE_TTL
|
|
|
+ if 'dead_message_ttl' in replica_settings:
|
|
|
+ dead_message_ttl = replica_settings['dead_message_ttl']
|
|
|
+
|
|
|
+ expiration = None
|
|
|
+ if dead_message_ttl is not None:
|
|
|
+ expiration = str(dead_message_ttl * 1000) # milliseconds
|
|
|
+
|
|
|
+ payload.is_dead_letter = True
|
|
|
+ exchange = cls._get_common_settings()[-1]
|
|
|
+ cls._produce_message(channel, exchange, payload, expiration)
|
|
|
+ cls.log_dead_letter(payload)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _requeue_message(cls, channel, delivery_tag, payload):
|
|
|
+ payload.retries += 1
|
|
|
+ payload.is_requeue = True
|
|
|
+
|
|
|
+ cls.produce(payload)
|
|
|
+ cls._nack(channel, delivery_tag)
|
|
|
+ cls.log_requeued(payload)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _process_delay_messages(cls, channel, delay_queue):
|
|
|
+ for delay_message in delay_queue.get_ready():
|
|
|
+ cls._requeue_message(channel, delay_message.delivery_tag, delay_message.payload)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _produce_message(cls, channel, exchange, payload, expiration=None):
|
|
|
+ routing_key = cls._get_produced_message_routing_key(payload)
|
|
|
+
|
|
|
+ channel.basic_publish(
|
|
|
+ exchange=exchange,
|
|
|
+ routing_key=routing_key,
|
|
|
+ body=ujson.dumps(payload.to_dict()),
|
|
|
+ mandatory=True,
|
|
|
+ properties=BasicProperties(
|
|
|
+ content_type='text/plain',
|
|
|
+ delivery_mode=2, # make message persistent
|
|
|
+ expiration=expiration,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_produced_message_routing_key(cls, payload):
|
|
|
+ routing_key = payload.cqrs_id
|
|
|
+
|
|
|
+ if payload.signal_type == SignalType.SYNC and payload.queue:
|
|
|
+ routing_key = 'cqrs.{0}.{1}'.format(payload.queue, routing_key)
|
|
|
+ elif getattr(payload, 'is_dead_letter', False):
|
|
|
+ dead_letter_queue_name = cls._get_consumer_settings()[1]
|
|
|
+ routing_key = 'cqrs.{0}.{1}'.format(dead_letter_queue_name, routing_key)
|
|
|
+ elif getattr(payload, 'is_requeue', False):
|
|
|
+ queue = cls._get_consumer_settings()[0]
|
|
|
+ routing_key = 'cqrs.{0}.{1}'.format(queue, routing_key)
|
|
|
+
|
|
|
+ return routing_key
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_consumer_rmq_objects(
|
|
|
+ cls,
|
|
|
+ host,
|
|
|
+ port,
|
|
|
+ creds,
|
|
|
+ exchange,
|
|
|
+ queue_name,
|
|
|
+ dead_letter_queue_name,
|
|
|
+ prefetch_count,
|
|
|
+ cqrs_ids=None,
|
|
|
+ ):
|
|
|
+ connection = BlockingConnection(
|
|
|
+ ConnectionParameters(host=host, port=port, credentials=creds),
|
|
|
+ )
|
|
|
+ channel = connection.channel()
|
|
|
+ channel.basic_qos(prefetch_count=prefetch_count)
|
|
|
+ cls._declare_exchange(channel, exchange)
|
|
|
+
|
|
|
+ channel.queue_declare(queue_name, durable=True, exclusive=False)
|
|
|
+ channel.queue_declare(dead_letter_queue_name, durable=True, exclusive=False)
|
|
|
+
|
|
|
+ for cqrs_id, _ in ReplicaRegistry.models.items():
|
|
|
+ if cqrs_ids and cqrs_id not in cqrs_ids:
|
|
|
+ continue
|
|
|
+
|
|
|
+ channel.queue_bind(exchange=exchange, queue=queue_name, routing_key=cqrs_id)
|
|
|
+
|
|
|
+ # Every service must have specific SYNC or requeue routes
|
|
|
+ channel.queue_bind(
|
|
|
+ exchange=exchange,
|
|
|
+ queue=queue_name,
|
|
|
+ routing_key='cqrs.{0}.{1}'.format(queue_name, cqrs_id),
|
|
|
+ )
|
|
|
+
|
|
|
+ # Dead letter
|
|
|
+ channel.queue_bind(
|
|
|
+ exchange=exchange,
|
|
|
+ queue=dead_letter_queue_name,
|
|
|
+ routing_key='cqrs.{0}.{1}'.format(dead_letter_queue_name, cqrs_id),
|
|
|
+ )
|
|
|
+
|
|
|
+ delay_queue_check_timeout = 1 # seconds
|
|
|
+ consumer_generator = channel.consume(
|
|
|
+ queue=queue_name,
|
|
|
+ auto_ack=False,
|
|
|
+ exclusive=False,
|
|
|
+ inactivity_timeout=delay_queue_check_timeout,
|
|
|
+ )
|
|
|
+ return connection, channel, consumer_generator
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_producer_rmq_objects(cls, host, port, creds, exchange, signal_type=None):
|
|
|
+ """
|
|
|
+ Use shared connection in case of sync mode, otherwise create new connection for each
|
|
|
+ message
|
|
|
+ """
|
|
|
+ if signal_type == SignalType.SYNC:
|
|
|
+ if cls._producer_connection is None:
|
|
|
+ connection, channel = cls._create_connection(host, port, creds, exchange)
|
|
|
+
|
|
|
+ cls._producer_connection = connection
|
|
|
+ cls._producer_channel = channel
|
|
|
+
|
|
|
+ return cls._producer_connection, cls._producer_channel
|
|
|
+ else:
|
|
|
+ return cls._create_connection(host, port, creds, exchange)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _create_connection(cls, host, port, creds, exchange):
|
|
|
+ connection = BlockingConnection(
|
|
|
+ ConnectionParameters(
|
|
|
+ host=host,
|
|
|
+ port=port,
|
|
|
+ credentials=creds,
|
|
|
+ blocked_connection_timeout=10,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ channel = connection.channel()
|
|
|
+ channel.basic_qos(prefetch_count=get_messages_prefetch_count_per_worker())
|
|
|
+ cls._declare_exchange(channel, exchange)
|
|
|
+
|
|
|
+ return connection, channel
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _declare_exchange(channel, exchange):
|
|
|
+ channel.exchange_declare(
|
|
|
+ exchange=exchange,
|
|
|
+ exchange_type='topic',
|
|
|
+ durable=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _parse_url(url):
|
|
|
+ scheme = urlparse(url).scheme
|
|
|
+ assert scheme == 'amqp', 'Scheme must be "amqp" for RabbitMQTransport.'
|
|
|
+
|
|
|
+ schemeless = url[len(scheme) + 3 :]
|
|
|
+ parts = urlparse('http://' + schemeless)
|
|
|
+
|
|
|
+ return (
|
|
|
+ unquote(parts.hostname or '') or ConnectionParameters.DEFAULT_HOST,
|
|
|
+ parts.port or ConnectionParameters.DEFAULT_PORT,
|
|
|
+ unquote(parts.username or '') or ConnectionParameters.DEFAULT_USERNAME,
|
|
|
+ unquote(parts.password or '') or ConnectionParameters.DEFAULT_PASSWORD,
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_common_settings(cls):
|
|
|
+ if 'url' in settings.CQRS:
|
|
|
+ host, port, user, password = cls._parse_url(settings.CQRS.get('url'))
|
|
|
+ else:
|
|
|
+ host = settings.CQRS.get('host', ConnectionParameters.DEFAULT_HOST)
|
|
|
+ port = settings.CQRS.get('port', ConnectionParameters.DEFAULT_PORT)
|
|
|
+ user = settings.CQRS.get('user', ConnectionParameters.DEFAULT_USERNAME)
|
|
|
+ password = settings.CQRS.get('password', ConnectionParameters.DEFAULT_PASSWORD)
|
|
|
+ exchange = settings.CQRS.get('exchange', 'cqrs')
|
|
|
+ return (
|
|
|
+ host,
|
|
|
+ port,
|
|
|
+ credentials.PlainCredentials(user, password, erase_on_connect=True),
|
|
|
+ exchange,
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_consumer_settings():
|
|
|
+ queue_name = settings.CQRS['queue']
|
|
|
+
|
|
|
+ replica_settings = settings.CQRS.get('replica', {})
|
|
|
+ dead_letter_queue_name = 'dead_letter_{0}'.format(queue_name)
|
|
|
+ if 'dead_letter_queue' in replica_settings:
|
|
|
+ dead_letter_queue_name = replica_settings['dead_letter_queue']
|
|
|
+
|
|
|
+ if 'consumer_prefetch_count' in settings.CQRS:
|
|
|
+ logger.warning(
|
|
|
+ "The 'consumer_prefetch_count' setting is ignored for RabbitMQTransport.",
|
|
|
+ )
|
|
|
+ prefetch_count = get_messages_prefetch_count_per_worker()
|
|
|
+
|
|
|
+ return (
|
|
|
+ queue_name,
|
|
|
+ dead_letter_queue_name,
|
|
|
+ prefetch_count,
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _ack(cls, channel, delivery_tag, payload=None):
|
|
|
+ channel.basic_ack(delivery_tag)
|
|
|
+ if payload is not None:
|
|
|
+ cls.log_consumed_accepted(payload)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _nack(cls, channel, delivery_tag, payload=None):
|
|
|
+ channel.basic_nack(delivery_tag, requeue=False)
|
|
|
+ if payload is not None:
|
|
|
+ cls.log_consumed_denied(payload)
|