diff --git a/hexlib/monitoring.py b/hexlib/monitoring.py index ac2651a..74292c1 100644 --- a/hexlib/monitoring.py +++ b/hexlib/monitoring.py @@ -1,22 +1,29 @@ -import logging import traceback +from abc import ABC from influxdb import InfluxDBClient from hexlib.misc import buffered -class Monitoring: - def __init__(self, db, host="localhost", logger=logging.getLogger("default"), batch_size=1, flush_on_exit=False): - self._db = db - self._client = InfluxDBClient(host, 8086, "", "", db) +class Monitoring(ABC): + def log(self, points): + raise NotImplementedError() + + +class BufferedInfluxDBMonitoring(Monitoring): + def __init__(self, db_name, host="localhost", port=8086, logger=None, batch_size=1, flush_on_exit=False): + self._db = db_name + self._client = InfluxDBClient(host, port, "", "", db_name) self._logger = logger - self._init() + if not self.db_exists(self._db): + self._client.create_database(self._db) @buffered(batch_size, flush_on_exit) def log(points): self._log(points) + self.log = log def db_exists(self, name): @@ -25,14 +32,16 @@ class Monitoring: return True return False - def _init(self): - if not self.db_exists(self._db): - self._client.create_database(self._db) + def log(self, points): + # Is overwritten in __init__() + pass def _log(self, points): try: self._client.write_points(points) - self._logger.debug("InfluxDB: Wrote %d points" % len(points)) + if self._logger: + self._logger.debug("InfluxDB: Wrote %d points" % len(points)) except Exception as e: - self._logger.debug(traceback.format_exc()) - self._logger.error(str(e)) + if self._logger: + self._logger.debug(traceback.format_exc()) + self._logger.error(str(e)) diff --git a/hexlib/mq.py b/hexlib/mq.py new file mode 100644 index 0000000..ae3bcce --- /dev/null +++ b/hexlib/mq.py @@ -0,0 +1,132 @@ +import json +from functools import partial +from itertools import islice +from time import sleep, time + +from orjson import orjson +from redis import Redis + + +class MessageQueue: + def read_messages(self, topics): + raise NotImplementedError() + + def publish(self, item, item_project, item_type, item_subproject, item_category): + raise NotImplementedError() + + +class RedisMQ(MessageQueue): + _MAX_KEYS = 30 + + def __init__(self, rdb, consumer_name="redis_mq", sep=".", max_pending_time=120, logger=None, publish_channel=None, + arc_lists=None, wait=1): + self._rdb: Redis = rdb + self._key_cache = None + self._consumer_id = consumer_name + self._pending_list = f"pending{sep}{consumer_name}" + self._max_pending_time = max_pending_time + self._logger = logger + self._publish_channel = publish_channel + self._arc_lists = arc_lists + self._wait = wait + + def _get_keys(self, pattern): + if self._key_cache: + return self._key_cache + + keys = list(islice( + self._rdb.scan_iter(match=pattern, count=RedisMQ._MAX_KEYS), RedisMQ._MAX_KEYS + )) + self._key_cache = keys + + return keys + + def _get_pending_tasks(self): + for task_id, pending_task in self._rdb.hscan_iter(self._pending_list): + + pending_task_json = orjson.loads(pending_task) + + if time() >= pending_task_json["resubmit_at"]: + yield pending_task_json["topic"], pending_task_json["task"], partial(self._ack, task_id) + + def _ack(self, task_id): + self._rdb.hdel(self._pending_list, task_id) + + def read_messages(self, topics): + """ + Assumes json-encoded tasks with an _id field + + Tasks are automatically put into a pending list until ack() is called. + When a task has been in the pending list for at least max_pending_time seconds, it + gets submitted again + """ + + assert len(topics) == 1, "RedisMQ only supports 1 topic pattern" + + pattern = topics[0] + counter = 0 + + if self._logger: + self._logger.info(f"MQ>Listening for new messages in {pattern}") + + while True: + counter += 1 + + if counter % 1000 == 0: + yield from self._get_pending_tasks() + + keys = self._get_keys(pattern) + if not keys: + sleep(self._wait) + self._key_cache = None + continue + + result = self._rdb.blpop(keys, timeout=1) + if not result: + self._key_cache = None + continue + + topic, task = result + + task_json = orjson.loads(task) + topic = topic.decode() + + if "_id" not in task_json: + raise ValueError(f"Task doesn't have _id field: {task}") + + # Immediately put in pending queue + self._rdb.hset( + self._pending_list, task_json["_id"], + orjson.dumps({ + "resubmit_at": time() + self._max_pending_time, + "topic": topic, + "task": task_json + }) + ) + + yield topic, task_json, partial(self._ack, task_json["_id"]) + + def publish(self, item, item_project, item_type, item_subproject=None, item_category="x"): + + if "_id" not in item: + raise ValueError("_id field must be set for item") + + item = json.dumps(item, separators=(',', ':'), ensure_ascii=False, sort_keys=True) + + item_project = item_project.replace(".", "-") + item_subproject = item_subproject.replace(".", "-") if item_subproject else None + + item_source = item_project if not item_subproject else f"{item_project}.{item_subproject}" + + item_type = item_type.replace(".", "-") + item_category = item_category.replace(".", "-") + + # If specified, fan-out to pub/sub channel + if self._publish_channel is not None: + routing_key = f"{self._publish_channel}.{item_source}.{item_type}.{item_category}" + self._rdb.publish(routing_key, item) + + # Save to list + for arc_list in self._arc_lists: + routing_key = f"{arc_list}.{item_source}.{item_type}.{item_category}" + self._rdb.lpush(routing_key, item) diff --git a/hexlib/web.py b/hexlib/web.py index 97b203f..95942b3 100644 --- a/hexlib/web.py +++ b/hexlib/web.py @@ -126,6 +126,7 @@ def download_file(url, destination, session=None, headers=None, overwrite=False, "url": url, "timestamp": datetime.utcnow().replace(microsecond=0).isoformat() })) + r.close() break except Exception as e: if err_cb: @@ -142,7 +143,7 @@ class Web: self._logger = logger self._current_req = None if retry_codes is None or not retry_codes: - retry_codes = {502, 504, 520, 522, 524, 429} + retry_codes = {500, 502, 504, 520, 522, 524, 429} self._retry_codes = retry_codes if session is None: diff --git a/test/test_redis_mq.py b/test/test_redis_mq.py new file mode 100644 index 0000000..f12f637 --- /dev/null +++ b/test/test_redis_mq.py @@ -0,0 +1,49 @@ +from unittest import TestCase + +from hexlib.env import get_redis +from hexlib.mq import RedisMQ + + +class TestRedisMQ(TestCase): + + def setUp(self) -> None: + self.rdb = get_redis() + self.rdb.delete("pending.test", "test_mq", "arc.test.msg.x") + + def test_ack(self): + mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=2, arc_lists=["arc"]) + + mq.publish({"_id": 1}, item_project="test", item_type="msg") + + topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.*"])) + + self.assertEqual(self.rdb.hlen("pending.test"), 1) + + ack1() + + self.assertEqual(self.rdb.hlen("pending.test"), 0) + + def test_pending_timeout(self): + mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0) + + mq.publish({"_id": 1}, item_project="test", item_type="msg") + + topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.test.*"])) + + self.assertEqual(self.rdb.hlen("pending.test"), 1) + + # msg1 will timeout after 0.5s, next iteration takes ceil(0.5)s + topic1_, msg1_, ack1_ = next(mq.read_messages(topics=["arc.test.*"])) + self.assertEqual(self.rdb.hlen("pending.test"), 1) + + ack1_() + + self.assertEqual(self.rdb.hlen("pending.test"), 0) + + self.assertEqual(msg1, msg1_) + + def test_no_id_field(self): + mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0) + + with self.assertRaises(ValueError): + mq.publish({"a": 1}, item_project="test", item_type="msg")