# Copyright © 2023 Ingram Micro Inc. All rights reserved. import logging from datetime import datetime, timedelta, timezone from importlib import import_module, reload import pytest import ujson from django.db import DatabaseError from pika.adapters.utils.connection_workflow import AMQPConnectorException from pika.exceptions import ( AMQPError, ChannelError, ReentrancyError, StreamLostError, ) from dj_cqrs.constants import ( DEFAULT_MASTER_AUTO_UPDATE_FIELDS, DEFAULT_MASTER_MESSAGE_TTL, DEFAULT_REPLICA_DELAY_QUEUE_MAX_SIZE, DEFAULT_REPLICA_MAX_RETRIES, DEFAULT_REPLICA_RETRY_DELAY, SignalType, ) from dj_cqrs.dataclasses import TransportPayload from dj_cqrs.delay import DelayMessage, DelayQueue from dj_cqrs.transport.rabbit_mq import RabbitMQTransport from tests.utils import db_error class PublicRabbitMQTransport(RabbitMQTransport): @classmethod def get_common_settings(cls): return cls._get_common_settings() @classmethod def get_consumer_settings(cls): return cls._get_consumer_settings() @classmethod def get_produced_message_routing_key(cls, *args): return cls._get_produced_message_routing_key(*args) @classmethod def consume_message(cls, *args): return cls._consume_message(*args) @classmethod def delay_message(cls, *args): return cls._delay_message(*args) @classmethod def fail_message(cls, *args): return cls._fail_message(*args) @classmethod def process_delay_messages(cls, *args): return cls._process_delay_messages(*args) @classmethod def produce_message(cls, *args): return cls._produce_message(*args) def test_default_settings(): s = PublicRabbitMQTransport.get_common_settings() assert s[0] == 'localhost' assert s[1] == 5672 assert s[2].username == 'guest' and s[2].password == 'guest' assert s[3] == 'cqrs' def test_non_default_settings(settings, caplog): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', 'host': 'rabbit', 'port': 8000, 'user': 'usr', 'password': 'pswd', 'exchange': 'exchange', } s = PublicRabbitMQTransport.get_common_settings() assert s[0] == 'rabbit' assert s[1] == 8000 assert s[2].username == 'usr' and s[2].password == 'pswd' assert s[3] == 'exchange' def test_default_url_settings(settings): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', 'url': 'amqp://localhost', } s = PublicRabbitMQTransport.get_common_settings() assert s[0] == 'localhost' assert s[1] == 5672 assert s[2].username == 'guest' and s[2].password == 'guest' assert s[3] == 'cqrs' def test_non_default_url_settings(settings): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', 'url': 'amqp://usr:pswd@rabbit:8000', 'exchange': 'exchange', } s = PublicRabbitMQTransport.get_common_settings() assert s[0] == 'rabbit' assert s[1] == 8000 assert s[2].username == 'usr' and s[2].password == 'pswd' assert s[3] == 'exchange' def test_invalid_url_settings(settings): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', 'url': 'rabbit://localhost', } with pytest.raises(AssertionError) as ei: PublicRabbitMQTransport.get_common_settings() assert ei.match('Scheme must be "amqp" for RabbitMQTransport.') def test_consumer_default_settings(settings): settings.CQRS['queue'] = 'replica' settings.CQRS['replica'].pop('dead_letter_queue', None) s = PublicRabbitMQTransport.get_consumer_settings() assert s[1] == 'dead_letter_replica' assert s[2] == 1001 def test_consumer_non_default_settings(settings, caplog): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', 'queue': 'q', 'consumer_prefetch_count': 2, 'replica': { 'delay_queue_max_size': None, # Infinite }, } s = PublicRabbitMQTransport.get_consumer_settings() assert s[0] == 'q' assert s[1] == 'dead_letter_q' assert s[2] == 0 # Infinite assert "The 'consumer_prefetch_count' setting is ignored for RabbitMQTransport." in caplog.text @pytest.fixture def rabbit_transport(settings): settings.CQRS = { 'transport': 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport', 'queue': 'replica', 'master': { 'CQRS_AUTO_UPDATE_FIELDS': DEFAULT_MASTER_AUTO_UPDATE_FIELDS, 'CQRS_MESSAGE_TTL': DEFAULT_MASTER_MESSAGE_TTL, 'correlation_function': None, 'meta_function': None, }, 'replica': { 'CQRS_MAX_RETRIES': DEFAULT_REPLICA_MAX_RETRIES, 'CQRS_RETRY_DELAY': DEFAULT_REPLICA_RETRY_DELAY, 'delay_queue_max_size': DEFAULT_REPLICA_DELAY_QUEUE_MAX_SIZE, }, } module = reload(import_module('dj_cqrs.transport')) yield module.current_transport @pytest.mark.parametrize( 'exception', (AMQPError, ChannelError, ReentrancyError, AMQPConnectorException, AssertionError), ) def test_produce_connection_error(exception, rabbit_transport, mocker, caplog): mocker.patch.object(RabbitMQTransport, '_get_producer_rmq_objects', side_effect=exception) rabbit_transport.produce( TransportPayload( SignalType.SAVE, 'CQRS_ID', {'id': 1}, 1, ), ) assert "CQRS couldn't be published: pk = 1 (CQRS_ID)." in caplog.text def test_produce_publish_error(rabbit_transport, mocker, caplog): mocker.patch.object( RabbitMQTransport, '_get_producer_rmq_objects', return_value=(mocker.MagicMock(), None), ) mocker.patch.object(RabbitMQTransport, '_produce_message', side_effect=AMQPError) rabbit_transport.produce( TransportPayload( SignalType.SAVE, 'CQRS_ID', {'id': 1}, 1, ), ) assert "CQRS couldn't be published: pk = 1 (CQRS_ID)." in caplog.text def test_produce_ok(rabbit_transport, mocker, caplog): caplog.set_level(logging.INFO) mocker.patch.object( RabbitMQTransport, '_get_producer_rmq_objects', return_value=(mocker.MagicMock(), None), ) mocker.patch.object(RabbitMQTransport, '_produce_message', return_value=True) rabbit_transport.produce( TransportPayload( SignalType.SAVE, 'CQRS_ID', {'id': 1}, 1, ), ) assert 'CQRS is published: pk = 1 (CQRS_ID)' in caplog.text def test_produce_retry_on_error(rabbit_transport, mocker, caplog): caplog.set_level(logging.INFO) mocker.patch.object( RabbitMQTransport, '_get_producer_rmq_objects', side_effect=[ AMQPConnectorException, (mocker.MagicMock(), None), ], ) mocker.patch.object(RabbitMQTransport, '_produce_message', return_value=True) rabbit_transport.produce( TransportPayload( SignalType.SAVE, 'CQRS_ID', {'id': 1}, 1, ), ) assert caplog.record_tuples == [ ( 'django-cqrs', logging.WARNING, "CQRS couldn't be published: pk = 1 (CQRS_ID)." ' Error: AMQPConnectorException. Reconnect...', ), ( 'django-cqrs', logging.INFO, 'CQRS is published: pk = 1 (CQRS_ID), correlation_id = None.', ), ] def test_produce_retry_on_error_1(rabbit_transport, mocker, caplog): mocker.patch.object( RabbitMQTransport, '_get_producer_rmq_objects', side_effect=[ StreamLostError, StreamLostError, ], ) mocker.patch.object(RabbitMQTransport, '_produce_message', return_value=True) rabbit_transport.produce( TransportPayload( SignalType.SAVE, 'CQRS_ID', {'id': 1}, 1, ), ) assert caplog.record_tuples == [ ( 'django-cqrs', logging.WARNING, "CQRS couldn't be published: pk = 1 (CQRS_ID). Error: StreamLostError. Reconnect...", ), ( 'django-cqrs', logging.ERROR, "CQRS couldn't be published: pk = 1 (CQRS_ID).", ), ] def test_produce_message_ok(mocker): expires = datetime(2100, 1, 1, tzinfo=timezone.utc) expected_expires = '2100-01-01T00:00:00+00:00' channel = mocker.MagicMock() payload = TransportPayload( SignalType.SAVE, cqrs_id='cqrs_id', instance_data={}, instance_pk='id', previous_data={'e': 'f'}, expires=expires, retries=2, ) PublicRabbitMQTransport.produce_message(channel, 'exchange', payload) assert channel.basic_publish.call_count == 1 basic_publish_kwargs = channel.basic_publish.call_args[1] assert ujson.loads(basic_publish_kwargs['body']) == { 'signal_type': SignalType.SAVE, 'cqrs_id': 'cqrs_id', 'instance_data': {}, 'instance_pk': 'id', 'previous_data': {'e': 'f'}, 'correlation_id': None, 'expires': expected_expires, 'retries': 2, 'meta': None, } assert basic_publish_kwargs['exchange'] == 'exchange' assert basic_publish_kwargs['mandatory'] assert basic_publish_kwargs['routing_key'] == 'cqrs_id' assert basic_publish_kwargs['properties'].content_type == 'text/plain' assert basic_publish_kwargs['properties'].delivery_mode == 2 def test_produce_sync_message_no_queue(mocker): channel = mocker.MagicMock() payload = TransportPayload(SignalType.SYNC, 'cqrs_id', {}, None) PublicRabbitMQTransport.produce_message(channel, 'exchange', payload) basic_publish_kwargs = channel.basic_publish.call_args[1] assert ujson.loads(basic_publish_kwargs['body']) == { 'signal_type': SignalType.SYNC, 'cqrs_id': 'cqrs_id', 'instance_data': {}, 'instance_pk': None, 'previous_data': None, 'correlation_id': None, 'expires': None, 'retries': 0, 'meta': None, } assert basic_publish_kwargs['routing_key'] == 'cqrs_id' def test_produce_sync_message_queue(mocker): channel = mocker.MagicMock() payload = TransportPayload(SignalType.SYNC, 'cqrs_id', {}, 'id', 'queue') PublicRabbitMQTransport.produce_message(channel, 'exchange', payload) basic_publish_kwargs = channel.basic_publish.call_args[1] assert ujson.loads(basic_publish_kwargs['body']) == { 'signal_type': SignalType.SYNC, 'cqrs_id': 'cqrs_id', 'instance_data': {}, 'instance_pk': 'id', 'previous_data': None, 'correlation_id': None, 'expires': None, 'retries': 0, 'meta': None, } assert basic_publish_kwargs['routing_key'] == 'cqrs.queue.cqrs_id' def test_consume_connection_error(rabbit_transport, mocker, caplog): mocker.patch.object( RabbitMQTransport, '_get_consumer_rmq_objects', side_effect=AMQPError, ) mocker.patch('time.sleep', side_effect=db_error) with pytest.raises(DatabaseError): rabbit_transport.consume() assert 'AMQP connection error. Reconnecting...' in caplog.text def test_consume_ok(rabbit_transport, mocker): consumer_generator = (v for v in [(1, None, None)]) mocker.patch.object( RabbitMQTransport, '_get_consumer_rmq_objects', return_value=(None, None, consumer_generator), ) mocker.patch.object( RabbitMQTransport, '_consume_message', db_error, ) with pytest.raises(DatabaseError): rabbit_transport.consume() def test_consume_message_ack(mocker, caplog): caplog.set_level(logging.INFO) consumer_mock = mocker.patch('dj_cqrs.controller.consumer.consume') PublicRabbitMQTransport.consume_message( mocker.MagicMock(), mocker.MagicMock(), None, '{"signal_type":"signal","cqrs_id":"cqrs_id","instance_data":{},' '"instance_pk":1, "previous_data":{}, "correlation_id":"abc",' '"expires":"2100-01-01T00:00:00+00:00", "retries":1}', mocker.MagicMock(), ) assert consumer_mock.call_count == 1 payload = consumer_mock.call_args[0][0] assert payload.signal_type == 'signal' assert payload.cqrs_id == 'cqrs_id' assert payload.instance_data == {} assert payload.previous_data == {} assert payload.pk == 1 assert payload.correlation_id == 'abc' assert payload.expires == datetime(2100, 1, 1, tzinfo=timezone.utc) assert payload.retries == 1 assert 'CQRS is received: pk = 1 (cqrs_id), correlation_id = abc.' in caplog.text assert 'CQRS is applied: pk = 1 (cqrs_id), correlation_id = abc.' in caplog.text def test_consume_message_nack(mocker, caplog): caplog.set_level(logging.INFO) mocker.patch('dj_cqrs.controller.consumer.consume', return_value=None) PublicRabbitMQTransport.consume_message( mocker.MagicMock(), mocker.MagicMock(), None, '{"signal_type":"signal","cqrs_id":"basic","instance_data":{},' '"instance_pk":1,"previous_data":null,' '"expires":"2100-01-01T00:00:00+00:00", "retries":0}', mocker.MagicMock(), ) assert 'CQRS is received: pk = 1 (basic), correlation_id = None.' in caplog.text assert 'CQRS is failed: pk = 1 (basic), correlation_id = None, retries = 0.' in caplog.text def test_consume_message_nack_deprecated_structure(mocker, caplog): caplog.set_level(logging.INFO) consumer_mock = mocker.patch('dj_cqrs.controller.consumer.consume', return_value=None) PublicRabbitMQTransport.consume_message( mocker.MagicMock(), mocker.MagicMock(), None, '{"signal_type":"signal","cqrs_id":"cqrs_id","instance_data":{}}', mocker.MagicMock(), ) assert consumer_mock.call_count == 0 assert "CQRS couldn't proceed, instance_pk isn't found in body" in caplog.text def test_consume_message_expired(mocker, caplog): caplog.set_level(logging.INFO) channel = mocker.MagicMock() PublicRabbitMQTransport.consume_message( channel, mocker.MagicMock(), None, '{"signal_type":"signal","cqrs_id":"cqrs_id","instance_data":{},' '"instance_pk":1,"previous_data":null,' '"expires":"2000-01-01T00:00:00+00:00", "retries":0}', mocker.MagicMock(), ) assert channel.basic_nack.call_count == 1 assert 'CQRS is received: pk = 1 (cqrs_id)' in caplog.text assert 'CQRS is added to dead letter queue: pk = 1 (cqrs_id)' in caplog.text def test_consume_message_json_parsing_error(mocker, caplog): PublicRabbitMQTransport.consume_message( mocker.MagicMock(), mocker.MagicMock(), None, '{bad_payload:', mocker.MagicMock(), ) assert ': {bad_payload:.' in caplog.text def test_consume_message_package_structure_error(mocker, caplog): PublicRabbitMQTransport.consume_message( mocker.MagicMock(), mocker.MagicMock(), None, 'inv{"pk":"1"}', mocker.MagicMock(), ) assert """CQRS couldn't be parsed: inv{"pk":"1"}""" in caplog.text def test_fail_message_with_retry(mocker): payload = TransportPayload(SignalType.SAVE, 'basic', {'id': 1}, 1) delay_queue = DelayQueue() PublicRabbitMQTransport.fail_message(mocker.MagicMock(), 100, payload, None, delay_queue) assert delay_queue.qsize() == 1 delay_message = delay_queue.get() assert delay_message.delivery_tag == 100 assert delay_message.payload is payload def test_message_without_retry_dead_letter(settings, mocker, caplog): settings.CQRS['replica']['CQRS_MAX_RETRIES'] = 1 produce_message = mocker.patch( 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport._produce_message', ) channel = mocker.MagicMock() payload = TransportPayload( SignalType.SAVE, 'basic', {'id': 1}, 1, correlation_id='abc', retries=2, ) delay_queue = DelayQueue() PublicRabbitMQTransport.fail_message(channel, 1, payload, None, delay_queue) assert delay_queue.qsize() == 0 assert channel.basic_nack.call_count == 1 assert produce_message.call_count == 1 produce_payload = produce_message.call_args[0][2] assert produce_payload is payload assert getattr(produce_message, 'is_dead_letter', False) assert 'CQRS is failed: pk = 1 (basic), correlation_id = abc, retries = 2.' in caplog.text assert 'CQRS is added to dead letter queue: pk = 1 (basic), correlation_id = abc' in caplog.text def test_fail_message_invalid_model(mocker, caplog): nack = mocker.patch( 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport._nack', ) payload = TransportPayload(SignalType.SAVE, 'not_existing', {'id': 1}, 1) delay_queue = DelayQueue() delivery_tag = 101 PublicRabbitMQTransport.fail_message( mocker.MagicMock(), delivery_tag, payload, None, delay_queue, ) assert delay_queue.qsize() == 0 assert nack.call_count == 1 assert nack.call_args[0][1] == delivery_tag assert 'Model for cqrs_id not_existing is not found.' in caplog.text def test_get_produced_message_routing_key_dead_letter(settings): settings.CQRS['replica']['dead_letter_queue'] = 'dead_letter_replica' payload = TransportPayload(SignalType.SYNC, 'CQRS_ID', {}, None) payload.is_dead_letter = True routing_key = PublicRabbitMQTransport.get_produced_message_routing_key(payload) assert routing_key == 'cqrs.dead_letter_replica.CQRS_ID' def test_get_produced_message_routing_key_requeue(settings): settings.CQRS['queue'] = 'replica' payload = TransportPayload(SignalType.SAVE, 'CQRS_ID', {}, None) payload.is_requeue = True routing_key = PublicRabbitMQTransport.get_produced_message_routing_key(payload) assert routing_key == 'cqrs.replica.CQRS_ID' def test_process_delay_messages(mocker, caplog): channel = mocker.MagicMock() produce = mocker.patch('dj_cqrs.transport.rabbit_mq.RabbitMQTransport.produce') payload = TransportPayload(SignalType.SAVE, 'CQRS_ID', {'id': 1}, 1) delay_queue = DelayQueue() delay_queue.put( DelayMessage(delivery_tag=1, payload=payload, eta=datetime.now(tz=timezone.utc)), ) PublicRabbitMQTransport.process_delay_messages(channel, delay_queue) assert delay_queue.qsize() == 0 assert channel.basic_nack.call_count == 1 assert produce.call_count == 1 produce_payload = produce.call_args[0][0] assert produce_payload is payload assert produce_payload.retries == 1 assert getattr(produce_payload, 'is_requeue', False) assert 'CQRS is requeued: pk = 1 (CQRS_ID)' in caplog.text def test_delay_message_with_requeue(mocker, caplog): channel = mocker.MagicMock() requeue_message = mocker.patch( 'dj_cqrs.transport.rabbit_mq.RabbitMQTransport._requeue_message', ) delay_messages = [] for delay in (2, 1, 3): payload = TransportPayload(SignalType.SAVE, 'CQRS_ID', {'id': delay}, delay) eta = datetime.now(tz=timezone.utc) + timedelta(hours=delay) delay_message = DelayMessage(delivery_tag=delay, payload=payload, eta=eta) delay_messages.append(delay_message) delay_queue = DelayQueue(max_size=3) for delay_message in delay_messages: delay_queue.put(delay_message) exceeding_delay = 0 exceeding_payload = TransportPayload(SignalType.SAVE, 'CQRS_ID', {'id': 4}, 4) PublicRabbitMQTransport.delay_message( channel, 4, exceeding_payload, exceeding_delay, delay_queue, ) assert delay_queue.qsize() == 3 assert delay_queue.get().payload is exceeding_payload assert 'CQRS is delayed: pk = 4 (CQRS_ID), correlation_id = None, delay = 0 sec' in caplog.text assert requeue_message.call_count == 1 requeue_payload = requeue_message.call_args[0][2] min_eta_delay_message = sorted(delay_messages, key=lambda x: x.eta)[0] assert requeue_payload is min_eta_delay_message.payload