mirror of
https://github.com/simon987/hexlib.git
synced 2025-12-14 15:19:05 +00:00
Compare commits
92 Commits
b845d96295
...
1.89
| Author | SHA1 | Date | |
|---|---|---|---|
| b1a1da3bac | |||
| a047366926 | |||
| 24230cdc1e | |||
| 3bd9f03996 | |||
| e267bbf1c8 | |||
| 42e33b72b2 | |||
| 5275c332cc | |||
| a7b1a6e1ec | |||
| 826312115c | |||
| 372abb0076 | |||
| 78c04ef6f3 | |||
| a51ad2cbb4 | |||
| 4befc3973d | |||
| c9fac7151a | |||
| 084acbe184 | |||
| d578be3218 | |||
| cd5a1ac50c | |||
| 62e74ed292 | |||
| 428c82bcfd | |||
| 4b3583358b | |||
| 90d434ec73 | |||
| 55fd4a66d2 | |||
| 3677815d57 | |||
| 1ce795a759 | |||
| e1537297d7 | |||
| 8d8f9e8751 | |||
| 18ba0024ea | |||
| 408735a926 | |||
| 2f6c2822b6 | |||
| d85ad919b3 | |||
| ed9d148411 | |||
| 5e00ddccdb | |||
| 7ecd55a1c6 | |||
| b746a91281 | |||
| 333083e8b9 | |||
| c295b5d30b | |||
| da0e117550 | |||
| c3fef7e7f8 | |||
| 9bd1f4b799 | |||
| c560cc2010 | |||
| f4a5e6cf53 | |||
| 71cd00c063 | |||
| 7349c9a5f1 | |||
| d19442b00e | |||
| 4711cd1b66 | |||
| 7e0ffafb8c | |||
| 60273fb6bd | |||
| 67c09cc10c | |||
| a7bf5b2d15 | |||
| 31b35e3a32 | |||
| 4cff343370 | |||
| 4d6c8018df | |||
| db3e191983 | |||
| 33e9734991 | |||
| 3238f92e4d | |||
| f8e93354a4 | |||
| 75bf2c2d85 | |||
| 9002ae7506 | |||
| 88f3124f85 | |||
| 8edad0255b | |||
| 32119535ae | |||
| 2ffaa4a5b3 | |||
| 067a20f7a8 | |||
| 00323ea576 | |||
| 45b5803c40 | |||
| 18cd59fc4a | |||
| d895ac837e | |||
| ae59522b27 | |||
| 765f6f59b7 | |||
| 30902c8235 | |||
| 53a262a138 | |||
| 6e1aa53455 | |||
| 53ac0c37e8 | |||
| 8378ed6526 | |||
| c79b3bfafd | |||
| 5ee1629c79 | |||
| 00f5aef721 | |||
| 021da84433 | |||
| 53a03baaa4 | |||
| 66d37e0be2 | |||
| 43cb6c4a7b | |||
| 4278b0f89e | |||
| 9738819428 | |||
| ce2e5b2af6 | |||
| 9cadce62ac | |||
| 7d330a0f9f | |||
| a2cfab55bc | |||
| c4fca1b754 | |||
| d615ebdbd9 | |||
| f914759b71 | |||
| 58d150279f | |||
| b2efaa99a4 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,3 +1,7 @@
|
||||
*.iml
|
||||
.idea/
|
||||
*.db
|
||||
*.db
|
||||
*.png
|
||||
hexlib.egg-info
|
||||
build/
|
||||
dist/
|
||||
@@ -1,5 +1,5 @@
|
||||
Misc utility methods in Python
|
||||
|
||||
```
|
||||
git+git://github.com/simon987/hexlib.git
|
||||
git+https://github.com/simon987/hexlib.git
|
||||
```
|
||||
27
bench/text.py
Normal file
27
bench/text.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from timeit import timeit
|
||||
|
||||
t = bytes.maketrans(b".,;:\"!?/()|*=>", b" ")
|
||||
|
||||
|
||||
def translate(x: str):
|
||||
arr = x.encode("utf8")
|
||||
|
||||
return arr.translate(t).decode("utf8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
res = timeit(
|
||||
setup='t = str.maketrans(".,;:\\"!?/()|*=>", " ")',
|
||||
stmt='x = "Hello, world %123 & *".translate(t)'
|
||||
)
|
||||
|
||||
# 0.865953s
|
||||
print("translate = %fs" % res)
|
||||
|
||||
res = timeit(
|
||||
setup='from text import translate',
|
||||
stmt='x = translate("Hello, world %123 & *")'
|
||||
)
|
||||
|
||||
# 0.865953s
|
||||
print("custom = %fs" % res)
|
||||
@@ -1,9 +1,139 @@
|
||||
from queue import Queue, Empty
|
||||
from multiprocessing import Process
|
||||
from multiprocessing import Queue as MPQueue
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread
|
||||
|
||||
from hexlib.misc import ichunks
|
||||
|
||||
def queue_iter(q: Queue, **get_args):
|
||||
|
||||
class StatelessStreamWorker:
|
||||
|
||||
def __init__(self):
|
||||
self._q_out = None
|
||||
|
||||
def run(self, q: Queue, q_out: Queue):
|
||||
|
||||
self._q_out: Queue = q_out
|
||||
|
||||
for chunk in queue_iter(q, joinable=False, timeout=10):
|
||||
self._process_chunk(chunk)
|
||||
|
||||
def _process_chunk(self, chunk):
|
||||
results = []
|
||||
|
||||
for item in chunk:
|
||||
result = self.process(item)
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
|
||||
if results:
|
||||
self._q_out.put(results)
|
||||
|
||||
def process(self, item):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StatelessStreamProcessor:
|
||||
def __init__(self, worker_factory, chunk_size=128, processes=1, timeout=60):
|
||||
self._chunk_size = 128
|
||||
self._queue = MPQueue(maxsize=chunk_size)
|
||||
self._queue_out = MPQueue(maxsize=processes * 2)
|
||||
self._process_count = processes
|
||||
self._processes = []
|
||||
self._factory = worker_factory
|
||||
self._workers = []
|
||||
self._timeout = timeout
|
||||
|
||||
if processes > 1:
|
||||
for _ in range(processes):
|
||||
worker = self._factory()
|
||||
p = Process(target=worker.run, args=(self._queue, self._queue_out))
|
||||
p.start()
|
||||
|
||||
self._processes.append(p)
|
||||
self._workers.append(worker)
|
||||
else:
|
||||
self._workers.append(self._factory())
|
||||
|
||||
def _ingest(self, iterable):
|
||||
if self._process_count > 1:
|
||||
for chunk in ichunks(iterable, self._chunk_size):
|
||||
self._queue.put(chunk)
|
||||
else:
|
||||
for item in iterable:
|
||||
self._workers[0].process(item)
|
||||
|
||||
def ingest(self, iterable):
|
||||
|
||||
ingest_thread = Thread(target=self._ingest, args=(iterable,))
|
||||
ingest_thread.start()
|
||||
|
||||
for results in queue_iter(self._queue_out, joinable=False, timeout=self._timeout):
|
||||
yield from results
|
||||
|
||||
ingest_thread.join()
|
||||
|
||||
|
||||
class StatefulStreamWorker:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, q: Queue, q_out: Queue):
|
||||
for chunk in queue_iter(q, joinable=False, timeout=3):
|
||||
self._process_chunk(chunk)
|
||||
|
||||
q_out.put(self.results())
|
||||
|
||||
def _process_chunk(self, chunk):
|
||||
for item in chunk:
|
||||
self.process(item)
|
||||
|
||||
def process(self, item) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def results(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StatefulStreamProcessor:
|
||||
def __init__(self, worker_factory, chunk_size=128, processes=1):
|
||||
self._chunk_size = 128
|
||||
self._queue = MPQueue(maxsize=chunk_size)
|
||||
self._queue_out = MPQueue()
|
||||
self._process_count = processes
|
||||
self._processes = []
|
||||
self._factory = worker_factory
|
||||
self._workers = []
|
||||
|
||||
if processes > 1:
|
||||
for _ in range(processes):
|
||||
worker = self._factory()
|
||||
p = Process(target=worker.run, args=(self._queue, self._queue_out))
|
||||
p.start()
|
||||
|
||||
self._processes.append(p)
|
||||
self._workers.append(worker)
|
||||
else:
|
||||
self._workers.append(self._factory())
|
||||
|
||||
def ingest(self, iterable):
|
||||
|
||||
if self._process_count > 1:
|
||||
for chunk in ichunks(iterable, self._chunk_size):
|
||||
self._queue.put(chunk)
|
||||
else:
|
||||
for item in iterable:
|
||||
self._workers[0].process(item)
|
||||
|
||||
def get_results(self):
|
||||
for _ in range(self._process_count):
|
||||
yield self._queue_out.get()
|
||||
for p in self._processes:
|
||||
p.join()
|
||||
|
||||
|
||||
def queue_iter(q: Queue, joinable=True, **get_args):
|
||||
while True:
|
||||
try:
|
||||
task = q.get(**get_args)
|
||||
@@ -12,7 +142,8 @@ def queue_iter(q: Queue, **get_args):
|
||||
break
|
||||
|
||||
yield task
|
||||
q.task_done()
|
||||
if joinable:
|
||||
q.task_done()
|
||||
except Empty:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
|
||||
276
hexlib/db.py
276
hexlib/db.py
@@ -1,88 +1,124 @@
|
||||
import base64
|
||||
import sqlite3
|
||||
import redis
|
||||
import orjson as json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
import psycopg2
|
||||
import umsgpack
|
||||
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hexlib.env import get_redis
|
||||
|
||||
|
||||
class PersistentState:
|
||||
"""Quick and dirty persistent dict-like SQLite wrapper"""
|
||||
def _json_encoder(x):
|
||||
if isinstance(x, datetime):
|
||||
return x.isoformat()
|
||||
if isinstance(x, Enum):
|
||||
return x.value
|
||||
|
||||
def __init__(self, dbfile="state.db", **dbargs):
|
||||
self.dbfile = dbfile
|
||||
if dbargs is None:
|
||||
dbargs = {"timeout": 30000}
|
||||
self.dbargs = dbargs
|
||||
|
||||
def __getitem__(self, table):
|
||||
return Table(self, table)
|
||||
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
|
||||
|
||||
|
||||
class VolatileState:
|
||||
"""Quick and dirty volatile dict-like redis wrapper"""
|
||||
|
||||
def __init__(self, prefix, **redis_args):
|
||||
self.rdb = redis.Redis(**redis_args)
|
||||
def __init__(self, prefix, redis_db=None, sep=""):
|
||||
if redis_db is None:
|
||||
redis_db = get_redis()
|
||||
self.rdb = redis_db
|
||||
self.prefix = prefix
|
||||
self._sep = sep
|
||||
|
||||
def __getitem__(self, table):
|
||||
return RedisTable(self, table)
|
||||
return RedisTable(self, table, self._sep)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.rdb.delete(f"{self.prefix}{self._sep}{key}")
|
||||
|
||||
|
||||
class VolatileQueue:
|
||||
"""Quick and dirty volatile queue-like redis wrapper"""
|
||||
|
||||
def __init__(self, key, redis_db=None):
|
||||
if redis_db is None:
|
||||
redis_db = get_redis()
|
||||
self.rdb = redis_db
|
||||
self.key = key
|
||||
|
||||
def put(self, item):
|
||||
self.rdb.sadd(self.key, umsgpack.dumps(item))
|
||||
|
||||
def get(self):
|
||||
v = self.rdb.spop(self.key)
|
||||
if v:
|
||||
return umsgpack.loads(v)
|
||||
|
||||
|
||||
class VolatileBooleanState:
|
||||
"""Quick and dirty volatile dict-like redis wrapper for boolean values"""
|
||||
|
||||
def __init__(self, prefix, **redis_args):
|
||||
self.rdb = redis.Redis(**redis_args)
|
||||
def __init__(self, prefix, redis_db=None, sep=""):
|
||||
if redis_db is None:
|
||||
redis_db = get_redis()
|
||||
self.rdb = redis_db
|
||||
self.prefix = prefix
|
||||
self._sep = sep
|
||||
|
||||
def __getitem__(self, table):
|
||||
return RedisBooleanTable(self, table)
|
||||
return RedisBooleanTable(self, table, self._sep)
|
||||
|
||||
def __delitem__(self, table):
|
||||
self.rdb.delete(f"{self.prefix}{self._sep}{table}")
|
||||
|
||||
|
||||
class RedisTable:
|
||||
def __init__(self, state, table):
|
||||
def __init__(self, state, table, sep=""):
|
||||
self._state = state
|
||||
self._table = table
|
||||
self._sep = sep
|
||||
self._key = f"{self._state.prefix}{self._sep}{self._table}"
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._state.rdb.hset(self._state.prefix + self._table, str(key), umsgpack.dumps(value))
|
||||
self._state.rdb.hset(self._key, str(key), umsgpack.dumps(value))
|
||||
|
||||
def __getitem__(self, key):
|
||||
val = self._state.rdb.hget(self._state.prefix + self._table, str(key))
|
||||
val = self._state.rdb.hget(self._key, str(key))
|
||||
if val:
|
||||
return umsgpack.loads(val)
|
||||
return None
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._state.rdb.hdel(self._state.prefix + self._table, str(key))
|
||||
self._state.rdb.hdel(self._key, str(key))
|
||||
|
||||
def __iter__(self):
|
||||
val = self._state.rdb.hgetall(self._state.prefix + self._table)
|
||||
if val:
|
||||
return ((k, umsgpack.loads(v)) for k, v in
|
||||
self._state.rdb.hgetall(self._state.prefix + self._table).items())
|
||||
for val in self._state.rdb.hscan(self._key):
|
||||
if val:
|
||||
return ((k, umsgpack.loads(v)) for k, v in val.items())
|
||||
|
||||
|
||||
class RedisBooleanTable:
|
||||
def __init__(self, state, table):
|
||||
def __init__(self, state, table, sep=""):
|
||||
self._state = state
|
||||
self._table = table
|
||||
self._sep = sep
|
||||
self._key = f"{self._state.prefix}{self._sep}{self._table}"
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if value:
|
||||
self._state.rdb.sadd(self._state.prefix + self._table, str(key))
|
||||
self._state.rdb.sadd(self._key, str(key))
|
||||
else:
|
||||
self.__delitem__(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._state.rdb.sismember(self._state.prefix + self._table, str(key))
|
||||
return self._state.rdb.sismember(self._key, str(key))
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._state.rdb.srem(self._state.prefix + self._table, str(key))
|
||||
self._state.rdb.srem(self._key, str(key))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._state.rdb.smembers(self._state.prefix + self._table))
|
||||
yield from self._state.rdb.sscan_iter(self._key)
|
||||
|
||||
|
||||
class Table:
|
||||
@@ -90,22 +126,54 @@ class Table:
|
||||
self._state = state
|
||||
self._table = table
|
||||
|
||||
def __iter__(self):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
cur = conn.execute("SELECT * FROM %s" % (self._table,))
|
||||
for row in cur:
|
||||
yield dict(row)
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
def _sql_dict(self, where_clause, *params):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
|
||||
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
|
||||
for row in cur:
|
||||
yield dict(
|
||||
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
|
||||
for i, col in enumerate(cur.description)
|
||||
)
|
||||
except:
|
||||
return None
|
||||
|
||||
def sql(self, where_clause, *params):
|
||||
for row in self._sql_dict(where_clause, *params):
|
||||
if row and "__pydantic" in row:
|
||||
yield self._deserialize_pydantic(row)
|
||||
else:
|
||||
yield row
|
||||
|
||||
def _iter_dict(self):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s" % (self._table,))
|
||||
for row in cur:
|
||||
yield dict(
|
||||
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
|
||||
for i, col in enumerate(cur.description)
|
||||
)
|
||||
except:
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
for row in self._iter_dict():
|
||||
if row and "__pydantic" in row:
|
||||
yield self._deserialize_pydantic(row)
|
||||
else:
|
||||
yield row
|
||||
|
||||
def _getitem_dict(self, key):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
|
||||
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
@@ -116,8 +184,32 @@ class Table:
|
||||
except:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_pydantic(row):
|
||||
module = __import__(row["__module"])
|
||||
cls = getattr(module, row["__class"])
|
||||
return cls.parse_raw(row["json"])
|
||||
|
||||
def __getitem__(self, key):
|
||||
row = self._getitem_dict(key)
|
||||
if row and "__pydantic" in row:
|
||||
return self._deserialize_pydantic(row)
|
||||
return row
|
||||
|
||||
def setitem_pydantic(self, key, value: BaseModel):
|
||||
self.__setitem__(key, {
|
||||
"json": value.json(encoder=_json_encoder, indent=2),
|
||||
"__class": value.__class__.__name__,
|
||||
"__module": value.__class__.__module__,
|
||||
"__pydantic": 1
|
||||
})
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
self.setitem_pydantic(key, value)
|
||||
return
|
||||
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
@@ -146,6 +238,13 @@ class Table:
|
||||
args
|
||||
)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
try:
|
||||
conn.execute("DELETE FROM %s WHERE id=?" % self._table, (key,))
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
|
||||
def _sqlite_type(value):
|
||||
if isinstance(value, int):
|
||||
@@ -160,15 +259,41 @@ def _sqlite_type(value):
|
||||
def _serialize(value):
|
||||
if isinstance(value, bytes):
|
||||
return base64.b64encode(value)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _deserialize(value, col_type):
|
||||
if col_type == "blob":
|
||||
if col_type.lower() == "blob":
|
||||
return base64.b64decode(value)
|
||||
return value
|
||||
|
||||
|
||||
class PersistentState:
|
||||
"""Quick and dirty persistent dict-like SQLite wrapper"""
|
||||
|
||||
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
|
||||
self.dbfile = dbfile
|
||||
self.logger = logger
|
||||
if dbargs is None or dbargs == {}:
|
||||
dbargs = {"timeout": 30000}
|
||||
self.dbargs = dbargs
|
||||
self._table_factory = table_factory
|
||||
|
||||
def __getitem__(self, table):
|
||||
return self._table_factory(self, table)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
|
||||
try:
|
||||
conn.execute(f"DROP TABLE {key}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def pg_fetch_cursor_all(cur, name, batch_size=1000):
|
||||
while True:
|
||||
cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name))
|
||||
@@ -183,3 +308,66 @@ def pg_fetch_cursor_all(cur, name, batch_size=1000):
|
||||
for row in cur:
|
||||
yield row
|
||||
break
|
||||
|
||||
|
||||
class PgConn:
|
||||
"""Wrapper for PostgreSQL connection"""
|
||||
|
||||
def __init__(self, logger=None, **kwargs):
|
||||
self._conn_args = kwargs
|
||||
self.conn = psycopg2.connect(**kwargs)
|
||||
self._logger = logger
|
||||
|
||||
def __enter__(self):
|
||||
self.cur = self.conn.cursor()
|
||||
return self
|
||||
|
||||
def exec(self, query_string, args=None):
|
||||
while True:
|
||||
try:
|
||||
if self._logger:
|
||||
self._logger.debug(query_string)
|
||||
self._logger.debug("With args " + str(args))
|
||||
|
||||
self.cur.execute(query_string, args)
|
||||
break
|
||||
except psycopg2.Error as e:
|
||||
if e.pgcode == UNIQUE_VIOLATION:
|
||||
break
|
||||
traceback.print_stack()
|
||||
self._handle_err(e, query_string, args)
|
||||
|
||||
def query(self, query_string, args=None, max_retries=1):
|
||||
retries = max_retries
|
||||
while retries > 0:
|
||||
try:
|
||||
if self._logger:
|
||||
self._logger.debug(query_string)
|
||||
self._logger.debug("With args " + str(args))
|
||||
|
||||
self.cur.execute(query_string, args)
|
||||
res = self.cur.fetchall()
|
||||
|
||||
if self._logger:
|
||||
self._logger.debug("result: " + str(res))
|
||||
|
||||
return res
|
||||
except psycopg2.Error as e:
|
||||
if e.pgcode == UNIQUE_VIOLATION:
|
||||
break
|
||||
self._handle_err(e, query_string, args)
|
||||
retries -= 1
|
||||
|
||||
def _handle_err(self, err, query, args):
|
||||
if self._logger:
|
||||
self._logger.warning(
|
||||
"Error during query '%s' with args %s: %s %s (%s)" % (query, args, type(err), err, err.pgcode))
|
||||
self.conn = psycopg2.connect(**self._conn_args)
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
try:
|
||||
self.conn.commit()
|
||||
self.cur.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
63
hexlib/env.py
Normal file
63
hexlib/env.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
|
||||
import redis
|
||||
from fake_useragent import UserAgent
|
||||
|
||||
from hexlib.log import stdout_logger
|
||||
from hexlib.web import Web
|
||||
|
||||
ARC_LISTS = os.environ.get("ARC_LISTS", "arc").split(",")
|
||||
PUBLISH_CHANNEL = os.environ.get("PUBLISH_CHANNEL", None)
|
||||
|
||||
|
||||
def get_redis():
|
||||
return redis.Redis(
|
||||
host=os.environ.get("REDIS_HOST", "localhost"),
|
||||
port=int(os.environ.get("REDIS_PORT", 6379))
|
||||
)
|
||||
|
||||
|
||||
def redis_publish(rdb, item, item_project, item_type, item_subproject=None, item_category="x"):
|
||||
item_project = item_project.replace(".", "-")
|
||||
item_subproject = item_subproject.replace(".", "-") if item_subproject else None
|
||||
|
||||
item_source = item_project if not item_subproject else f"{item_project}.{item_subproject}"
|
||||
|
||||
item_type = item_type.replace(".", "-")
|
||||
item_category = item_category.replace(".", "-")
|
||||
|
||||
if PUBLISH_CHANNEL is not None:
|
||||
routing_key = f"{PUBLISH_CHANNEL}.{item_source}.{item_type}.{item_category}"
|
||||
rdb.publish(routing_key, item)
|
||||
for arc_list in ARC_LISTS:
|
||||
routing_key = f"{arc_list}.{item_source}.{item_type}.{item_category}"
|
||||
rdb.lpush(routing_key, item)
|
||||
|
||||
|
||||
def get_web(session=None):
|
||||
ua = UserAgent()
|
||||
|
||||
retry_codes = os.environ.get("RETRY_CODES", "")
|
||||
|
||||
web = Web(
|
||||
session=session,
|
||||
proxy=os.environ.get("PROXY", None),
|
||||
rps=os.environ.get("RPS", 1),
|
||||
logger=stdout_logger,
|
||||
cookie_file=os.environ.get("COOKIE_FILE", None),
|
||||
retry_codes=set(int(x) for x in retry_codes.split(",")) if retry_codes else None,
|
||||
retries=int(os.environ.get("RETRIES", 3)),
|
||||
retry_sleep=int(os.environ.get("RETRY_SLEEP", 0)),
|
||||
ua=ua[os.environ.get("USER_AGENT")] if os.environ.get("USER_AGENT", None) is not None else None
|
||||
)
|
||||
|
||||
if hasattr(web._session, "cipherSuite"):
|
||||
stdout_logger.debug("Web>cipherSuite=%s" % web._session.cipherSuite)
|
||||
if hasattr(web._session, "headers"):
|
||||
stdout_logger.debug("Web>headers=%s" % web._session.headers)
|
||||
if hasattr(web._session, "cookies"):
|
||||
stdout_logger.debug("Web>cookies=%s" % web._session.cookies)
|
||||
|
||||
stdout_logger.debug("Web>rps=%s" % os.environ.get("RPS", 1))
|
||||
|
||||
return web
|
||||
@@ -62,6 +62,16 @@ COMPRESSION_GZIP = "gz"
|
||||
COMPRESSION_ZSTD = "zstd"
|
||||
|
||||
|
||||
class NDJsonLine:
|
||||
__slots__ = "text"
|
||||
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
def json(self):
|
||||
return json.loads(self.text)
|
||||
|
||||
|
||||
def ndjson_iter(*files, compression=""):
|
||||
for file in files:
|
||||
cleanup = None
|
||||
@@ -75,7 +85,7 @@ def ndjson_iter(*files, compression=""):
|
||||
line_iter = BufferedReader(gzip.open(file))
|
||||
elif compression == COMPRESSION_ZSTD:
|
||||
fp = open(file, "rb")
|
||||
dctx = zstandard.ZstdDecompressor()
|
||||
dctx = zstandard.ZstdDecompressor(max_window_size=2147483648)
|
||||
reader = dctx.stream_reader(fp)
|
||||
line_iter = BufferedReader(reader)
|
||||
|
||||
@@ -90,7 +100,6 @@ def ndjson_iter(*files, compression=""):
|
||||
line_iter.close()
|
||||
|
||||
for line in line_iter:
|
||||
yield json.loads(line)
|
||||
yield NDJsonLine(line)
|
||||
if cleanup:
|
||||
cleanup()
|
||||
|
||||
|
||||
51
hexlib/log.py
Normal file
51
hexlib/log.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from logging import StreamHandler
|
||||
|
||||
DATE_FMT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
|
||||
def __init__(self, fmt):
|
||||
super().__init__()
|
||||
|
||||
grey = "\x1b[38;21m"
|
||||
yellow = "\x1b[33;21m"
|
||||
red = "\x1b[31;21m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
|
||||
self.formats = {
|
||||
logging.DEBUG: logging.Formatter(grey + fmt + reset, datefmt=DATE_FMT),
|
||||
logging.INFO: logging.Formatter(grey + fmt + reset, datefmt=DATE_FMT),
|
||||
logging.WARNING: logging.Formatter(yellow + fmt + reset, datefmt=DATE_FMT),
|
||||
logging.ERROR: logging.Formatter(red + fmt + reset, datefmt=DATE_FMT),
|
||||
logging.CRITICAL: logging.Formatter(bold_red + fmt + reset, datefmt=DATE_FMT)
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
return self.formats[record.levelno].format(record)
|
||||
|
||||
|
||||
stdout_logger = logging.getLogger("default")
|
||||
|
||||
if os.environ.get("LOG_LEVEL", "debug") == "debug":
|
||||
stdout_logger.setLevel(logging.DEBUG)
|
||||
|
||||
for h in stdout_logger.handlers:
|
||||
stdout_logger.removeHandler(h)
|
||||
|
||||
handler = StreamHandler(sys.stdout)
|
||||
if os.environ.get("LOG_THREAD_NAME", "0") == "1":
|
||||
fmt = "%(asctime)s %(levelname)-5s>%(threadName)s %(message)s"
|
||||
else:
|
||||
fmt = "%(asctime)s %(levelname)-5s>%(message)s"
|
||||
|
||||
if os.environ.get("LOG_COLORS", "1") == "1":
|
||||
handler.formatter = ColorFormatter(fmt)
|
||||
else:
|
||||
handler.formatter = logging.Formatter(fmt, datefmt='%Y-%m-%d %H:%M:%S')
|
||||
stdout_logger.addHandler(handler)
|
||||
logger = stdout_logger
|
||||
@@ -1,23 +1,31 @@
|
||||
import atexit
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
|
||||
import atexit
|
||||
import siphash
|
||||
|
||||
last_time_called = dict()
|
||||
|
||||
|
||||
def retry(attempts, callback=None):
|
||||
def retry(attempts, callback=None, retry_sleep=0):
|
||||
def decorate(func):
|
||||
retries = attempts
|
||||
while retries > 0:
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
if callback:
|
||||
callback(e)
|
||||
def wrapper(*args, **kwargs):
|
||||
retries = attempts
|
||||
while retries > 0:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if callback:
|
||||
callback(e)
|
||||
retries -= 1
|
||||
sleep(retry_sleep)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
@@ -26,6 +34,15 @@ def chunks(lst: list, chunk_len: int):
|
||||
yield lst[i:i + chunk_len]
|
||||
|
||||
|
||||
def ichunks(iterable, chunk_len: int):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = tuple(itertools.islice(it, chunk_len))
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
def rate_limit(per_second):
|
||||
min_interval = 1.0 / float(per_second)
|
||||
|
||||
@@ -54,6 +71,12 @@ def buffered(batch_size: int, flush_on_exit: bool = False):
|
||||
atexit.register(func, buffer)
|
||||
|
||||
def wrapper(items):
|
||||
|
||||
if items is None:
|
||||
func(buffer)
|
||||
buffer.clear()
|
||||
return
|
||||
|
||||
with lock:
|
||||
for item in items:
|
||||
buffer.append(item)
|
||||
@@ -93,4 +116,20 @@ class CustomStdOut:
|
||||
self.fp.close()
|
||||
|
||||
|
||||
class CustomStdErr:
|
||||
original_stderr = sys.stderr
|
||||
|
||||
def __init__(self, fname):
|
||||
self.fname = fname
|
||||
|
||||
def __enter__(self):
|
||||
self.fp = open(self.fname, "w")
|
||||
sys.stderr = self.fp
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stdout = CustomStdErr.original_stderr
|
||||
self.fp.close()
|
||||
|
||||
|
||||
silent_stdout = CustomStdOut(os.devnull)
|
||||
silent_stderr = CustomStdErr(os.devnull)
|
||||
|
||||
@@ -1,22 +1,29 @@
|
||||
import logging
|
||||
import traceback
|
||||
from abc import ABC
|
||||
|
||||
from influxdb import InfluxDBClient
|
||||
|
||||
from hexlib.misc import buffered
|
||||
|
||||
|
||||
class Monitoring:
|
||||
def __init__(self, db, host="localhost", logger=logging.getLogger("default"), batch_size=1, flush_on_exit=False):
|
||||
self._db = db
|
||||
self._client = InfluxDBClient(host, 8086, "", "", db)
|
||||
class Monitoring(ABC):
|
||||
def log(self, points):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BufferedInfluxDBMonitoring(Monitoring):
|
||||
def __init__(self, db_name, host="localhost", port=8086, logger=None, batch_size=1, flush_on_exit=False):
|
||||
self._db = db_name
|
||||
self._client = InfluxDBClient(host, port, "", "", db_name)
|
||||
self._logger = logger
|
||||
|
||||
self._init()
|
||||
if not self.db_exists(self._db):
|
||||
self._client.create_database(self._db)
|
||||
|
||||
@buffered(batch_size, flush_on_exit)
|
||||
def log(points):
|
||||
self._log(points)
|
||||
|
||||
self.log = log
|
||||
|
||||
def db_exists(self, name):
|
||||
@@ -25,14 +32,16 @@ class Monitoring:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _init(self):
|
||||
if not self.db_exists(self._db):
|
||||
self._client.create_database(self._db)
|
||||
def log(self, points):
|
||||
# Is overwritten in __init__()
|
||||
pass
|
||||
|
||||
def _log(self, points):
|
||||
try:
|
||||
self._client.write_points(points)
|
||||
self._logger.debug("InfluxDB: Wrote %d points" % len(points))
|
||||
if self._logger:
|
||||
self._logger.debug("InfluxDB: Wrote %d points" % len(points))
|
||||
except Exception as e:
|
||||
self._logger.debug(traceback.format_exc())
|
||||
self._logger.error(str(e))
|
||||
if self._logger:
|
||||
self._logger.debug(traceback.format_exc())
|
||||
self._logger.error(str(e))
|
||||
|
||||
161
hexlib/mq.py
Normal file
161
hexlib/mq.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import json
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from time import sleep, time
|
||||
|
||||
from orjson import orjson
|
||||
from redis import Redis
|
||||
|
||||
RoutingKeyParts = namedtuple(
|
||||
"RoutingKeyParts",
|
||||
["arc_list", "project", "subproject", "type", "category"]
|
||||
)
|
||||
|
||||
|
||||
def parse_routing_key(key):
|
||||
tokens = key.split(".")
|
||||
|
||||
if len(tokens) == 4:
|
||||
arc_list, project, type_, category = tokens
|
||||
return RoutingKeyParts(
|
||||
arc_list=arc_list,
|
||||
project=project,
|
||||
subproject=None,
|
||||
type=type_,
|
||||
category=category
|
||||
)
|
||||
else:
|
||||
arc_list, project, subproject, type_, category = tokens
|
||||
return RoutingKeyParts(
|
||||
arc_list=arc_list,
|
||||
project=project,
|
||||
subproject=subproject,
|
||||
type=type_,
|
||||
category=category
|
||||
)
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
def read_messages(self, topics):
|
||||
raise NotImplementedError()
|
||||
|
||||
def publish(self, item, item_project, item_type, item_subproject, item_category):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RedisMQ(MessageQueue):
|
||||
_MAX_KEYS = 30
|
||||
|
||||
def __init__(self, rdb, consumer_name="redis_mq", sep=".", max_pending_time=120, logger=None, publish_channel=None,
|
||||
arc_lists=None, wait=1):
|
||||
self._rdb: Redis = rdb
|
||||
self._key_cache = None
|
||||
self._consumer_id = consumer_name
|
||||
self._pending_list = f"pending{sep}{consumer_name}"
|
||||
self._max_pending_time = max_pending_time
|
||||
self._logger = logger
|
||||
self._publish_channel = publish_channel
|
||||
self._arc_lists = arc_lists
|
||||
self._wait = wait
|
||||
|
||||
def _get_keys(self, pattern):
|
||||
if self._key_cache:
|
||||
return self._key_cache
|
||||
|
||||
keys = list(islice(
|
||||
self._rdb.scan_iter(match=pattern, count=RedisMQ._MAX_KEYS), RedisMQ._MAX_KEYS
|
||||
))
|
||||
self._key_cache = keys
|
||||
|
||||
return keys
|
||||
|
||||
def _get_pending_tasks(self):
|
||||
for task_id, pending_task in self._rdb.hscan_iter(self._pending_list):
|
||||
|
||||
pending_task_json = orjson.loads(pending_task)
|
||||
|
||||
if time() >= pending_task_json["resubmit_at"]:
|
||||
yield pending_task_json["topic"], pending_task_json["task"], partial(self._ack, task_id)
|
||||
|
||||
def _ack(self, task_id):
|
||||
self._rdb.hdel(self._pending_list, task_id)
|
||||
|
||||
def read_messages(self, topics):
|
||||
"""
|
||||
Assumes json-encoded tasks with an _id field
|
||||
|
||||
Tasks are automatically put into a pending list until ack() is called.
|
||||
When a task has been in the pending list for at least max_pending_time seconds, it
|
||||
gets submitted again
|
||||
"""
|
||||
|
||||
assert len(topics) == 1, "RedisMQ only supports 1 topic pattern"
|
||||
|
||||
pattern = topics[0]
|
||||
counter = 0
|
||||
|
||||
if self._logger:
|
||||
self._logger.info(f"MQ>Listening for new messages in {pattern}")
|
||||
|
||||
while True:
|
||||
counter += 1
|
||||
|
||||
if counter % 1000 == 0:
|
||||
yield from self._get_pending_tasks()
|
||||
|
||||
keys = self._get_keys(pattern)
|
||||
if not keys:
|
||||
sleep(self._wait)
|
||||
self._key_cache = None
|
||||
continue
|
||||
|
||||
result = self._rdb.blpop(keys, timeout=1)
|
||||
if not result:
|
||||
self._key_cache = None
|
||||
continue
|
||||
|
||||
topic, task = result
|
||||
|
||||
task_json = orjson.loads(task)
|
||||
topic = topic.decode()
|
||||
|
||||
if "_id" not in task_json or not task_json["_id"]:
|
||||
raise ValueError(f"Task doesn't have _id field: {task}")
|
||||
|
||||
# Immediately put in pending queue
|
||||
self._rdb.hset(
|
||||
self._pending_list, task_json["_id"],
|
||||
orjson.dumps({
|
||||
"resubmit_at": time() + self._max_pending_time,
|
||||
"topic": topic,
|
||||
"task": task_json
|
||||
})
|
||||
)
|
||||
|
||||
yield topic, task_json, partial(self._ack, task_json["_id"])
|
||||
|
||||
def publish(self, item, item_project, item_type, item_subproject=None, item_category="x"):
|
||||
|
||||
if "_id" not in item:
|
||||
raise ValueError("_id field must be set for item")
|
||||
|
||||
item = json.dumps(item, separators=(',', ':'), ensure_ascii=False, sort_keys=True)
|
||||
|
||||
item_project = item_project.replace(".", "-")
|
||||
item_subproject = item_subproject.replace(".", "-") if item_subproject else None
|
||||
|
||||
item_source = item_project if not item_subproject else f"{item_project}.{item_subproject}"
|
||||
|
||||
item_type = item_type.replace(".", "-")
|
||||
item_category = item_category.replace(".", "-")
|
||||
|
||||
# If specified, fan-out to pub/sub channel
|
||||
if self._publish_channel is not None:
|
||||
routing_key = f"{self._publish_channel}.{item_source}.{item_type}.{item_category}"
|
||||
self._rdb.publish(routing_key, item)
|
||||
|
||||
# Save to list
|
||||
for arc_list in self._arc_lists:
|
||||
routing_key = f"{arc_list}.{item_source}.{item_type}.{item_category}"
|
||||
self._rdb.lpush(routing_key, item)
|
||||
@@ -2,3 +2,6 @@ import re
|
||||
|
||||
LINK_RE = re.compile(r"(https?://[\w\-_.]+\.[a-z]{2,4}([^\s<'\"]*|$))")
|
||||
HTML_HREF_RE = re.compile(r"href=\"([^\"]+)\"")
|
||||
WHITESPACE_RE = re.compile(r"\s+")
|
||||
PUNCTUATION_RE = re.compile(r"[.,;:\"“!?/()|*=>]+")
|
||||
XML_ENTITY_RE = re.compile(r"&[a-z]+;")
|
||||
128
hexlib/text.py
Normal file
128
hexlib/text.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import re
|
||||
from itertools import chain, repeat
|
||||
|
||||
import nltk.corpus
|
||||
from lxml import etree
|
||||
from nltk import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
|
||||
from .regex_util import LINK_RE
|
||||
|
||||
get_text = etree.XPath("//text()")
|
||||
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
stop_words_en = set(stopwords.words("english"))
|
||||
|
||||
extra_stop_words_en = [
|
||||
"u", "&", "-", "--"
|
||||
]
|
||||
|
||||
stop_words_en.update(extra_stop_words_en)
|
||||
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
|
||||
|
||||
def _transform_bigram(ngram_seq, ngrams):
|
||||
for ngram in ngram_seq:
|
||||
if ngram in ngrams:
|
||||
yield ngram[0] + "_" + ngram[1]
|
||||
|
||||
next(ngram_seq)
|
||||
else:
|
||||
yield ngram[0]
|
||||
|
||||
|
||||
def _transform_trigram(ngram_seq, ngrams):
|
||||
for ngram in ngram_seq:
|
||||
if ngram in ngrams:
|
||||
# yield ngram[0] + "_" + ngram[1] + "_" + ngram[2]
|
||||
yield "_".join(ngram)
|
||||
|
||||
next(ngram_seq)
|
||||
next(ngram_seq)
|
||||
else:
|
||||
yield ngram[0]
|
||||
|
||||
|
||||
SINGLE_QUOTES = ("’", "`", "‘")
|
||||
SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", len(SINGLE_QUOTES))))
|
||||
|
||||
DASHES = ("–", "⸺", "–", "—")
|
||||
DASHES_TRANS = str.maketrans("".join(DASHES), "".join(repeat("-", len(DASHES))))
|
||||
|
||||
DASHES_RE = re.compile(r"-+")
|
||||
|
||||
SPECIAL_PUNCTUATION = ";:\"/()|*=>"
|
||||
SPECIAL_PUNCTUATION_TRANS = str.maketrans(SPECIAL_PUNCTUATION, " " * len(SPECIAL_PUNCTUATION))
|
||||
|
||||
PUNCTUATION = ".,!?"
|
||||
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION))
|
||||
|
||||
|
||||
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_special_punctuation=False,
|
||||
remove_stopwords_en=False, lemmatize=False, fix_single_quotes=False, strip_quotes=False,
|
||||
strip_dashes=False,
|
||||
remove_urls=False, bigrams: set = None, trigrams: set = None, remove_numbers=False,
|
||||
use_nltk_tokenizer=False):
|
||||
if lowercase:
|
||||
text = text.lower()
|
||||
|
||||
if fix_single_quotes:
|
||||
text = text.translate(SINGLE_QUOTE_TRANS)
|
||||
|
||||
text = text.translate(DASHES_TRANS)
|
||||
|
||||
if strip_dashes:
|
||||
text = DASHES_RE.sub("-", text)
|
||||
|
||||
if remove_urls:
|
||||
text = LINK_RE.sub(" ", text)
|
||||
|
||||
if clean_html:
|
||||
try:
|
||||
text = "<root>" + text + "</root>"
|
||||
|
||||
parser = etree.XMLParser(recover=True)
|
||||
root = etree.fromstring(text, parser)
|
||||
|
||||
text = " ".join(get_text(root))
|
||||
except:
|
||||
pass
|
||||
|
||||
if remove_punctuation:
|
||||
text = text.translate(PUNCTUATION_TRANS)
|
||||
|
||||
if remove_special_punctuation:
|
||||
text = text.translate(SPECIAL_PUNCTUATION_TRANS)
|
||||
|
||||
if use_nltk_tokenizer:
|
||||
words = word_tokenize(text, language="english")
|
||||
else:
|
||||
words = text.split()
|
||||
|
||||
if strip_quotes:
|
||||
words = map(lambda w: w.strip("\"'“”"), words)
|
||||
|
||||
if strip_dashes:
|
||||
words = map(lambda w: w.strip("-"), words)
|
||||
|
||||
if bigrams:
|
||||
words = _transform_bigram(nltk.bigrams(chain(words, ("*",))), bigrams)
|
||||
|
||||
if trigrams:
|
||||
words = _transform_trigram(nltk.trigrams(chain(words, ("*", "*"))), trigrams)
|
||||
|
||||
if remove_numbers:
|
||||
words = filter(lambda w: not w.isnumeric(), words)
|
||||
|
||||
if lemmatize:
|
||||
words = map(lambda w: lemmatizer.lemmatize(w), words)
|
||||
|
||||
if remove_stopwords_en:
|
||||
words = filter(lambda w: w not in stop_words_en, words)
|
||||
|
||||
return filter(lambda w: w != "", words)
|
||||
123
hexlib/web.py
123
hexlib/web.py
@@ -4,12 +4,19 @@ import os
|
||||
from datetime import datetime
|
||||
from base64 import b64encode, b64decode
|
||||
from http.cookiejar import Cookie
|
||||
from time import time
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
import requests
|
||||
import orjson as json
|
||||
|
||||
from dateutil.parser import parse
|
||||
from requests.cookies import RequestsCookieJar
|
||||
|
||||
from hexlib.misc import rate_limit, retry
|
||||
|
||||
|
||||
def cookie_from_string(text: str, domain: str) -> Cookie:
|
||||
tokens = [t.strip() for t in text.split(";")]
|
||||
@@ -74,6 +81,26 @@ def cookiejar_filter(cj, pattern):
|
||||
return filtered_cj
|
||||
|
||||
|
||||
def cookiejar_filter_name(cj, pattern):
|
||||
filtered_cj = RequestsCookieJar()
|
||||
for c in cj:
|
||||
if re.match(pattern, c.name):
|
||||
filtered_cj.set_cookie(c)
|
||||
return filtered_cj
|
||||
|
||||
|
||||
def url_query_value(url, arg, as_list=False):
|
||||
qs = urlparse(url).query
|
||||
parsed_qs = parse_qs(qs)
|
||||
|
||||
arg = parsed_qs.get(arg, [])
|
||||
|
||||
if as_list:
|
||||
return arg if arg else []
|
||||
else:
|
||||
return arg[0] if arg else None
|
||||
|
||||
|
||||
def download_file(url, destination, session=None, headers=None, overwrite=False, retries=1, err_cb=None,
|
||||
save_meta=False):
|
||||
if os.path.exists(destination) and not overwrite:
|
||||
@@ -99,8 +126,104 @@ def download_file(url, destination, session=None, headers=None, overwrite=False,
|
||||
"url": url,
|
||||
"timestamp": datetime.utcnow().replace(microsecond=0).isoformat()
|
||||
}))
|
||||
r.close()
|
||||
break
|
||||
except Exception as e:
|
||||
if err_cb:
|
||||
err_cb(e)
|
||||
retries -= 1
|
||||
|
||||
|
||||
class Web:
|
||||
def __init__(self, proxy=None, rps=1, retries=3, retry_sleep=0, logger=None, cookie_file=None, retry_codes=None,
|
||||
session=None,
|
||||
ua=None):
|
||||
self._cookie_file = cookie_file
|
||||
self._proxy = proxy
|
||||
self._logger = logger
|
||||
self._current_req = None
|
||||
if retry_codes is None or not retry_codes:
|
||||
retry_codes = {500, 502, 503, 504, 520, 522, 524, 429}
|
||||
self._retry_codes = retry_codes
|
||||
|
||||
if session is None:
|
||||
session = requests.session()
|
||||
|
||||
self._session = session
|
||||
|
||||
if ua is not None:
|
||||
session.headers["User-Agent"] = ua
|
||||
|
||||
if self._cookie_file:
|
||||
self._session.cookies = load_cookiejar(cookie_file)
|
||||
|
||||
if self._proxy:
|
||||
self._session.proxies = {
|
||||
"http": proxy,
|
||||
"https": proxy,
|
||||
}
|
||||
|
||||
@rate_limit(rps)
|
||||
@retry(retries, callback=self._error_callback, retry_sleep=retry_sleep)
|
||||
def get(url, **kwargs):
|
||||
self._current_req = "GET", url, kwargs
|
||||
r = self._session.get(url, **kwargs)
|
||||
|
||||
if r.status_code in self._retry_codes:
|
||||
raise Exception(f"HTTP {r.status_code}")
|
||||
return r
|
||||
|
||||
self._get = get
|
||||
|
||||
@rate_limit(rps)
|
||||
@retry(retries, callback=self._error_callback, retry_sleep=retry_sleep)
|
||||
def post(url, **kwargs):
|
||||
self._current_req = "POST", url, kwargs
|
||||
r = self._session.post(url, **kwargs)
|
||||
|
||||
if r.status_code in self._retry_codes:
|
||||
raise Exception(f"HTTP {r.status_code}")
|
||||
return r
|
||||
|
||||
self._post = post
|
||||
|
||||
def _error_callback(self, e):
|
||||
if self._logger:
|
||||
self._logger.critical(f"{self._format_url(*self._current_req)}: {e}")
|
||||
|
||||
def _format_url(self, method, url, kwargs, r=None):
|
||||
if "params" in kwargs and kwargs["params"]:
|
||||
return "%s %s?%s <%s>" % (method, url, "&".join(f"{k}={v}" for k, v in kwargs["params"].items()),
|
||||
r.status_code if r is not None else "ERR")
|
||||
else:
|
||||
return "%s %s <%s>" % (method, url, r.status_code if r is not None else "ERR",)
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
|
||||
time_start = time()
|
||||
r = self._get(url, **kwargs)
|
||||
|
||||
if self._cookie_file:
|
||||
save_cookiejar(self._session.cookies, self._cookie_file)
|
||||
|
||||
if self._logger and r is not None:
|
||||
self._logger.debug(self._format_url("GET", url, kwargs, r) + " %.2fs" % (time() - time_start))
|
||||
return r
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
|
||||
time_start = time()
|
||||
r = self._post(url, **kwargs)
|
||||
|
||||
if self._cookie_file:
|
||||
save_cookiejar(self._session.cookies, self._cookie_file)
|
||||
|
||||
if self._logger and r is not None:
|
||||
self._logger.debug(self._format_url("POST", url, kwargs, r) + " %.2fs" % (time() - time_start))
|
||||
return r
|
||||
|
||||
def get_soup(self, url, **kwargs):
|
||||
r = self.get(url, **kwargs)
|
||||
if not r:
|
||||
return None
|
||||
return BeautifulSoup(r.content, "html.parser")
|
||||
|
||||
8
setup.py
8
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="hexlib",
|
||||
version="1.20",
|
||||
version="1.89",
|
||||
description="Misc utility methods",
|
||||
author="simon987",
|
||||
author_email="me@simon987.net",
|
||||
@@ -12,7 +12,9 @@ setup(
|
||||
"data/*"
|
||||
]},
|
||||
install_requires=[
|
||||
"ImageHash", "influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
|
||||
"u-msgpack-python"
|
||||
"influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
|
||||
"u-msgpack-python", "psycopg2-binary", "bs4", "lxml", "nltk", "numpy",
|
||||
"matplotlib", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent",
|
||||
"requests", "pydantic==1.10.11"
|
||||
]
|
||||
)
|
||||
|
||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
@@ -1,49 +0,0 @@
|
||||
import os
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.db import PersistentState
|
||||
|
||||
|
||||
class TestPersistentState(TestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def setUp(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def test_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
|
||||
s["a"]["1"] = val
|
||||
|
||||
val["id"] = "1"
|
||||
|
||||
self.assertDictEqual(val, s["a"]["1"])
|
||||
|
||||
def test_get_set_int_id(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
|
||||
s["a"][1] = val
|
||||
|
||||
val["id"] = 1
|
||||
|
||||
self.assertDictEqual(val, s["a"][1])
|
||||
|
||||
def test_update_partial(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
|
||||
s["a"][1] = val
|
||||
s["a"][1] = {
|
||||
"a": 2
|
||||
}
|
||||
|
||||
val["a"] = 2
|
||||
val["id"] = 1
|
||||
|
||||
self.assertDictEqual(val, s["a"][1])
|
||||
143
test/test_PersistentState.py
Normal file
143
test/test_PersistentState.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.db import PersistentState
|
||||
|
||||
|
||||
class TestPersistentState(TestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def setUp(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def test_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
|
||||
s["a"]["1"] = val
|
||||
|
||||
val["id"] = "1"
|
||||
|
||||
self.assertDictEqual(val, s["a"]["1"])
|
||||
|
||||
def test_get_set_int_id(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
|
||||
s["a"][1] = val
|
||||
|
||||
val["id"] = 1
|
||||
|
||||
self.assertDictEqual(val, s["a"][1])
|
||||
|
||||
def test_update_partial(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
|
||||
s["a"][1] = val
|
||||
s["a"][1] = {
|
||||
"a": 2
|
||||
}
|
||||
|
||||
val["a"] = 2
|
||||
val["id"] = 1
|
||||
|
||||
self.assertDictEqual(val, s["a"][1])
|
||||
|
||||
def test_none(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": 1, "b": None}
|
||||
s["a"][1] = val
|
||||
s["a"][1] = {
|
||||
"a": None
|
||||
}
|
||||
|
||||
val["a"] = None
|
||||
val["id"] = 1
|
||||
|
||||
self.assertDictEqual(val, s["a"][1])
|
||||
|
||||
def test_bool(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = {"a": True, "b": False}
|
||||
s["a"][1] = val
|
||||
s["a"][1] = {
|
||||
"a": True
|
||||
}
|
||||
|
||||
val["a"] = True
|
||||
val["id"] = 1
|
||||
|
||||
self.assertDictEqual(val, s["a"][1])
|
||||
|
||||
def test_sql(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][1] = {"a": True}
|
||||
s["a"][2] = {"a": False}
|
||||
s["a"][3] = {"a": True}
|
||||
|
||||
items = list(s["a"].sql("WHERE a=0 ORDER BY id"))
|
||||
|
||||
self.assertDictEqual(items[0], s["a"][2])
|
||||
|
||||
def test_delitem(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][1] = {"a": True}
|
||||
del s["a"][1]
|
||||
|
||||
self.assertIsNone(s["a"][1])
|
||||
|
||||
def test_delitem_nonexistent(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][1] = {"a": True}
|
||||
del s["a"][456]
|
||||
|
||||
self.assertIsNotNone(s["a"][1])
|
||||
|
||||
def test_delitem_no_table(self):
|
||||
s = PersistentState()
|
||||
|
||||
try:
|
||||
del s["a"][456]
|
||||
except Exception as e:
|
||||
self.fail(e)
|
||||
|
||||
def test_deserialize_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": b'abc'}
|
||||
|
||||
self.assertEqual(s["a"][0]["x"], b'abc')
|
||||
|
||||
def test_deserialize_sql(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": b'abc'}
|
||||
|
||||
self.assertEqual(list(s["a"].sql("WHERE 1=1"))[0]["x"], b'abc')
|
||||
|
||||
def test_deserialize_iter(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": b'abc'}
|
||||
|
||||
self.assertEqual(list(s["a"])[0]["x"], b'abc')
|
||||
|
||||
def test_drop_table(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": 1}
|
||||
s["a"][1] = {"x": 2}
|
||||
self.assertEqual(len(list(s["a"])), 2)
|
||||
|
||||
del s["a"]
|
||||
self.assertEqual(len(list(s["a"])), 0)
|
||||
110
test/test_PydanticTable.py
Normal file
110
test/test_PydanticTable.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from unittest import TestCase
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import List
|
||||
|
||||
from hexlib.db import PersistentState
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
yes = "yes"
|
||||
no = "no"
|
||||
|
||||
|
||||
class Point(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
class Polygon(BaseModel):
|
||||
points: List[Point] = []
|
||||
created_date: datetime
|
||||
status: Status = Status("yes")
|
||||
|
||||
|
||||
class TestPydanticTable(TestCase):
|
||||
def tearDown(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def setUp(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def test_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[
|
||||
Point(x=1, y=2),
|
||||
Point(x=3, y=4),
|
||||
],
|
||||
)
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
||||
self.assertEqual(s["a"]["1"].status, Status("yes"))
|
||||
self.assertEqual(s["a"]["1"].points[1].x, 3)
|
||||
self.assertEqual(s["a"]["1"].created_date.year, 2000)
|
||||
|
||||
def test_update(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[
|
||||
Point(x=1, y=2),
|
||||
Point(x=3, y=4),
|
||||
]
|
||||
)
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
||||
|
||||
val.points[0].x = 2
|
||||
s["a"]["1"] = val
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 2)
|
||||
|
||||
def test_sql(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["b"]["1"] = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
s["b"]["2"] = Polygon(
|
||||
created_date=datetime(year=2010, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
|
||||
result = list(s["b"].sql(
|
||||
"WHERE json->>'created_date' LIKE '2000-%'"
|
||||
))
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].created_date.year, 2000)
|
||||
|
||||
def test_iterate(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["b"]["1"] = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
s["b"]["2"] = Polygon(
|
||||
created_date=datetime(year=2010, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
|
||||
result = list(s["b"])
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].created_date.year, 2000)
|
||||
self.assertEqual(result[1].created_date.year, 2010)
|
||||
@@ -1,9 +1,15 @@
|
||||
from unittest import TestCase
|
||||
from hexlib.db import VolatileState, VolatileBooleanState
|
||||
|
||||
from hexlib.db import VolatileState, VolatileBooleanState, VolatileQueue
|
||||
from hexlib.env import get_redis
|
||||
|
||||
|
||||
class TestVolatileState(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
rdb = get_redis()
|
||||
rdb.delete("test1a", "test1b", "test1c", "test1:a", "test2b")
|
||||
|
||||
def test_get_set(self):
|
||||
s = VolatileState(prefix="test1")
|
||||
val = {
|
||||
@@ -15,6 +21,17 @@ class TestVolatileState(TestCase):
|
||||
|
||||
self.assertDictEqual(val, s["a"]["1"])
|
||||
|
||||
def test_sep(self):
|
||||
s = VolatileState(prefix="test1", sep=":")
|
||||
val = {
|
||||
"field1": 1,
|
||||
"arr1": [1, 2, 3]
|
||||
}
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertDictEqual(val, s["a"]["1"])
|
||||
|
||||
def test_iter(self):
|
||||
s = VolatileState(prefix="test2")
|
||||
|
||||
@@ -23,7 +40,7 @@ class TestVolatileState(TestCase):
|
||||
s["b"]["3"] = 3
|
||||
s["b"]["4"] = 4
|
||||
|
||||
self.assertEqual(sum(v for k,v in s["b"]), 10)
|
||||
self.assertEqual(sum(v for k, v in s["b"]), 10)
|
||||
|
||||
def test_int_key(self):
|
||||
s = VolatileState(prefix="test2")
|
||||
@@ -41,6 +58,10 @@ class TestVolatileState(TestCase):
|
||||
|
||||
class TestVolatileBoolState(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
rdb = get_redis()
|
||||
rdb.delete("test1a", "test1b", "test1c", "test1:a", "test2b")
|
||||
|
||||
def test_get_set(self):
|
||||
s = VolatileBooleanState(prefix="test1")
|
||||
|
||||
@@ -51,6 +72,16 @@ class TestVolatileBoolState(TestCase):
|
||||
self.assertTrue(s["a"]["2"])
|
||||
self.assertFalse(s["a"]["3"])
|
||||
|
||||
def test_sep(self):
|
||||
s = VolatileBooleanState(prefix="test1", sep=":")
|
||||
|
||||
s["a"]["1"] = True
|
||||
s["a"]["2"] = True
|
||||
|
||||
self.assertTrue(s["a"]["1"])
|
||||
self.assertTrue(s["a"]["2"])
|
||||
self.assertFalse(s["a"]["3"])
|
||||
|
||||
def test_iter(self):
|
||||
s = VolatileBooleanState(prefix="test2")
|
||||
|
||||
@@ -68,3 +99,30 @@ class TestVolatileBoolState(TestCase):
|
||||
self.assertTrue(s["c"]["1"])
|
||||
del s["c"]["1"]
|
||||
self.assertFalse(s["c"]["1"])
|
||||
|
||||
|
||||
class TestVolatileQueue(TestCase):
|
||||
|
||||
def test_simple(self):
|
||||
s = VolatileQueue(key="test5")
|
||||
|
||||
s.put("test")
|
||||
item = s.get()
|
||||
|
||||
self.assertTrue(item == "test")
|
||||
|
||||
def test_dict(self):
|
||||
s = VolatileQueue(key="test5")
|
||||
|
||||
s.put({"a": 1})
|
||||
item = s.get()
|
||||
|
||||
self.assertTrue(item["a"] == 1)
|
||||
|
||||
def test_int(self):
|
||||
s = VolatileQueue(key="test5")
|
||||
|
||||
s.put(123)
|
||||
item = s.get()
|
||||
|
||||
self.assertTrue(item == 123)
|
||||
|
||||
36
test/test_buffered.py
Normal file
36
test/test_buffered.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.misc import buffered
|
||||
|
||||
|
||||
class TestBuffered(TestCase):
|
||||
|
||||
def test_simple(self):
|
||||
my_list = []
|
||||
|
||||
@buffered(batch_size=2)
|
||||
def put_item(items):
|
||||
my_list.extend(items)
|
||||
|
||||
put_item([1, 2])
|
||||
put_item([1])
|
||||
put_item([1])
|
||||
put_item([1])
|
||||
|
||||
self.assertEqual(len(my_list), 4)
|
||||
|
||||
def test_flush(self):
|
||||
my_list = []
|
||||
|
||||
@buffered(batch_size=2)
|
||||
def put_item(items):
|
||||
my_list.extend(items)
|
||||
|
||||
put_item([1, 2])
|
||||
put_item([1])
|
||||
put_item([1])
|
||||
put_item([1])
|
||||
|
||||
put_item(None)
|
||||
|
||||
self.assertEqual(len(my_list), 5)
|
||||
@@ -1,13 +1,17 @@
|
||||
from unittest import TestCase
|
||||
import os
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.web import download_file
|
||||
import warnings
|
||||
|
||||
|
||||
class TestDownloadFile(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
warnings.filterwarnings(action="ignore", category=ResourceWarning)
|
||||
|
||||
def test_download_file(self):
|
||||
download_file("http://ovh.net/files/10Mb.dat", "/tmp/10Mb.dat")
|
||||
download_file("https://github.com/simon987/hexlib/raw/master/10MB.bin", "/tmp/10Mb.dat")
|
||||
self.assertTrue(os.path.exists("/tmp/10Mb.dat"))
|
||||
os.remove("/tmp/10Mb.dat")
|
||||
|
||||
@@ -22,8 +26,8 @@ class TestDownloadFile(TestCase):
|
||||
self.assertEqual(len(exceptions), 3)
|
||||
|
||||
def test_download_file_meta(self):
|
||||
download_file("http://ovh.net/files/10Mb.dat", "/tmp/10Mb.dat", save_meta=True)
|
||||
download_file("https://github.com/simon987/hexlib/raw/master/10MB.bin", "/tmp/10Mb.dat", save_meta=True)
|
||||
self.assertTrue(os.path.exists("/tmp/10Mb.dat"))
|
||||
self.assertTrue(os.path.exists("/tmp/10Mb.dat.meta"))
|
||||
os.remove("/tmp/10Mb.dat")
|
||||
# os.remove("/tmp/10Mb.dat.meta")
|
||||
os.remove("/tmp/10Mb.dat.meta")
|
||||
|
||||
62
test/test_redis_mq.py
Normal file
62
test/test_redis_mq.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.env import get_redis
|
||||
from hexlib.mq import RedisMQ, parse_routing_key, RoutingKeyParts
|
||||
|
||||
|
||||
class TestRedisMQ(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.rdb = get_redis()
|
||||
self.rdb.delete("pending.test", "test_mq", "arc.test.msg.x")
|
||||
|
||||
def test_ack(self):
|
||||
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=2, arc_lists=["arc"])
|
||||
|
||||
mq.publish({"_id": 1}, item_project="test", item_type="msg")
|
||||
|
||||
topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.*"]))
|
||||
|
||||
self.assertEqual(self.rdb.hlen("pending.test"), 1)
|
||||
|
||||
ack1()
|
||||
|
||||
self.assertEqual(self.rdb.hlen("pending.test"), 0)
|
||||
|
||||
def test_pending_timeout(self):
|
||||
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0)
|
||||
|
||||
mq.publish({"_id": 1}, item_project="test", item_type="msg")
|
||||
|
||||
topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.test.*"]))
|
||||
|
||||
self.assertEqual(self.rdb.hlen("pending.test"), 1)
|
||||
|
||||
# msg1 will timeout after 0.5s, next iteration takes ceil(0.5)s
|
||||
topic1_, msg1_, ack1_ = next(mq.read_messages(topics=["arc.test.*"]))
|
||||
self.assertEqual(self.rdb.hlen("pending.test"), 1)
|
||||
|
||||
ack1_()
|
||||
|
||||
self.assertEqual(self.rdb.hlen("pending.test"), 0)
|
||||
|
||||
self.assertEqual(msg1, msg1_)
|
||||
|
||||
def test_no_id_field(self):
|
||||
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mq.publish({"a": 1}, item_project="test", item_type="msg")
|
||||
|
||||
|
||||
class TestRoutingKey(TestCase):
|
||||
|
||||
def test1(self):
|
||||
key = "arc.chan.4chan.post.b"
|
||||
parts = parse_routing_key(key)
|
||||
self.assertEqual(parts, RoutingKeyParts("arc", "chan", "4chan", "post", "b"))
|
||||
|
||||
def test2(self):
|
||||
key = "arc.reddit.submission.birdpics"
|
||||
parts = parse_routing_key(key)
|
||||
self.assertEqual(parts, RoutingKeyParts("arc", "reddit", None, "submission", "birdpics"))
|
||||
27
test/test_retry.py
Normal file
27
test/test_retry.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.misc import retry
|
||||
|
||||
|
||||
class TestRetry(TestCase):
|
||||
|
||||
def test_simple(self):
|
||||
@retry(attempts=3)
|
||||
def a(i):
|
||||
return i + 1
|
||||
|
||||
self.assertEqual(a(1), 2)
|
||||
|
||||
def test_error(self):
|
||||
arr = []
|
||||
|
||||
def cb(e):
|
||||
arr.append(e)
|
||||
|
||||
@retry(attempts=3, callback=cb)
|
||||
def a(i):
|
||||
raise Exception("err")
|
||||
|
||||
a(1)
|
||||
|
||||
self.assertEqual(3, len(arr))
|
||||
279
test/test_text.py
Normal file
279
test/test_text.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.text import preprocess
|
||||
|
||||
|
||||
class TestText(TestCase):
|
||||
|
||||
def test_html_invalid(self):
|
||||
text = ""
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
)
|
||||
expected = ""
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_1(self):
|
||||
text = "<div>Hello, <strong>world</strong></div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
)
|
||||
expected = "Hello, world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_2(self):
|
||||
text = "<div>Hello, <strong>world</strong></div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True
|
||||
)
|
||||
expected = "hello, world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_4(self):
|
||||
text = "<div>\n Hello, \t\n<strong> world </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
)
|
||||
expected = "hello, world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_5(self):
|
||||
text = "<div>\n Hello, \t\n<strong> world </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True
|
||||
)
|
||||
expected = "hello world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_6(self):
|
||||
text = "<div>\n Hello, \t\n<strong>a the world </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
remove_stopwords_en=True
|
||||
)
|
||||
expected = "hello world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_7(self):
|
||||
text = "<div>\n Hello, \t\n<strong>a the worlds </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
remove_stopwords_en=True,
|
||||
lemmatize=True
|
||||
)
|
||||
expected = "hello world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_8(self):
|
||||
text = "<div>\n Hello, \t\n<strong>a the worlds! </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
remove_stopwords_en=True,
|
||||
lemmatize=True
|
||||
)
|
||||
expected = "hello world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_9(self):
|
||||
text = "<div>\n Hello, \t\n<strong>world! it's it`s </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
lemmatize=True,
|
||||
fix_single_quotes=True
|
||||
)
|
||||
expected = "hello world it's it's"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_single_quote(self):
|
||||
text = "it's it`s it’s"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
lowercase=True,
|
||||
fix_single_quotes=True
|
||||
)
|
||||
expected = "it's it's it's"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_10(self):
|
||||
text = "<div>\n Hello, \t\n<strong>world! it's it`s https://google.ca/test/abc.pdf </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
lemmatize=True,
|
||||
fix_single_quotes=True,
|
||||
remove_urls=True
|
||||
)
|
||||
expected = "hello world it's it's"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_11(self):
|
||||
text = "<div>\n Hello, \t\n<strong>world! it's it`s & | </strong>\n\t</div>"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
lemmatize=True,
|
||||
fix_single_quotes=True,
|
||||
remove_stopwords_en=True,
|
||||
remove_urls=True
|
||||
)
|
||||
expected = "hello world |"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_no_root(self):
|
||||
text = "<a href=\"#p217709510\" class=\"quotelink\">>>217709510</a><br>Is there a<wbr>servant that is against civilization and humanity?<br>Literally instant summon."
|
||||
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
lemmatize=False,
|
||||
fix_single_quotes=True,
|
||||
remove_stopwords_en=False,
|
||||
remove_urls=False
|
||||
)
|
||||
|
||||
expected = ">>217709510 is there a servant that is against civilization and humanity literally instant summon"
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_entity(self):
|
||||
text = "doesn't"
|
||||
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
lemmatize=False,
|
||||
fix_single_quotes=True,
|
||||
remove_stopwords_en=False,
|
||||
remove_urls=False
|
||||
)
|
||||
|
||||
expected = "doesn't"
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_invalid_attribute(self):
|
||||
text = '<root><iframe width="560" height="315" src=" " title="youtube video player" frameborder="0" allowfullscreen></iframe></root>'
|
||||
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
clean_html=True,
|
||||
lowercase=True,
|
||||
remove_punctuation=True,
|
||||
lemmatize=False,
|
||||
fix_single_quotes=True,
|
||||
remove_stopwords_en=False,
|
||||
remove_urls=False
|
||||
)
|
||||
|
||||
expected = ""
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_bigrams(self):
|
||||
text = "x A b c d e f g h"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
lowercase=True,
|
||||
bigrams={
|
||||
("a", "b"),
|
||||
("c", "d"),
|
||||
("f", "g"),
|
||||
}
|
||||
)
|
||||
expected = "x a_b c_d e f_g h"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_trigrams(self):
|
||||
text = "x A b c d e f g h"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
lowercase=True,
|
||||
trigrams={
|
||||
("a", "b", "c"),
|
||||
("e", "f", "g"),
|
||||
}
|
||||
)
|
||||
expected = "x a_b_c d e_f_g h"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_remove_numbers(self):
|
||||
text = "Hello1 test1124test 12 1 1111111 world"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
lowercase=True,
|
||||
remove_numbers=True
|
||||
)
|
||||
expected = "hello1 test1124test world"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_strip_quotes(self):
|
||||
text = "'hi' “test” 'hello\""
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
strip_quotes=True
|
||||
)
|
||||
expected = "hi test hello"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_strip_dashes(self):
|
||||
text = "yes -But something-something -- hello aa--bb"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
strip_dashes=True
|
||||
)
|
||||
expected = "yes But something-something hello aa-bb"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_word_tokenize(self):
|
||||
text = "i cannot believe'"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
use_nltk_tokenizer=True
|
||||
)
|
||||
expected = "i can not believe '"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
21
test/test_web.py
Normal file
21
test/test_web.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from hexlib.web import url_query_value
|
||||
|
||||
|
||||
class TestWebMiscFuncs(TestCase):
|
||||
def test_qs_1(self):
|
||||
url = "https://test.com/page?a=1&b=2&a=2&c=hello"
|
||||
|
||||
self.assertEqual(url_query_value(url, "a"), "1")
|
||||
self.assertEqual(url_query_value(url, "b"), "2")
|
||||
self.assertEqual(url_query_value(url, "c"), "hello")
|
||||
self.assertEqual(url_query_value(url, "D"), None)
|
||||
|
||||
def test_qs_as_list(self):
|
||||
url = "https://test.com/page?a=1&b=2&a=2&c=hello"
|
||||
|
||||
self.assertEqual(url_query_value(url, "a", as_list=True), ["1", "2"])
|
||||
self.assertEqual(url_query_value(url, "b", as_list=True), ["2"])
|
||||
self.assertEqual(url_query_value(url, "c", as_list=True), ["hello"])
|
||||
self.assertEqual(url_query_value(url, "D", as_list=True), [])
|
||||
Reference in New Issue
Block a user