Compare commits

...

34 Commits

Author SHA1 Message Date
4c8b74bd8f Merge pull request #1 from simon987/dependabot/pip/pydantic-1.10.13
Bump pydantic from 1.10.11 to 1.10.13
2024-04-25 08:45:06 -04:00
dependabot[bot]
d82e1bccee Bump pydantic from 1.10.11 to 1.10.13
Bumps [pydantic](https://github.com/pydantic/pydantic) from 1.10.11 to 1.10.13.
- [Release notes](https://github.com/pydantic/pydantic/releases)
- [Changelog](https://github.com/pydantic/pydantic/blob/main/HISTORY.md)
- [Commits](https://github.com/pydantic/pydantic/compare/v1.10.11...v1.10.13)

---
updated-dependencies:
- dependency-name: pydantic
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-25 05:06:23 +00:00
b1a1da3bac Add option to use nltk word_tokenize 2023-09-09 11:11:44 -04:00
a047366926 pin pydantic version 2023-07-13 08:27:48 -04:00
24230cdc1e Update PgConn 2023-05-26 14:09:45 -04:00
3bd9f03996 Fix PersistentState constructor 2023-04-10 17:01:42 -04:00
e267bbf1c8 Set json indent=2 for pydantic rows 2023-02-25 16:07:41 -05:00
42e33b72b2 Fix JSON encoding for Enums 2023-02-25 15:51:28 -05:00
5275c332cc Add drop table 2023-02-25 15:38:40 -05:00
a7b1a6e1ec Fix tests, add pydantic row support for PersistentState 2023-02-25 15:20:17 -05:00
826312115c Fix deserialization in PersistentState again 2022-05-07 09:41:10 -04:00
372abb0076 Fix deserialization in PersistentState 2022-05-07 09:34:50 -04:00
78c04ef6f3 Add option to override Table factory in PersistentState 2022-05-05 15:02:48 -04:00
a51ad2cbb4 Cleanup 2022-05-03 10:59:25 -04:00
4befc3973d Add strip_dashes option in preprocess() 2022-02-26 19:31:22 -05:00
c9fac7151a Split punctuation into punctuation and special_punctuation 2022-02-23 11:01:17 -05:00
084acbe184 Set max_window_size=2147483648 for zstd 2022-01-29 10:44:38 -05:00
d578be3218 Increase timeout 2022-01-29 10:38:23 -05:00
cd5a1ac50c Remove clean_multicore function 2022-01-28 20:11:26 -05:00
62e74ed292 Make sure that _id field is present in redis MQ 2022-01-27 11:06:52 -05:00
428c82bcfd Update retry codes 2021-12-09 19:46:12 -05:00
4b3583358b Update retry codes 2021-12-09 19:44:21 -05:00
90d434ec73 Add more single quotes 2021-11-16 15:32:24 -05:00
55fd4a66d2 Fix strip_quotes 2021-11-16 11:48:23 -05:00
3677815d57 Add more quotes in strip_quotes 2021-11-16 11:39:10 -05:00
1ce795a759 ... 2021-11-16 11:36:17 -05:00
e1537297d7 normalize dashes in preprocess 2021-11-16 11:34:48 -05:00
8d8f9e8751 Fix typo, add stateless stream processor 2021-11-04 13:23:23 -04:00
18ba0024ea Null check on logger 2021-11-03 16:47:21 -04:00
408735a926 Fix git clone 2021-11-02 14:38:12 -04:00
2f6c2822b6 Add functions to handle routing keys 2021-10-22 10:36:05 -04:00
d85ad919b3 Add redis MQ, update influxdb monitoring 2021-10-21 20:19:09 -04:00
ed9d148411 Fix tests 2021-10-21 19:52:42 -04:00
5e00ddccdb Add test file 2021-10-07 15:40:23 -04:00
21 changed files with 666 additions and 315 deletions

5
.gitignore vendored
View File

@@ -1,4 +1,7 @@
*.iml
.idea/
*.db
*.png
*.png
hexlib.egg-info
build/
dist/

BIN
10MB.bin Normal file

Binary file not shown.

View File

@@ -1,5 +1,5 @@
Misc utility methods in Python
```
git+git://github.com/simon987/hexlib.git
git+https://github.com/simon987/hexlib.git
```

View File

@@ -6,6 +6,74 @@ from threading import Thread
from hexlib.misc import ichunks
class StatelessStreamWorker:
def __init__(self):
self._q_out = None
def run(self, q: Queue, q_out: Queue):
self._q_out: Queue = q_out
for chunk in queue_iter(q, joinable=False, timeout=10):
self._process_chunk(chunk)
def _process_chunk(self, chunk):
results = []
for item in chunk:
result = self.process(item)
if result is not None:
results.append(result)
if results:
self._q_out.put(results)
def process(self, item):
raise NotImplementedError
class StatelessStreamProcessor:
def __init__(self, worker_factory, chunk_size=128, processes=1, timeout=60):
self._chunk_size = 128
self._queue = MPQueue(maxsize=chunk_size)
self._queue_out = MPQueue(maxsize=processes * 2)
self._process_count = processes
self._processes = []
self._factory = worker_factory
self._workers = []
self._timeout = timeout
if processes > 1:
for _ in range(processes):
worker = self._factory()
p = Process(target=worker.run, args=(self._queue, self._queue_out))
p.start()
self._processes.append(p)
self._workers.append(worker)
else:
self._workers.append(self._factory())
def _ingest(self, iterable):
if self._process_count > 1:
for chunk in ichunks(iterable, self._chunk_size):
self._queue.put(chunk)
else:
for item in iterable:
self._workers[0].process(item)
def ingest(self, iterable):
ingest_thread = Thread(target=self._ingest, args=(iterable,))
ingest_thread.start()
for results in queue_iter(self._queue_out, joinable=False, timeout=self._timeout):
yield from results
ingest_thread.join()
class StatefulStreamWorker:
def __init__(self):
@@ -49,7 +117,7 @@ class StatefulStreamProcessor:
else:
self._workers.append(self._factory())
def injest(self, iterable):
def ingest(self, iterable):
if self._process_count > 1:
for chunk in ichunks(iterable, self._chunk_size):

View File

@@ -1,26 +1,24 @@
import base64
import sqlite3
import traceback
from datetime import datetime
from enum import Enum
import psycopg2
import umsgpack
from psycopg2.errorcodes import UNIQUE_VIOLATION
from pydantic import BaseModel
from hexlib.env import get_redis
class PersistentState:
"""Quick and dirty persistent dict-like SQLite wrapper"""
def _json_encoder(x):
if isinstance(x, datetime):
return x.isoformat()
if isinstance(x, Enum):
return x.value
def __init__(self, dbfile="state.db", logger=None, **dbargs):
self.dbfile = dbfile
self.logger = logger
if dbargs is None:
dbargs = {"timeout": 30000}
self.dbargs = dbargs
def __getitem__(self, table):
return Table(self, table)
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
class VolatileState:
@@ -36,6 +34,9 @@ class VolatileState:
def __getitem__(self, table):
return RedisTable(self, table, self._sep)
def __delitem__(self, key):
self.rdb.delete(f"{self.prefix}{self._sep}{key}")
class VolatileQueue:
"""Quick and dirty volatile queue-like redis wrapper"""
@@ -68,6 +69,9 @@ class VolatileBooleanState:
def __getitem__(self, table):
return RedisBooleanTable(self, table, self._sep)
def __delitem__(self, table):
self.rdb.delete(f"{self.prefix}{self._sep}{table}")
class RedisTable:
def __init__(self, state, table, sep=""):
@@ -89,9 +93,9 @@ class RedisTable:
self._state.rdb.hdel(self._key, str(key))
def __iter__(self):
val = self._state.rdb.hgetall(self._key)
if val:
return ((k, umsgpack.loads(v)) for k, v in val.items())
for val in self._state.rdb.hscan(self._key):
if val:
return ((k, umsgpack.loads(v)) for k, v in val.items())
class RedisBooleanTable:
@@ -114,7 +118,7 @@ class RedisBooleanTable:
self._state.rdb.srem(self._key, str(key))
def __iter__(self):
return iter(self._state.rdb.smembers(self._key))
yield from self._state.rdb.sscan_iter(self._key)
class Table:
@@ -122,32 +126,54 @@ class Table:
self._state = state
self._table = table
def sql(self, where_clause, *params):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
for row in cur:
yield dict(row)
except:
return None
def __iter__(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
cur = conn.execute("SELECT * FROM %s" % (self._table,))
for row in cur:
yield dict(row)
except:
return None
def __getitem__(self, item):
def _sql_dict(self, where_clause, *params):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
for row in cur:
yield dict(
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
for i, col in enumerate(cur.description)
)
except:
return None
def sql(self, where_clause, *params):
for row in self._sql_dict(where_clause, *params):
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _iter_dict(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s" % (self._table,))
for row in cur:
yield dict(
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
for i, col in enumerate(cur.description)
)
except:
return None
def __iter__(self):
for row in self._iter_dict():
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _getitem_dict(self, key):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
row = cur.fetchone()
if row:
@@ -158,8 +184,32 @@ class Table:
except:
return None
@staticmethod
def _deserialize_pydantic(row):
module = __import__(row["__module"])
cls = getattr(module, row["__class"])
return cls.parse_raw(row["json"])
def __getitem__(self, key):
row = self._getitem_dict(key)
if row and "__pydantic" in row:
return self._deserialize_pydantic(row)
return row
def setitem_pydantic(self, key, value: BaseModel):
self.__setitem__(key, {
"json": value.json(encoder=_json_encoder, indent=2),
"__class": value.__class__.__name__,
"__module": value.__class__.__module__,
"__pydantic": 1
})
def __setitem__(self, key, value):
if isinstance(value, BaseModel):
self.setitem_pydantic(key, value)
return
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
@@ -217,11 +267,33 @@ def _serialize(value):
def _deserialize(value, col_type):
if col_type == "blob":
if col_type.lower() == "blob":
return base64.b64decode(value)
return value
class PersistentState:
"""Quick and dirty persistent dict-like SQLite wrapper"""
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
self.dbfile = dbfile
self.logger = logger
if dbargs is None or dbargs == {}:
dbargs = {"timeout": 30000}
self.dbargs = dbargs
self._table_factory = table_factory
def __getitem__(self, table):
return self._table_factory(self, table)
def __delitem__(self, key):
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
try:
conn.execute(f"DROP TABLE {key}")
except:
pass
def pg_fetch_cursor_all(cur, name, batch_size=1000):
while True:
cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name))
@@ -244,10 +316,10 @@ class PgConn:
def __init__(self, logger=None, **kwargs):
self._conn_args = kwargs
self.conn = psycopg2.connect(**kwargs)
self.cur = self.conn.cursor()
self._logger = logger
def __enter__(self):
self.cur = self.conn.cursor()
return self
def exec(self, query_string, args=None):

View File

@@ -45,7 +45,7 @@ def get_web(session=None):
rps=os.environ.get("RPS", 1),
logger=stdout_logger,
cookie_file=os.environ.get("COOKIE_FILE", None),
retry_codes=set(int(x) for x in retry_codes) if retry_codes else None,
retry_codes=set(int(x) for x in retry_codes.split(",")) if retry_codes else None,
retries=int(os.environ.get("RETRIES", 3)),
retry_sleep=int(os.environ.get("RETRY_SLEEP", 0)),
ua=ua[os.environ.get("USER_AGENT")] if os.environ.get("USER_AGENT", None) is not None else None

View File

@@ -85,7 +85,7 @@ def ndjson_iter(*files, compression=""):
line_iter = BufferedReader(gzip.open(file))
elif compression == COMPRESSION_ZSTD:
fp = open(file, "rb")
dctx = zstandard.ZstdDecompressor()
dctx = zstandard.ZstdDecompressor(max_window_size=2147483648)
reader = dctx.stream_reader(fp)
line_iter = BufferedReader(reader)

View File

@@ -1,22 +1,29 @@
import logging
import traceback
from abc import ABC
from influxdb import InfluxDBClient
from hexlib.misc import buffered
class Monitoring:
def __init__(self, db, host="localhost", logger=logging.getLogger("default"), batch_size=1, flush_on_exit=False):
self._db = db
self._client = InfluxDBClient(host, 8086, "", "", db)
class Monitoring(ABC):
def log(self, points):
raise NotImplementedError()
class BufferedInfluxDBMonitoring(Monitoring):
def __init__(self, db_name, host="localhost", port=8086, logger=None, batch_size=1, flush_on_exit=False):
self._db = db_name
self._client = InfluxDBClient(host, port, "", "", db_name)
self._logger = logger
self._init()
if not self.db_exists(self._db):
self._client.create_database(self._db)
@buffered(batch_size, flush_on_exit)
def log(points):
self._log(points)
self.log = log
def db_exists(self, name):
@@ -25,14 +32,16 @@ class Monitoring:
return True
return False
def _init(self):
if not self.db_exists(self._db):
self._client.create_database(self._db)
def log(self, points):
# Is overwritten in __init__()
pass
def _log(self, points):
try:
self._client.write_points(points)
self._logger.debug("InfluxDB: Wrote %d points" % len(points))
if self._logger:
self._logger.debug("InfluxDB: Wrote %d points" % len(points))
except Exception as e:
self._logger.debug(traceback.format_exc())
self._logger.error(str(e))
if self._logger:
self._logger.debug(traceback.format_exc())
self._logger.error(str(e))

161
hexlib/mq.py Normal file
View File

@@ -0,0 +1,161 @@
import json
from collections import namedtuple
from functools import partial
from itertools import islice
from time import sleep, time
from orjson import orjson
from redis import Redis
RoutingKeyParts = namedtuple(
"RoutingKeyParts",
["arc_list", "project", "subproject", "type", "category"]
)
def parse_routing_key(key):
tokens = key.split(".")
if len(tokens) == 4:
arc_list, project, type_, category = tokens
return RoutingKeyParts(
arc_list=arc_list,
project=project,
subproject=None,
type=type_,
category=category
)
else:
arc_list, project, subproject, type_, category = tokens
return RoutingKeyParts(
arc_list=arc_list,
project=project,
subproject=subproject,
type=type_,
category=category
)
class MessageQueue:
def read_messages(self, topics):
raise NotImplementedError()
def publish(self, item, item_project, item_type, item_subproject, item_category):
raise NotImplementedError()
class RedisMQ(MessageQueue):
_MAX_KEYS = 30
def __init__(self, rdb, consumer_name="redis_mq", sep=".", max_pending_time=120, logger=None, publish_channel=None,
arc_lists=None, wait=1):
self._rdb: Redis = rdb
self._key_cache = None
self._consumer_id = consumer_name
self._pending_list = f"pending{sep}{consumer_name}"
self._max_pending_time = max_pending_time
self._logger = logger
self._publish_channel = publish_channel
self._arc_lists = arc_lists
self._wait = wait
def _get_keys(self, pattern):
if self._key_cache:
return self._key_cache
keys = list(islice(
self._rdb.scan_iter(match=pattern, count=RedisMQ._MAX_KEYS), RedisMQ._MAX_KEYS
))
self._key_cache = keys
return keys
def _get_pending_tasks(self):
for task_id, pending_task in self._rdb.hscan_iter(self._pending_list):
pending_task_json = orjson.loads(pending_task)
if time() >= pending_task_json["resubmit_at"]:
yield pending_task_json["topic"], pending_task_json["task"], partial(self._ack, task_id)
def _ack(self, task_id):
self._rdb.hdel(self._pending_list, task_id)
def read_messages(self, topics):
"""
Assumes json-encoded tasks with an _id field
Tasks are automatically put into a pending list until ack() is called.
When a task has been in the pending list for at least max_pending_time seconds, it
gets submitted again
"""
assert len(topics) == 1, "RedisMQ only supports 1 topic pattern"
pattern = topics[0]
counter = 0
if self._logger:
self._logger.info(f"MQ>Listening for new messages in {pattern}")
while True:
counter += 1
if counter % 1000 == 0:
yield from self._get_pending_tasks()
keys = self._get_keys(pattern)
if not keys:
sleep(self._wait)
self._key_cache = None
continue
result = self._rdb.blpop(keys, timeout=1)
if not result:
self._key_cache = None
continue
topic, task = result
task_json = orjson.loads(task)
topic = topic.decode()
if "_id" not in task_json or not task_json["_id"]:
raise ValueError(f"Task doesn't have _id field: {task}")
# Immediately put in pending queue
self._rdb.hset(
self._pending_list, task_json["_id"],
orjson.dumps({
"resubmit_at": time() + self._max_pending_time,
"topic": topic,
"task": task_json
})
)
yield topic, task_json, partial(self._ack, task_json["_id"])
def publish(self, item, item_project, item_type, item_subproject=None, item_category="x"):
if "_id" not in item:
raise ValueError("_id field must be set for item")
item = json.dumps(item, separators=(',', ':'), ensure_ascii=False, sort_keys=True)
item_project = item_project.replace(".", "-")
item_subproject = item_subproject.replace(".", "-") if item_subproject else None
item_source = item_project if not item_subproject else f"{item_project}.{item_subproject}"
item_type = item_type.replace(".", "-")
item_category = item_category.replace(".", "-")
# If specified, fan-out to pub/sub channel
if self._publish_channel is not None:
routing_key = f"{self._publish_channel}.{item_source}.{item_type}.{item_category}"
self._rdb.publish(routing_key, item)
# Save to list
for arc_list in self._arc_lists:
routing_key = f"{arc_list}.{item_source}.{item_type}.{item_category}"
self._rdb.lpush(routing_key, item)

View File

@@ -1,227 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
DATA = [
*["apple"] * 5,
*["banana"] * 12,
*["strawberry"] * 8,
*["pineapple"] * 2,
]
class Cmap:
Accent = "Accent"
Accent_r = "Accent_r"
Blues = "Blues"
Blues_r = "Blues_r"
BrBG = "BrBG"
BrBG_r = "BrBG_r"
BuGn = "BuGn"
BuGn_r = "BuGn_r"
BuPu = "BuPu"
BuPu_r = "BuPu_r"
CMRmap = "CMRmap"
CMRmap_r = "CMRmap_r"
Dark2 = "Dark2"
Dark2_r = "Dark2_r"
GnBu = "GnBu"
GnBu_r = "GnBu_r"
Greens = "Greens"
Greens_r = "Greens_r"
Greys = "Greys"
Greys_r = "Greys_r"
OrRd = "OrRd"
OrRd_r = "OrRd_r"
Oranges = "Oranges"
Oranges_r = "Oranges_r"
PRGn = "PRGn"
PRGn_r = "PRGn_r"
Paired = "Paired"
Paired_r = "Paired_r"
Pastel1 = "Pastel1"
Pastel1_r = "Pastel1_r"
Pastel2 = "Pastel2"
Pastel2_r = "Pastel2_r"
PiYG = "PiYG"
PiYG_r = "PiYG_r"
PuBu = "PuBu"
PuBuGn = "PuBuGn"
PuBuGn_r = "PuBuGn_r"
PuBu_r = "PuBu_r"
PuOr = "PuOr"
PuOr_r = "PuOr_r"
PuRd = "PuRd"
PuRd_r = "PuRd_r"
Purples = "Purples"
Purples_r = "Purples_r"
RdBu = "RdBu"
RdBu_r = "RdBu_r"
RdGy = "RdGy"
RdGy_r = "RdGy_r"
RdPu = "RdPu"
RdPu_r = "RdPu_r"
RdYlBu = "RdYlBu"
RdYlBu_r = "RdYlBu_r"
RdYlGn = "RdYlGn"
RdYlGn_r = "RdYlGn_r"
Reds = "Reds"
Reds_r = "Reds_r"
Set1 = "Set1"
Set1_r = "Set1_r"
Set2 = "Set2"
Set2_r = "Set2_r"
Set3 = "Set3"
Set3_r = "Set3_r"
Spectral = "Spectral"
Spectral_r = "Spectral_r"
Wistia = "Wistia"
Wistia_r = "Wistia_r"
YlGn = "YlGn"
YlGnBu = "YlGnBu"
YlGnBu_r = "YlGnBu_r"
YlGn_r = "YlGn_r"
YlOrBr = "YlOrBr"
YlOrBr_r = "YlOrBr_r"
YlOrRd = "YlOrRd"
YlOrRd_r = "YlOrRd_r"
afmhot = "afmhot"
afmhot_r = "afmhot_r"
autumn = "autumn"
autumn_r = "autumn_r"
binary = "binary"
binary_r = "binary_r"
bone = "bone"
bone_r = "bone_r"
brg = "brg"
brg_r = "brg_r"
bwr = "bwr"
bwr_r = "bwr_r"
cividis = "cividis"
cividis_r = "cividis_r"
cool = "cool"
cool_r = "cool_r"
coolwarm = "coolwarm"
coolwarm_r = "coolwarm_r"
copper = "copper"
copper_r = "copper_r"
cubehelix = "cubehelix"
cubehelix_r = "cubehelix_r"
flag = "flag"
flag_r = "flag_r"
gist_earth = "gist_earth"
gist_earth_r = "gist_earth_r"
gist_gray = "gist_gray"
gist_gray_r = "gist_gray_r"
gist_heat = "gist_heat"
gist_heat_r = "gist_heat_r"
gist_ncar = "gist_ncar"
gist_ncar_r = "gist_ncar_r"
gist_rainbow = "gist_rainbow"
gist_rainbow_r = "gist_rainbow_r"
gist_stern = "gist_stern"
gist_stern_r = "gist_stern_r"
gist_yarg = "gist_yarg"
gist_yarg_r = "gist_yarg_r"
gnuplot = "gnuplot"
gnuplot2 = "gnuplot2"
gnuplot2_r = "gnuplot2_r"
gnuplot_r = "gnuplot_r"
gray = "gray"
gray_r = "gray_r"
hot = "hot"
hot_r = "hot_r"
hsv = "hsv"
hsv_r = "hsv_r"
inferno = "inferno"
inferno_r = "inferno_r"
jet = "jet"
jet_r = "jet_r"
magma = "magma"
magma_r = "magma_r"
nipy_spectral = "nipy_spectral"
nipy_spectral_r = "nipy_spectral_r"
ocean = "ocean"
ocean_r = "ocean_r"
pink = "pink"
pink_r = "pink_r"
plasma = "plasma"
plasma_r = "plasma_r"
prism = "prism"
prism_r = "prism_r"
rainbow = "rainbow"
rainbow_r = "rainbow_r"
seismic = "seismic"
seismic_r = "seismic_r"
spring = "spring"
spring_r = "spring_r"
summer = "summer"
summer_r = "summer_r"
tab10 = "tab10"
tab10_r = "tab10_r"
tab20 = "tab20"
tab20_r = "tab20_r"
tab20b = "tab20b"
tab20b_r = "tab20b_r"
tab20c = "tab20c"
tab20c_r = "tab20c_r"
terrain = "terrain"
terrain_r = "terrain_r"
turbo = "turbo"
turbo_r = "turbo_r"
twilight = "twilight"
twilight_r = "twilight_r"
twilight_shifted = "twilight_shifted"
twilight_shifted_r = "twilight_shifted_r"
viridis = "viridis"
viridis_r = "viridis_r"
winter = "winter"
winter_r = "winter_r"
def plot_freq_bar(items, ylabel="frequency", title=""):
item_set, item_counts = np.unique(items, return_counts=True)
plt.bar(item_set, item_counts)
plt.xticks(rotation=35)
plt.ylabel(ylabel)
plt.title(title)
for i, cnt in enumerate(item_counts):
plt.text(x=i, y=cnt / 2, s=cnt, ha="center", color="white")
plt.tight_layout()
def plot_confusion_matrix(y_true=None, y_pred=None, cm=None, labels=None, title=None, cmap=None):
if not cm:
cm = confusion_matrix(y_true, y_pred, labels=labels)
if type(cm) == list:
cm = np.array(cm)
cm_display = ConfusionMatrixDisplay(cm, display_labels=labels)
cm_display.plot(cmap=cmap)
if title:
plt.title(title)
if labels:
plt.xticks(rotation=30)
plt.tight_layout()
if __name__ == '__main__':
plot_freq_bar(DATA, title="My title")
plt.show()
plot_confusion_matrix(
cm=[[12, 1, 0],
[3, 14, 1],
[5, 6, 7]],
title="My title",
labels=["apple", "orange", "grape"],
cmap=Cmap.viridis
)
plt.show()

View File

@@ -1,16 +1,20 @@
from functools import partial
import re
from itertools import chain, repeat
from multiprocessing.pool import Pool
import nltk.corpus
from lxml import etree
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE
from .regex_util import LINK_RE
get_text = etree.XPath("//text()")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
stop_words_en = set(stopwords.words("english"))
extra_stop_words_en = [
@@ -19,21 +23,9 @@ extra_stop_words_en = [
stop_words_en.update(extra_stop_words_en)
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
lemmatizer = WordNetLemmatizer()
def clean_multicore(texts, processes, chunk_size=10000, **kwargs):
pool = Pool(processes=processes)
yield from pool.imap(
func=partial(preprocess, **kwargs),
iterable=texts,
chunksize=chunk_size
)
def _transform_bigram(ngram_seq, ngrams):
for ngram in ngram_seq:
if ngram in ngrams:
@@ -56,22 +48,37 @@ def _transform_trigram(ngram_seq, ngrams):
yield ngram[0]
SINGLE_QUOTES = ("", "`")
SINGLE_QUOTES = ("", "`", "")
SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", len(SINGLE_QUOTES))))
PUNCTUATION = ".,;:\"!?/()|*=>"
DASHES = ("", "", "", "")
DASHES_TRANS = str.maketrans("".join(DASHES), "".join(repeat("-", len(DASHES))))
DASHES_RE = re.compile(r"-+")
SPECIAL_PUNCTUATION = ";:\"/()|*=>"
SPECIAL_PUNCTUATION_TRANS = str.maketrans(SPECIAL_PUNCTUATION, " " * len(SPECIAL_PUNCTUATION))
PUNCTUATION = ".,!?"
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION))
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_stopwords_en=False,
lemmatize=False, fix_single_quotes=False, strip_quotes=False, remove_urls=False, bigrams: set = None,
trigrams: set = None, remove_numbers=False):
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_special_punctuation=False,
remove_stopwords_en=False, lemmatize=False, fix_single_quotes=False, strip_quotes=False,
strip_dashes=False,
remove_urls=False, bigrams: set = None, trigrams: set = None, remove_numbers=False,
use_nltk_tokenizer=False):
if lowercase:
text = text.lower()
if fix_single_quotes:
text = text.translate(SINGLE_QUOTE_TRANS)
text = text.translate(DASHES_TRANS)
if strip_dashes:
text = DASHES_RE.sub("-", text)
if remove_urls:
text = LINK_RE.sub(" ", text)
@@ -89,10 +96,19 @@ def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False
if remove_punctuation:
text = text.translate(PUNCTUATION_TRANS)
words = text.split()
if remove_special_punctuation:
text = text.translate(SPECIAL_PUNCTUATION_TRANS)
if use_nltk_tokenizer:
words = word_tokenize(text, language="english")
else:
words = text.split()
if strip_quotes:
words = filter(lambda w: w.strip("\"'"), words)
words = map(lambda w: w.strip("\"'“”"), words)
if strip_dashes:
words = map(lambda w: w.strip("-"), words)
if bigrams:
words = _transform_bigram(nltk.bigrams(chain(words, ("*",))), bigrams)

View File

@@ -126,6 +126,7 @@ def download_file(url, destination, session=None, headers=None, overwrite=False,
"url": url,
"timestamp": datetime.utcnow().replace(microsecond=0).isoformat()
}))
r.close()
break
except Exception as e:
if err_cb:
@@ -142,7 +143,7 @@ class Web:
self._logger = logger
self._current_req = None
if retry_codes is None or not retry_codes:
retry_codes = {502, 504, 520, 522, 524, 429}
retry_codes = {500, 502, 503, 504, 520, 522, 524, 429}
self._retry_codes = retry_codes
if session is None:
@@ -187,7 +188,8 @@ class Web:
self._post = post
def _error_callback(self, e):
self._logger.critical(f"{self._format_url(*self._current_req)}: {e}")
if self._logger:
self._logger.critical(f"{self._format_url(*self._current_req)}: {e}")
def _format_url(self, method, url, kwargs, r=None):
if "params" in kwargs and kwargs["params"]:

View File

@@ -2,7 +2,7 @@ from setuptools import setup
setup(
name="hexlib",
version="1.62",
version="1.89",
description="Misc utility methods",
author="simon987",
author_email="me@simon987.net",
@@ -12,8 +12,9 @@ setup(
"data/*"
]},
install_requires=[
"ImageHash", "influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
"influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
"u-msgpack-python", "psycopg2-binary", "bs4", "lxml", "nltk", "numpy",
"matplotlib", "scikit-learn", "fake-useragent @ git+git://github.com/Jordan9675/fake-useragent"
"matplotlib", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent",
"requests", "pydantic==1.10.13"
]
)

0
test/__init__.py Normal file
View File

View File

@@ -110,3 +110,34 @@ class TestPersistentState(TestCase):
del s["a"][456]
except Exception as e:
self.fail(e)
def test_deserialize_get_set(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(s["a"][0]["x"], b'abc')
def test_deserialize_sql(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"].sql("WHERE 1=1"))[0]["x"], b'abc')
def test_deserialize_iter(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"])[0]["x"], b'abc')
def test_drop_table(self):
s = PersistentState()
s["a"][0] = {"x": 1}
s["a"][1] = {"x": 2}
self.assertEqual(len(list(s["a"])), 2)
del s["a"]
self.assertEqual(len(list(s["a"])), 0)

110
test/test_PydanticTable.py Normal file
View File

@@ -0,0 +1,110 @@
import os
from datetime import datetime
from enum import Enum
from typing import Optional
from unittest import TestCase
from pydantic import BaseModel
from pydantic.types import List
from hexlib.db import PersistentState
class Status(Enum):
yes = "yes"
no = "no"
class Point(BaseModel):
x: int
y: int
class Polygon(BaseModel):
points: List[Point] = []
created_date: datetime
status: Status = Status("yes")
class TestPydanticTable(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
],
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
self.assertEqual(s["a"]["1"].status, Status("yes"))
self.assertEqual(s["a"]["1"].points[1].x, 3)
self.assertEqual(s["a"]["1"].created_date.year, 2000)
def test_update(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
]
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
val.points[0].x = 2
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 2)
def test_sql(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"].sql(
"WHERE json->>'created_date' LIKE '2000-%'"
))
self.assertEqual(len(result), 1)
self.assertEqual(result[0].created_date.year, 2000)
def test_iterate(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"])
self.assertEqual(len(result), 2)
self.assertEqual(result[0].created_date.year, 2000)
self.assertEqual(result[1].created_date.year, 2010)

View File

@@ -1,10 +1,15 @@
from unittest import TestCase
from hexlib.db import VolatileState, VolatileBooleanState, VolatileQueue
from hexlib.env import get_redis
class TestVolatileState(TestCase):
def setUp(self) -> None:
rdb = get_redis()
rdb.delete("test1a", "test1b", "test1c", "test1:a", "test2b")
def test_get_set(self):
s = VolatileState(prefix="test1")
val = {
@@ -53,6 +58,10 @@ class TestVolatileState(TestCase):
class TestVolatileBoolState(TestCase):
def setUp(self) -> None:
rdb = get_redis()
rdb.delete("test1a", "test1b", "test1c", "test1:a", "test2b")
def test_get_set(self):
s = VolatileBooleanState(prefix="test1")

View File

@@ -2,12 +2,16 @@ import os
from unittest import TestCase
from hexlib.web import download_file
import warnings
class TestDownloadFile(TestCase):
def setUp(self) -> None:
warnings.filterwarnings(action="ignore", category=ResourceWarning)
def test_download_file(self):
download_file("http://ovh.net/files/10Mb.dat", "/tmp/10Mb.dat")
download_file("https://github.com/simon987/hexlib/raw/master/10MB.bin", "/tmp/10Mb.dat")
self.assertTrue(os.path.exists("/tmp/10Mb.dat"))
os.remove("/tmp/10Mb.dat")
@@ -22,8 +26,8 @@ class TestDownloadFile(TestCase):
self.assertEqual(len(exceptions), 3)
def test_download_file_meta(self):
download_file("http://ovh.net/files/10Mb.dat", "/tmp/10Mb.dat", save_meta=True)
download_file("https://github.com/simon987/hexlib/raw/master/10MB.bin", "/tmp/10Mb.dat", save_meta=True)
self.assertTrue(os.path.exists("/tmp/10Mb.dat"))
self.assertTrue(os.path.exists("/tmp/10Mb.dat.meta"))
os.remove("/tmp/10Mb.dat")
# os.remove("/tmp/10Mb.dat.meta")
os.remove("/tmp/10Mb.dat.meta")

62
test/test_redis_mq.py Normal file
View File

@@ -0,0 +1,62 @@
from unittest import TestCase
from hexlib.env import get_redis
from hexlib.mq import RedisMQ, parse_routing_key, RoutingKeyParts
class TestRedisMQ(TestCase):
def setUp(self) -> None:
self.rdb = get_redis()
self.rdb.delete("pending.test", "test_mq", "arc.test.msg.x")
def test_ack(self):
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=2, arc_lists=["arc"])
mq.publish({"_id": 1}, item_project="test", item_type="msg")
topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.*"]))
self.assertEqual(self.rdb.hlen("pending.test"), 1)
ack1()
self.assertEqual(self.rdb.hlen("pending.test"), 0)
def test_pending_timeout(self):
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0)
mq.publish({"_id": 1}, item_project="test", item_type="msg")
topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.test.*"]))
self.assertEqual(self.rdb.hlen("pending.test"), 1)
# msg1 will timeout after 0.5s, next iteration takes ceil(0.5)s
topic1_, msg1_, ack1_ = next(mq.read_messages(topics=["arc.test.*"]))
self.assertEqual(self.rdb.hlen("pending.test"), 1)
ack1_()
self.assertEqual(self.rdb.hlen("pending.test"), 0)
self.assertEqual(msg1, msg1_)
def test_no_id_field(self):
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0)
with self.assertRaises(ValueError):
mq.publish({"a": 1}, item_project="test", item_type="msg")
class TestRoutingKey(TestCase):
def test1(self):
key = "arc.chan.4chan.post.b"
parts = parse_routing_key(key)
self.assertEqual(parts, RoutingKeyParts("arc", "chan", "4chan", "post", "b"))
def test2(self):
key = "arc.reddit.submission.birdpics"
parts = parse_routing_key(key)
self.assertEqual(parts, RoutingKeyParts("arc", "reddit", None, "submission", "birdpics"))

View File

@@ -152,7 +152,7 @@ class TestText(TestCase):
remove_stopwords_en=True,
remove_urls=True
)
expected = "hello world"
expected = "hello world |"
self.assertEqual(" ".join(cleaned), expected)
@@ -170,7 +170,7 @@ class TestText(TestCase):
remove_urls=False
)
expected = "217709510 is there a servant that is against civilization and humanity literally instant summon"
expected = ">>217709510 is there a servant that is against civilization and humanity literally instant summon"
self.assertEqual(" ".join(cleaned), expected)
def test_html_entity(self):
@@ -247,3 +247,33 @@ class TestText(TestCase):
expected = "hello1 test1124test world"
self.assertEqual(" ".join(cleaned), expected)
def test_strip_quotes(self):
text = "'hi' “test” 'hello\""
cleaned = preprocess(
text,
strip_quotes=True
)
expected = "hi test hello"
self.assertEqual(" ".join(cleaned), expected)
def test_strip_dashes(self):
text = "yes -But something-something -- hello aa--bb"
cleaned = preprocess(
text,
strip_dashes=True
)
expected = "yes But something-something hello aa-bb"
self.assertEqual(" ".join(cleaned), expected)
def test_word_tokenize(self):
text = "i cannot believe'"
cleaned = preprocess(
text,
use_nltk_tokenizer=True
)
expected = "i can not believe '"
self.assertEqual(" ".join(cleaned), expected)