rabbit_mq.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # Copyright © 2023 Ingram Micro Inc. All rights reserved.
  2. import logging
  3. import time
  4. from datetime import timedelta
  5. from socket import gaierror
  6. from urllib.parse import unquote, urlparse
  7. import ujson
  8. from django.conf import settings
  9. from django.utils import timezone
  10. from pika import (
  11. BasicProperties,
  12. BlockingConnection,
  13. ConnectionParameters,
  14. credentials,
  15. exceptions,
  16. )
  17. from pika.adapters.utils.connection_workflow import AMQPConnectorException
  18. from dj_cqrs.constants import DEFAULT_DEAD_MESSAGE_TTL, SignalType
  19. from dj_cqrs.controller import consumer
  20. from dj_cqrs.dataclasses import TransportPayload
  21. from dj_cqrs.delay import DelayMessage, DelayQueue
  22. from dj_cqrs.registries import ReplicaRegistry
  23. from dj_cqrs.transport import BaseTransport
  24. from dj_cqrs.transport.mixins import LoggingMixin
  25. from dj_cqrs.utils import get_delay_queue_max_size, get_messages_prefetch_count_per_worker
  26. logger = logging.getLogger('django-cqrs')
  27. class RabbitMQTransport(LoggingMixin, BaseTransport):
  28. """Transport class for RabbitMQ."""
  29. CONSUMER_RETRY_TIMEOUT = 5
  30. PRODUCER_RETRIES = 1
  31. _producer_connection = None
  32. _producer_channel = None
  33. @classmethod
  34. def clean_connection(cls):
  35. """Clean the RabbitMQ connection."""
  36. connection = cls._producer_connection
  37. if connection and not connection.is_closed:
  38. try:
  39. connection.close()
  40. except (exceptions.StreamLostError, exceptions.ConnectionClosed, ConnectionError):
  41. logger.warning('Connection was closed or is closing. Skip it...')
  42. cls._producer_connection = None
  43. cls._producer_channel = None
  44. @classmethod
  45. def consume(cls, cqrs_ids=None):
  46. """Receive data from master model.
  47. Args:
  48. cqrs_ids (str): cqrs ids.
  49. """
  50. consumer_rabbit_settings = cls._get_consumer_settings()
  51. common_rabbit_settings = cls._get_common_settings()
  52. while True:
  53. connection = None
  54. try:
  55. delay_queue = DelayQueue(max_size=get_delay_queue_max_size())
  56. connection, channel, consumer_generator = cls._get_consumer_rmq_objects(
  57. *(common_rabbit_settings + consumer_rabbit_settings),
  58. cqrs_ids=cqrs_ids,
  59. )
  60. for method_frame, properties, body in consumer_generator:
  61. if method_frame is not None:
  62. cls._consume_message(
  63. channel,
  64. method_frame,
  65. properties,
  66. body,
  67. delay_queue,
  68. )
  69. cls._process_delay_messages(channel, delay_queue)
  70. except (
  71. exceptions.AMQPError,
  72. exceptions.ChannelError,
  73. exceptions.ReentrancyError,
  74. gaierror,
  75. ):
  76. logger.warning('AMQP connection error. Reconnecting...', exc_info=True)
  77. time.sleep(cls.CONSUMER_RETRY_TIMEOUT)
  78. finally:
  79. if connection and not connection.is_closed:
  80. connection.close()
  81. @classmethod
  82. def produce(cls, payload):
  83. """
  84. Send data from master model to replicas.
  85. Args:
  86. payload (dj_cqrs.dataclasses.TransportPayload): Transport payload from master model.
  87. """
  88. cls._produce_with_retries(payload, retries=cls.PRODUCER_RETRIES)
  89. @classmethod
  90. def _produce_with_retries(cls, payload, retries):
  91. try:
  92. rmq_settings = cls._get_common_settings()
  93. exchange = rmq_settings[-1]
  94. # Decided not to create context-manager to stay within the class
  95. _, channel = cls._get_producer_rmq_objects(
  96. *rmq_settings,
  97. signal_type=payload.signal_type,
  98. )
  99. cls._produce_message(channel, exchange, payload)
  100. cls.log_produced(payload)
  101. except (
  102. exceptions.AMQPError,
  103. exceptions.ChannelError,
  104. exceptions.ReentrancyError,
  105. AMQPConnectorException,
  106. AssertionError,
  107. ) as e:
  108. # in case of any error - close connection and try to reconnect
  109. cls.clean_connection()
  110. base_log_message = "CQRS couldn't be published: pk = {0} ({1}).".format(
  111. payload.pk,
  112. payload.cqrs_id,
  113. )
  114. if not retries:
  115. logger.exception(base_log_message)
  116. return
  117. logger.warning(
  118. '{0} Error: {1}. Reconnect...'.format(
  119. base_log_message,
  120. e.__class__.__name__,
  121. ),
  122. )
  123. cls._produce_with_retries(payload, retries - 1)
  124. @classmethod
  125. def _consume_message(cls, ch, method, properties, body, delay_queue):
  126. try:
  127. dct = ujson.loads(body)
  128. except ValueError:
  129. logger.error("CQRS couldn't be parsed: {0}.".format(body))
  130. ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
  131. return
  132. required_keys = {'instance_pk', 'signal_type', 'cqrs_id', 'instance_data'}
  133. for key in required_keys:
  134. if key not in dct:
  135. msg = "CQRS couldn't proceed, %s isn't found in body: %s."
  136. logger.error(msg, key, body)
  137. ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
  138. return
  139. payload = TransportPayload.from_message(dct)
  140. cls.log_consumed(payload)
  141. delivery_tag = method.delivery_tag
  142. if payload.is_expired():
  143. cls._add_to_dead_letter_queue(ch, payload)
  144. cls._nack(ch, delivery_tag)
  145. return
  146. instance, exception = None, None
  147. try:
  148. instance = consumer.consume(payload)
  149. except Exception as e:
  150. exception = e
  151. logger.error('CQRS service exception', exc_info=True)
  152. if instance and exception is None:
  153. cls._ack(ch, delivery_tag, payload)
  154. else:
  155. cls._fail_message(
  156. ch,
  157. delivery_tag,
  158. payload,
  159. exception,
  160. delay_queue,
  161. )
  162. @classmethod
  163. def _fail_message(cls, channel, delivery_tag, payload, exception, delay_queue):
  164. cls.log_consumed_failed(payload)
  165. model_cls = ReplicaRegistry.get_model_by_cqrs_id(payload.cqrs_id)
  166. if model_cls is None:
  167. logger.error('Model for cqrs_id {0} is not found.'.format(payload.cqrs_id))
  168. cls._nack(channel, delivery_tag)
  169. return
  170. if model_cls.should_retry_cqrs(payload.retries, exception):
  171. delay = model_cls.get_cqrs_retry_delay(payload.retries)
  172. cls._delay_message(channel, delivery_tag, payload, delay, delay_queue)
  173. else:
  174. cls._add_to_dead_letter_queue(channel, payload)
  175. cls._nack(channel, delivery_tag)
  176. @classmethod
  177. def _delay_message(cls, channel, delivery_tag, payload, delay, delay_queue):
  178. if delay_queue.full():
  179. # Memory limits handling, requeuing message with lowest ETA
  180. requeue_message = delay_queue.get()
  181. cls._requeue_message(
  182. channel,
  183. requeue_message.delivery_tag,
  184. requeue_message.payload,
  185. )
  186. eta = timezone.now() + timedelta(seconds=delay)
  187. delay_message = DelayMessage(delivery_tag, payload, eta)
  188. delay_queue.put(delay_message)
  189. cls.log_delayed(payload, delay, delay_message.eta)
  190. @classmethod
  191. def _add_to_dead_letter_queue(cls, channel, payload):
  192. replica_settings = settings.CQRS.get('replica', {})
  193. dead_message_ttl = DEFAULT_DEAD_MESSAGE_TTL
  194. if 'dead_message_ttl' in replica_settings:
  195. dead_message_ttl = replica_settings['dead_message_ttl']
  196. expiration = None
  197. if dead_message_ttl is not None:
  198. expiration = str(dead_message_ttl * 1000) # milliseconds
  199. payload.is_dead_letter = True
  200. exchange = cls._get_common_settings()[-1]
  201. cls._produce_message(channel, exchange, payload, expiration)
  202. cls.log_dead_letter(payload)
  203. @classmethod
  204. def _requeue_message(cls, channel, delivery_tag, payload):
  205. payload.retries += 1
  206. payload.is_requeue = True
  207. cls.produce(payload)
  208. cls._nack(channel, delivery_tag)
  209. cls.log_requeued(payload)
  210. @classmethod
  211. def _process_delay_messages(cls, channel, delay_queue):
  212. for delay_message in delay_queue.get_ready():
  213. cls._requeue_message(channel, delay_message.delivery_tag, delay_message.payload)
  214. @classmethod
  215. def _produce_message(cls, channel, exchange, payload, expiration=None):
  216. routing_key = cls._get_produced_message_routing_key(payload)
  217. channel.basic_publish(
  218. exchange=exchange,
  219. routing_key=routing_key,
  220. body=ujson.dumps(payload.to_dict()),
  221. mandatory=True,
  222. properties=BasicProperties(
  223. content_type='text/plain',
  224. delivery_mode=2, # make message persistent
  225. expiration=expiration,
  226. ),
  227. )
  228. @classmethod
  229. def _get_produced_message_routing_key(cls, payload):
  230. routing_key = payload.cqrs_id
  231. if payload.signal_type == SignalType.SYNC and payload.queue:
  232. routing_key = 'cqrs.{0}.{1}'.format(payload.queue, routing_key)
  233. elif getattr(payload, 'is_dead_letter', False):
  234. dead_letter_queue_name = cls._get_consumer_settings()[1]
  235. routing_key = 'cqrs.{0}.{1}'.format(dead_letter_queue_name, routing_key)
  236. elif getattr(payload, 'is_requeue', False):
  237. queue = cls._get_consumer_settings()[0]
  238. routing_key = 'cqrs.{0}.{1}'.format(queue, routing_key)
  239. return routing_key
  240. @classmethod
  241. def _get_consumer_rmq_objects(
  242. cls,
  243. host,
  244. port,
  245. creds,
  246. exchange,
  247. queue_name,
  248. dead_letter_queue_name,
  249. prefetch_count,
  250. cqrs_ids=None,
  251. ):
  252. connection = BlockingConnection(
  253. ConnectionParameters(host=host, port=port, credentials=creds),
  254. )
  255. channel = connection.channel()
  256. channel.basic_qos(prefetch_count=prefetch_count)
  257. cls._declare_exchange(channel, exchange)
  258. channel.queue_declare(queue_name, durable=True, exclusive=False)
  259. channel.queue_declare(dead_letter_queue_name, durable=True, exclusive=False)
  260. for cqrs_id, _ in ReplicaRegistry.models.items():
  261. if cqrs_ids and cqrs_id not in cqrs_ids:
  262. continue
  263. channel.queue_bind(exchange=exchange, queue=queue_name, routing_key=cqrs_id)
  264. # Every service must have specific SYNC or requeue routes
  265. channel.queue_bind(
  266. exchange=exchange,
  267. queue=queue_name,
  268. routing_key='cqrs.{0}.{1}'.format(queue_name, cqrs_id),
  269. )
  270. # Dead letter
  271. channel.queue_bind(
  272. exchange=exchange,
  273. queue=dead_letter_queue_name,
  274. routing_key='cqrs.{0}.{1}'.format(dead_letter_queue_name, cqrs_id),
  275. )
  276. delay_queue_check_timeout = 1 # seconds
  277. consumer_generator = channel.consume(
  278. queue=queue_name,
  279. auto_ack=False,
  280. exclusive=False,
  281. inactivity_timeout=delay_queue_check_timeout,
  282. )
  283. return connection, channel, consumer_generator
  284. @classmethod
  285. def _get_producer_rmq_objects(cls, host, port, creds, exchange, signal_type=None):
  286. """
  287. Use shared connection in case of sync mode, otherwise create new connection for each
  288. message
  289. """
  290. if signal_type == SignalType.SYNC:
  291. if cls._producer_connection is None:
  292. connection, channel = cls._create_connection(host, port, creds, exchange)
  293. cls._producer_connection = connection
  294. cls._producer_channel = channel
  295. return cls._producer_connection, cls._producer_channel
  296. else:
  297. return cls._create_connection(host, port, creds, exchange)
  298. @classmethod
  299. def _create_connection(cls, host, port, creds, exchange):
  300. connection = BlockingConnection(
  301. ConnectionParameters(
  302. host=host,
  303. port=port,
  304. credentials=creds,
  305. blocked_connection_timeout=10,
  306. ),
  307. )
  308. channel = connection.channel()
  309. channel.basic_qos(prefetch_count=get_messages_prefetch_count_per_worker())
  310. cls._declare_exchange(channel, exchange)
  311. return connection, channel
  312. @staticmethod
  313. def _declare_exchange(channel, exchange):
  314. channel.exchange_declare(
  315. exchange=exchange,
  316. exchange_type='topic',
  317. durable=True,
  318. )
  319. @staticmethod
  320. def _parse_url(url):
  321. scheme = urlparse(url).scheme
  322. assert scheme == 'amqp', 'Scheme must be "amqp" for RabbitMQTransport.'
  323. schemeless = url[len(scheme) + 3 :]
  324. parts = urlparse('http://' + schemeless)
  325. return (
  326. unquote(parts.hostname or '') or ConnectionParameters.DEFAULT_HOST,
  327. parts.port or ConnectionParameters.DEFAULT_PORT,
  328. unquote(parts.username or '') or ConnectionParameters.DEFAULT_USERNAME,
  329. unquote(parts.password or '') or ConnectionParameters.DEFAULT_PASSWORD,
  330. )
  331. @classmethod
  332. def _get_common_settings(cls):
  333. if 'url' in settings.CQRS:
  334. host, port, user, password = cls._parse_url(settings.CQRS.get('url'))
  335. else:
  336. host = settings.CQRS.get('host', ConnectionParameters.DEFAULT_HOST)
  337. port = settings.CQRS.get('port', ConnectionParameters.DEFAULT_PORT)
  338. user = settings.CQRS.get('user', ConnectionParameters.DEFAULT_USERNAME)
  339. password = settings.CQRS.get('password', ConnectionParameters.DEFAULT_PASSWORD)
  340. exchange = settings.CQRS.get('exchange', 'cqrs')
  341. return (
  342. host,
  343. port,
  344. credentials.PlainCredentials(user, password, erase_on_connect=True),
  345. exchange,
  346. )
  347. @staticmethod
  348. def _get_consumer_settings():
  349. queue_name = settings.CQRS['queue']
  350. replica_settings = settings.CQRS.get('replica', {})
  351. dead_letter_queue_name = 'dead_letter_{0}'.format(queue_name)
  352. if 'dead_letter_queue' in replica_settings:
  353. dead_letter_queue_name = replica_settings['dead_letter_queue']
  354. if 'consumer_prefetch_count' in settings.CQRS:
  355. logger.warning(
  356. "The 'consumer_prefetch_count' setting is ignored for RabbitMQTransport.",
  357. )
  358. prefetch_count = get_messages_prefetch_count_per_worker()
  359. return (
  360. queue_name,
  361. dead_letter_queue_name,
  362. prefetch_count,
  363. )
  364. @classmethod
  365. def _ack(cls, channel, delivery_tag, payload=None):
  366. channel.basic_ack(delivery_tag)
  367. if payload is not None:
  368. cls.log_consumed_accepted(payload)
  369. @classmethod
  370. def _nack(cls, channel, delivery_tag, payload=None):
  371. channel.basic_nack(delivery_tag, requeue=False)
  372. if payload is not None:
  373. cls.log_consumed_denied(payload)