kombu.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright © 2023 Ingram Micro Inc. All rights reserved.
  2. import logging
  3. import ujson
  4. from django.conf import settings
  5. from kombu import (
  6. Connection,
  7. Exchange,
  8. Producer,
  9. Queue,
  10. )
  11. from kombu.exceptions import KombuError
  12. from kombu.mixins import ConsumerMixin
  13. from dj_cqrs.constants import SignalType
  14. from dj_cqrs.controller import consumer
  15. from dj_cqrs.dataclasses import TransportPayload
  16. from dj_cqrs.registries import ReplicaRegistry
  17. from dj_cqrs.transport import BaseTransport
  18. from dj_cqrs.transport.mixins import LoggingMixin
  19. logger = logging.getLogger('django-cqrs')
  20. class _KombuConsumer(ConsumerMixin):
  21. def __init__(self, url, exchange_name, queue_name, prefetch_count, callback, cqrs_ids=None):
  22. self.connection = Connection(url)
  23. self.exchange = Exchange(
  24. exchange_name,
  25. type='topic',
  26. durable=True,
  27. )
  28. self.queue_name = queue_name
  29. self.prefetch_count = prefetch_count
  30. self.callback = callback
  31. self.queues = []
  32. self.cqrs_ids = cqrs_ids
  33. self._init_queues()
  34. def _init_queues(self):
  35. channel = self.connection.channel()
  36. for cqrs_id in ReplicaRegistry.models.keys():
  37. if (not self.cqrs_ids) or (cqrs_id in self.cqrs_ids):
  38. q = Queue(
  39. self.queue_name,
  40. exchange=self.exchange,
  41. routing_key=cqrs_id,
  42. )
  43. q.maybe_bind(channel)
  44. q.declare()
  45. self.queues.append(q)
  46. sync_q = Queue(
  47. self.queue_name,
  48. exchange=self.exchange,
  49. routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id),
  50. )
  51. sync_q.maybe_bind(channel)
  52. sync_q.declare()
  53. self.queues.append(sync_q)
  54. def get_consumers(self, Consumer, channel):
  55. return [
  56. Consumer(
  57. queues=self.queues,
  58. callbacks=[self.callback],
  59. prefetch_count=self.prefetch_count,
  60. auto_declare=True,
  61. ),
  62. ]
  63. class KombuTransport(LoggingMixin, BaseTransport):
  64. """Transport class for Kombu."""
  65. CONSUMER_RETRY_TIMEOUT = 5
  66. @classmethod
  67. def clean_connection(cls):
  68. """Nothing to do here"""
  69. pass
  70. @classmethod
  71. def consume(cls, cqrs_ids=None):
  72. """Receive data from master model.
  73. Args:
  74. cqrs_ids (str): cqrs ids.
  75. """
  76. queue_name, prefetch_count = cls._get_consumer_settings()
  77. url, exchange_name = cls._get_common_settings()
  78. consumer = _KombuConsumer(
  79. url,
  80. exchange_name,
  81. queue_name,
  82. prefetch_count,
  83. cls._consume_message,
  84. cqrs_ids=cqrs_ids,
  85. )
  86. consumer.run()
  87. @classmethod
  88. def produce(cls, payload):
  89. """
  90. Send data from master model to replicas.
  91. Args:
  92. payload (dj_cqrs.dataclasses.TransportPayload): Transport payload from master model.
  93. """
  94. url, exchange_name = cls._get_common_settings()
  95. connection = None
  96. try:
  97. # Decided not to create context-manager to stay within the class
  98. connection, channel = cls._get_producer_kombu_objects(url, exchange_name)
  99. exchange = cls._create_exchange(exchange_name)
  100. cls._produce_message(channel, exchange, payload)
  101. cls.log_produced(payload)
  102. except KombuError:
  103. logger.error(
  104. "CQRS couldn't be published: pk = {0} ({1}).".format(
  105. payload.pk,
  106. payload.cqrs_id,
  107. ),
  108. )
  109. finally:
  110. if connection:
  111. connection.close()
  112. @classmethod
  113. def _consume_message(cls, body, message):
  114. try:
  115. dct = ujson.loads(body)
  116. except ValueError:
  117. logger.error("CQRS couldn't be parsed: {0}.".format(body))
  118. message.reject()
  119. return
  120. required_keys = {'instance_pk', 'signal_type', 'cqrs_id', 'instance_data'}
  121. for key in required_keys:
  122. if key not in dct:
  123. msg = "CQRS couldn't proceed, %s isn't found in body: %s."
  124. logger.error(msg, key, body)
  125. message.reject()
  126. return
  127. payload = TransportPayload(
  128. dct['signal_type'],
  129. dct['cqrs_id'],
  130. dct['instance_data'],
  131. dct.get('instance_pk'),
  132. previous_data=dct.get('previous_data'),
  133. correlation_id=dct.get('correlation_id'),
  134. )
  135. cls.log_consumed(payload)
  136. instance = consumer.consume(payload)
  137. if instance:
  138. message.ack()
  139. cls.log_consumed_accepted(payload)
  140. else:
  141. message.reject()
  142. cls.log_consumed_denied(payload)
  143. @classmethod
  144. def _produce_message(cls, channel, exchange, payload):
  145. routing_key = cls._get_produced_message_routing_key(payload)
  146. producer = Producer(
  147. channel,
  148. exchange=exchange,
  149. auto_declare=True,
  150. )
  151. producer.publish(
  152. ujson.dumps(payload.to_dict()),
  153. routing_key=routing_key,
  154. mandatory=True,
  155. content_type='text/plain',
  156. delivery_mode=2,
  157. )
  158. @staticmethod
  159. def _get_produced_message_routing_key(payload):
  160. routing_key = payload.cqrs_id
  161. if payload.signal_type == SignalType.SYNC and payload.queue:
  162. routing_key = 'cqrs.{0}.{1}'.format(payload.queue, routing_key)
  163. return routing_key
  164. @classmethod
  165. def _get_producer_kombu_objects(cls, url, exchange_name):
  166. connection = Connection(url)
  167. channel = connection.channel()
  168. return connection, channel
  169. @staticmethod
  170. def _create_exchange(exchange_name):
  171. return Exchange(
  172. exchange_name,
  173. type='topic',
  174. durable=True,
  175. )
  176. @staticmethod
  177. def _get_common_settings():
  178. url = settings.CQRS.get('url', 'amqp://localhost')
  179. exchange = settings.CQRS.get('exchange', 'cqrs')
  180. return (
  181. url,
  182. exchange,
  183. )
  184. @staticmethod
  185. def _get_consumer_settings():
  186. queue_name = settings.CQRS['queue']
  187. consumer_prefetch_count = settings.CQRS.get('consumer_prefetch_count', 10)
  188. return (
  189. queue_name,
  190. consumer_prefetch_count,
  191. )