Compare commits

...

92 Commits

Author SHA1 Message Date
b1a1da3bac Add option to use nltk word_tokenize 2023-09-09 11:11:44 -04:00
a047366926 pin pydantic version 2023-07-13 08:27:48 -04:00
24230cdc1e Update PgConn 2023-05-26 14:09:45 -04:00
3bd9f03996 Fix PersistentState constructor 2023-04-10 17:01:42 -04:00
e267bbf1c8 Set json indent=2 for pydantic rows 2023-02-25 16:07:41 -05:00
42e33b72b2 Fix JSON encoding for Enums 2023-02-25 15:51:28 -05:00
5275c332cc Add drop table 2023-02-25 15:38:40 -05:00
a7b1a6e1ec Fix tests, add pydantic row support for PersistentState 2023-02-25 15:20:17 -05:00
826312115c Fix deserialization in PersistentState again 2022-05-07 09:41:10 -04:00
372abb0076 Fix deserialization in PersistentState 2022-05-07 09:34:50 -04:00
78c04ef6f3 Add option to override Table factory in PersistentState 2022-05-05 15:02:48 -04:00
a51ad2cbb4 Cleanup 2022-05-03 10:59:25 -04:00
4befc3973d Add strip_dashes option in preprocess() 2022-02-26 19:31:22 -05:00
c9fac7151a Split punctuation into punctuation and special_punctuation 2022-02-23 11:01:17 -05:00
084acbe184 Set max_window_size=2147483648 for zstd 2022-01-29 10:44:38 -05:00
d578be3218 Increase timeout 2022-01-29 10:38:23 -05:00
cd5a1ac50c Remove clean_multicore function 2022-01-28 20:11:26 -05:00
62e74ed292 Make sure that _id field is present in redis MQ 2022-01-27 11:06:52 -05:00
428c82bcfd Update retry codes 2021-12-09 19:46:12 -05:00
4b3583358b Update retry codes 2021-12-09 19:44:21 -05:00
90d434ec73 Add more single quotes 2021-11-16 15:32:24 -05:00
55fd4a66d2 Fix strip_quotes 2021-11-16 11:48:23 -05:00
3677815d57 Add more quotes in strip_quotes 2021-11-16 11:39:10 -05:00
1ce795a759 ... 2021-11-16 11:36:17 -05:00
e1537297d7 normalize dashes in preprocess 2021-11-16 11:34:48 -05:00
8d8f9e8751 Fix typo, add stateless stream processor 2021-11-04 13:23:23 -04:00
18ba0024ea Null check on logger 2021-11-03 16:47:21 -04:00
408735a926 Fix git clone 2021-11-02 14:38:12 -04:00
2f6c2822b6 Add functions to handle routing keys 2021-10-22 10:36:05 -04:00
d85ad919b3 Add redis MQ, update influxdb monitoring 2021-10-21 20:19:09 -04:00
ed9d148411 Fix tests 2021-10-21 19:52:42 -04:00
5e00ddccdb Add test file 2021-10-07 15:40:23 -04:00
7ecd55a1c6 Add cookiejar_filter_name 2021-10-03 11:02:39 -04:00
b746a91281 Fix default value for retry codes 2021-10-03 10:43:19 -04:00
333083e8b9 Add 520 in default retry codes 2021-10-03 10:33:49 -04:00
c295b5d30b Add Web.post 2021-09-28 15:08:51 -04:00
da0e117550 Fix fake-useragent (for real time time?) 2021-09-25 16:03:45 -04:00
c3fef7e7f8 Fix fake-useragent 2021-09-25 15:53:52 -04:00
9bd1f4b799 unbreak statefulstreamworker 2021-09-23 19:14:11 -04:00
c560cc2010 tweak StatefulStreamWorker interface 2021-09-19 14:19:17 -04:00
f4a5e6cf53 queue_iter fix 2021-09-19 12:44:49 -04:00
71cd00c063 Add StatfulStreamProcessor 2021-09-19 12:39:57 -04:00
7349c9a5f1 Quick optimisation 2021-09-19 10:57:07 -04:00
d19442b00e Update preprocess: now returns generator objects 2021-09-19 09:35:35 -04:00
4711cd1b66 Add trigrams 2021-09-10 17:35:19 -04:00
7e0ffafb8c Update fix_single_quotes 2021-08-28 20:48:00 -04:00
60273fb6bd Use imap instead of map 2021-08-28 20:09:00 -04:00
67c09cc10c Add remove_numbers 2021-08-28 20:06:53 -04:00
a7bf5b2d15 Fix clean html (again!) 2021-08-28 19:59:04 -04:00
31b35e3a32 Fix clean html (again) 2021-08-28 19:44:10 -04:00
4cff343370 version bump 2021-08-28 19:34:09 -04:00
4d6c8018df Fix clean_html 2021-08-28 19:33:11 -04:00
db3e191983 Add plot_confusion_matrix 2021-06-29 14:13:07 -04:00
33e9734991 Add plot_freq_bar 2021-06-23 19:37:22 -04:00
3238f92e4d Revert "get_soup() decode utf8"
This reverts commit f8e93354
2021-05-14 10:09:07 -04:00
f8e93354a4 get_soup() decode utf8 2021-05-14 09:59:59 -04:00
75bf2c2d85 Rename test.clean to text.preprocess, add QS util func, more debug logging 2021-04-25 12:10:03 -04:00
9002ae7506 Add debug info 2021-04-24 10:21:54 -04:00
88f3124f85 add bigram option for clean function 2021-04-21 21:34:49 -04:00
8edad0255b add retries arg in get_web() 2021-04-21 19:50:59 -04:00
32119535ae improve text cleaning 2021-04-18 21:27:12 -04:00
2ffaa4a5b3 improve text cleaning 2021-04-18 21:10:07 -04:00
067a20f7a8 improve text cleaning 2021-04-18 20:32:34 -04:00
00323ea576 improve text cleaning 2021-04-18 18:50:39 -04:00
45b5803c40 improve text cleaning 2021-04-18 15:40:30 -04:00
18cd59fc4a ignore log in text 2021-04-18 12:20:22 -04:00
d895ac837e ignore log in text 2021-04-18 12:18:27 -04:00
ae59522b27 Add customstderr 2021-04-18 12:17:00 -04:00
765f6f59b7 Add text cleaning function 2021-04-18 12:12:31 -04:00
30902c8235 Add sep option in volatile state 2021-04-16 19:10:45 -04:00
53a262a138 Add retry_sleep to retry 2021-04-06 21:24:23 -04:00
6e1aa53455 Add retry_sleep to retry 2021-04-06 21:23:28 -04:00
53ac0c37e8 Fix redis_publish 2021-04-06 21:08:43 -04:00
8378ed6526 Add session arg to get_web 2021-04-06 20:32:34 -04:00
c79b3bfafd Add get_soup() 2021-04-05 19:17:27 -04:00
5ee1629c79 oops 2021-03-28 09:53:16 -04:00
00f5aef721 Update redis_publish to add subproject 2021-03-28 09:42:52 -04:00
021da84433 add redis_publish 2021-03-25 18:25:46 -04:00
53a03baaa4 Add useragent option in Web 2021-03-07 11:02:26 -05:00
66d37e0be2 Add way to manually flush @buffered 2021-02-28 12:15:44 -05:00
43cb6c4a7b web retry_codes fix 2021-02-27 12:07:48 -05:00
4278b0f89e web retry_codes fix 2021-02-27 08:58:21 -05:00
9738819428 web logging fix 2021-02-25 22:00:44 -05:00
ce2e5b2af6 load redis from env if not specified 2021-02-25 21:35:18 -05:00
9cadce62ac add Web helper & logger 2021-02-25 21:26:27 -05:00
7d330a0f9f Add pgsql wrapper & delitem for persistent state 2021-02-06 15:41:18 -05:00
a2cfab55bc msgpack for queue 2021-01-20 20:30:57 -05:00
c4fca1b754 Add volatile queue 2021-01-20 20:08:35 -05:00
d615ebdbd9 add custum SQL in persistant state 2021-01-17 10:05:39 -05:00
f914759b71 Fix @retry 2021-01-10 22:46:20 -05:00
58d150279f Handle bool values in state 2021-01-10 21:26:34 -05:00
b2efaa99a4 Handle null values in state 2021-01-10 20:56:46 -05:00
27 changed files with 1760 additions and 131 deletions

6
.gitignore vendored
View File

@@ -1,3 +1,7 @@
*.iml
.idea/
*.db
*.db
*.png
hexlib.egg-info
build/
dist/

BIN
10MB.bin Normal file

Binary file not shown.

View File

@@ -1,5 +1,5 @@
Misc utility methods in Python
```
git+git://github.com/simon987/hexlib.git
git+https://github.com/simon987/hexlib.git
```

27
bench/text.py Normal file
View File

@@ -0,0 +1,27 @@
from timeit import timeit
t = bytes.maketrans(b".,;:\"!?/()|*=>", b" ")
def translate(x: str):
arr = x.encode("utf8")
return arr.translate(t).decode("utf8")
if __name__ == '__main__':
res = timeit(
setup='t = str.maketrans(".,;:\\"!?/()|*=>", " ")',
stmt='x = "Hello, world %123 & *".translate(t)'
)
# 0.865953s
print("translate = %fs" % res)
res = timeit(
setup='from text import translate',
stmt='x = translate("Hello, world %123 & *")'
)
# 0.865953s
print("custom = %fs" % res)

View File

@@ -1,9 +1,139 @@
from queue import Queue, Empty
from multiprocessing import Process
from multiprocessing import Queue as MPQueue
from queue import Queue, Empty
from threading import Thread
from hexlib.misc import ichunks
def queue_iter(q: Queue, **get_args):
class StatelessStreamWorker:
def __init__(self):
self._q_out = None
def run(self, q: Queue, q_out: Queue):
self._q_out: Queue = q_out
for chunk in queue_iter(q, joinable=False, timeout=10):
self._process_chunk(chunk)
def _process_chunk(self, chunk):
results = []
for item in chunk:
result = self.process(item)
if result is not None:
results.append(result)
if results:
self._q_out.put(results)
def process(self, item):
raise NotImplementedError
class StatelessStreamProcessor:
def __init__(self, worker_factory, chunk_size=128, processes=1, timeout=60):
self._chunk_size = 128
self._queue = MPQueue(maxsize=chunk_size)
self._queue_out = MPQueue(maxsize=processes * 2)
self._process_count = processes
self._processes = []
self._factory = worker_factory
self._workers = []
self._timeout = timeout
if processes > 1:
for _ in range(processes):
worker = self._factory()
p = Process(target=worker.run, args=(self._queue, self._queue_out))
p.start()
self._processes.append(p)
self._workers.append(worker)
else:
self._workers.append(self._factory())
def _ingest(self, iterable):
if self._process_count > 1:
for chunk in ichunks(iterable, self._chunk_size):
self._queue.put(chunk)
else:
for item in iterable:
self._workers[0].process(item)
def ingest(self, iterable):
ingest_thread = Thread(target=self._ingest, args=(iterable,))
ingest_thread.start()
for results in queue_iter(self._queue_out, joinable=False, timeout=self._timeout):
yield from results
ingest_thread.join()
class StatefulStreamWorker:
def __init__(self):
pass
def run(self, q: Queue, q_out: Queue):
for chunk in queue_iter(q, joinable=False, timeout=3):
self._process_chunk(chunk)
q_out.put(self.results())
def _process_chunk(self, chunk):
for item in chunk:
self.process(item)
def process(self, item) -> None:
raise NotImplementedError
def results(self):
raise NotImplementedError
class StatefulStreamProcessor:
def __init__(self, worker_factory, chunk_size=128, processes=1):
self._chunk_size = 128
self._queue = MPQueue(maxsize=chunk_size)
self._queue_out = MPQueue()
self._process_count = processes
self._processes = []
self._factory = worker_factory
self._workers = []
if processes > 1:
for _ in range(processes):
worker = self._factory()
p = Process(target=worker.run, args=(self._queue, self._queue_out))
p.start()
self._processes.append(p)
self._workers.append(worker)
else:
self._workers.append(self._factory())
def ingest(self, iterable):
if self._process_count > 1:
for chunk in ichunks(iterable, self._chunk_size):
self._queue.put(chunk)
else:
for item in iterable:
self._workers[0].process(item)
def get_results(self):
for _ in range(self._process_count):
yield self._queue_out.get()
for p in self._processes:
p.join()
def queue_iter(q: Queue, joinable=True, **get_args):
while True:
try:
task = q.get(**get_args)
@@ -12,7 +142,8 @@ def queue_iter(q: Queue, **get_args):
break
yield task
q.task_done()
if joinable:
q.task_done()
except Empty:
break
except KeyboardInterrupt:

View File

@@ -1,88 +1,124 @@
import base64
import sqlite3
import redis
import orjson as json
import traceback
from datetime import datetime
from enum import Enum
import psycopg2
import umsgpack
from psycopg2.errorcodes import UNIQUE_VIOLATION
from pydantic import BaseModel
from hexlib.env import get_redis
class PersistentState:
"""Quick and dirty persistent dict-like SQLite wrapper"""
def _json_encoder(x):
if isinstance(x, datetime):
return x.isoformat()
if isinstance(x, Enum):
return x.value
def __init__(self, dbfile="state.db", **dbargs):
self.dbfile = dbfile
if dbargs is None:
dbargs = {"timeout": 30000}
self.dbargs = dbargs
def __getitem__(self, table):
return Table(self, table)
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
class VolatileState:
"""Quick and dirty volatile dict-like redis wrapper"""
def __init__(self, prefix, **redis_args):
self.rdb = redis.Redis(**redis_args)
def __init__(self, prefix, redis_db=None, sep=""):
if redis_db is None:
redis_db = get_redis()
self.rdb = redis_db
self.prefix = prefix
self._sep = sep
def __getitem__(self, table):
return RedisTable(self, table)
return RedisTable(self, table, self._sep)
def __delitem__(self, key):
self.rdb.delete(f"{self.prefix}{self._sep}{key}")
class VolatileQueue:
"""Quick and dirty volatile queue-like redis wrapper"""
def __init__(self, key, redis_db=None):
if redis_db is None:
redis_db = get_redis()
self.rdb = redis_db
self.key = key
def put(self, item):
self.rdb.sadd(self.key, umsgpack.dumps(item))
def get(self):
v = self.rdb.spop(self.key)
if v:
return umsgpack.loads(v)
class VolatileBooleanState:
"""Quick and dirty volatile dict-like redis wrapper for boolean values"""
def __init__(self, prefix, **redis_args):
self.rdb = redis.Redis(**redis_args)
def __init__(self, prefix, redis_db=None, sep=""):
if redis_db is None:
redis_db = get_redis()
self.rdb = redis_db
self.prefix = prefix
self._sep = sep
def __getitem__(self, table):
return RedisBooleanTable(self, table)
return RedisBooleanTable(self, table, self._sep)
def __delitem__(self, table):
self.rdb.delete(f"{self.prefix}{self._sep}{table}")
class RedisTable:
def __init__(self, state, table):
def __init__(self, state, table, sep=""):
self._state = state
self._table = table
self._sep = sep
self._key = f"{self._state.prefix}{self._sep}{self._table}"
def __setitem__(self, key, value):
self._state.rdb.hset(self._state.prefix + self._table, str(key), umsgpack.dumps(value))
self._state.rdb.hset(self._key, str(key), umsgpack.dumps(value))
def __getitem__(self, key):
val = self._state.rdb.hget(self._state.prefix + self._table, str(key))
val = self._state.rdb.hget(self._key, str(key))
if val:
return umsgpack.loads(val)
return None
def __delitem__(self, key):
self._state.rdb.hdel(self._state.prefix + self._table, str(key))
self._state.rdb.hdel(self._key, str(key))
def __iter__(self):
val = self._state.rdb.hgetall(self._state.prefix + self._table)
if val:
return ((k, umsgpack.loads(v)) for k, v in
self._state.rdb.hgetall(self._state.prefix + self._table).items())
for val in self._state.rdb.hscan(self._key):
if val:
return ((k, umsgpack.loads(v)) for k, v in val.items())
class RedisBooleanTable:
def __init__(self, state, table):
def __init__(self, state, table, sep=""):
self._state = state
self._table = table
self._sep = sep
self._key = f"{self._state.prefix}{self._sep}{self._table}"
def __setitem__(self, key, value):
if value:
self._state.rdb.sadd(self._state.prefix + self._table, str(key))
self._state.rdb.sadd(self._key, str(key))
else:
self.__delitem__(key)
def __getitem__(self, key):
return self._state.rdb.sismember(self._state.prefix + self._table, str(key))
return self._state.rdb.sismember(self._key, str(key))
def __delitem__(self, key):
self._state.rdb.srem(self._state.prefix + self._table, str(key))
self._state.rdb.srem(self._key, str(key))
def __iter__(self):
return iter(self._state.rdb.smembers(self._state.prefix + self._table))
yield from self._state.rdb.sscan_iter(self._key)
class Table:
@@ -90,22 +126,54 @@ class Table:
self._state = state
self._table = table
def __iter__(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
cur = conn.execute("SELECT * FROM %s" % (self._table,))
for row in cur:
yield dict(row)
except:
return None
def __getitem__(self, item):
def _sql_dict(self, where_clause, *params):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
for row in cur:
yield dict(
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
for i, col in enumerate(cur.description)
)
except:
return None
def sql(self, where_clause, *params):
for row in self._sql_dict(where_clause, *params):
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _iter_dict(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s" % (self._table,))
for row in cur:
yield dict(
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
for i, col in enumerate(cur.description)
)
except:
return None
def __iter__(self):
for row in self._iter_dict():
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _getitem_dict(self, key):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
row = cur.fetchone()
if row:
@@ -116,8 +184,32 @@ class Table:
except:
return None
@staticmethod
def _deserialize_pydantic(row):
module = __import__(row["__module"])
cls = getattr(module, row["__class"])
return cls.parse_raw(row["json"])
def __getitem__(self, key):
row = self._getitem_dict(key)
if row and "__pydantic" in row:
return self._deserialize_pydantic(row)
return row
def setitem_pydantic(self, key, value: BaseModel):
self.__setitem__(key, {
"json": value.json(encoder=_json_encoder, indent=2),
"__class": value.__class__.__name__,
"__module": value.__class__.__module__,
"__pydantic": 1
})
def __setitem__(self, key, value):
if isinstance(value, BaseModel):
self.setitem_pydantic(key, value)
return
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
@@ -146,6 +238,13 @@ class Table:
args
)
def __delitem__(self, key):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
try:
conn.execute("DELETE FROM %s WHERE id=?" % self._table, (key,))
except sqlite3.OperationalError:
pass
def _sqlite_type(value):
if isinstance(value, int):
@@ -160,15 +259,41 @@ def _sqlite_type(value):
def _serialize(value):
if isinstance(value, bytes):
return base64.b64encode(value)
if value is None:
return None
if isinstance(value, bool):
return value
return str(value)
def _deserialize(value, col_type):
if col_type == "blob":
if col_type.lower() == "blob":
return base64.b64decode(value)
return value
class PersistentState:
"""Quick and dirty persistent dict-like SQLite wrapper"""
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
self.dbfile = dbfile
self.logger = logger
if dbargs is None or dbargs == {}:
dbargs = {"timeout": 30000}
self.dbargs = dbargs
self._table_factory = table_factory
def __getitem__(self, table):
return self._table_factory(self, table)
def __delitem__(self, key):
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
try:
conn.execute(f"DROP TABLE {key}")
except:
pass
def pg_fetch_cursor_all(cur, name, batch_size=1000):
while True:
cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name))
@@ -183,3 +308,66 @@ def pg_fetch_cursor_all(cur, name, batch_size=1000):
for row in cur:
yield row
break
class PgConn:
"""Wrapper for PostgreSQL connection"""
def __init__(self, logger=None, **kwargs):
self._conn_args = kwargs
self.conn = psycopg2.connect(**kwargs)
self._logger = logger
def __enter__(self):
self.cur = self.conn.cursor()
return self
def exec(self, query_string, args=None):
while True:
try:
if self._logger:
self._logger.debug(query_string)
self._logger.debug("With args " + str(args))
self.cur.execute(query_string, args)
break
except psycopg2.Error as e:
if e.pgcode == UNIQUE_VIOLATION:
break
traceback.print_stack()
self._handle_err(e, query_string, args)
def query(self, query_string, args=None, max_retries=1):
retries = max_retries
while retries > 0:
try:
if self._logger:
self._logger.debug(query_string)
self._logger.debug("With args " + str(args))
self.cur.execute(query_string, args)
res = self.cur.fetchall()
if self._logger:
self._logger.debug("result: " + str(res))
return res
except psycopg2.Error as e:
if e.pgcode == UNIQUE_VIOLATION:
break
self._handle_err(e, query_string, args)
retries -= 1
def _handle_err(self, err, query, args):
if self._logger:
self._logger.warning(
"Error during query '%s' with args %s: %s %s (%s)" % (query, args, type(err), err, err.pgcode))
self.conn = psycopg2.connect(**self._conn_args)
self.cur = self.conn.cursor()
def __exit__(self, type, value, traceback):
try:
self.conn.commit()
self.cur.close()
except:
pass

63
hexlib/env.py Normal file
View File

@@ -0,0 +1,63 @@
import os
import redis
from fake_useragent import UserAgent
from hexlib.log import stdout_logger
from hexlib.web import Web
ARC_LISTS = os.environ.get("ARC_LISTS", "arc").split(",")
PUBLISH_CHANNEL = os.environ.get("PUBLISH_CHANNEL", None)
def get_redis():
return redis.Redis(
host=os.environ.get("REDIS_HOST", "localhost"),
port=int(os.environ.get("REDIS_PORT", 6379))
)
def redis_publish(rdb, item, item_project, item_type, item_subproject=None, item_category="x"):
item_project = item_project.replace(".", "-")
item_subproject = item_subproject.replace(".", "-") if item_subproject else None
item_source = item_project if not item_subproject else f"{item_project}.{item_subproject}"
item_type = item_type.replace(".", "-")
item_category = item_category.replace(".", "-")
if PUBLISH_CHANNEL is not None:
routing_key = f"{PUBLISH_CHANNEL}.{item_source}.{item_type}.{item_category}"
rdb.publish(routing_key, item)
for arc_list in ARC_LISTS:
routing_key = f"{arc_list}.{item_source}.{item_type}.{item_category}"
rdb.lpush(routing_key, item)
def get_web(session=None):
ua = UserAgent()
retry_codes = os.environ.get("RETRY_CODES", "")
web = Web(
session=session,
proxy=os.environ.get("PROXY", None),
rps=os.environ.get("RPS", 1),
logger=stdout_logger,
cookie_file=os.environ.get("COOKIE_FILE", None),
retry_codes=set(int(x) for x in retry_codes.split(",")) if retry_codes else None,
retries=int(os.environ.get("RETRIES", 3)),
retry_sleep=int(os.environ.get("RETRY_SLEEP", 0)),
ua=ua[os.environ.get("USER_AGENT")] if os.environ.get("USER_AGENT", None) is not None else None
)
if hasattr(web._session, "cipherSuite"):
stdout_logger.debug("Web>cipherSuite=%s" % web._session.cipherSuite)
if hasattr(web._session, "headers"):
stdout_logger.debug("Web>headers=%s" % web._session.headers)
if hasattr(web._session, "cookies"):
stdout_logger.debug("Web>cookies=%s" % web._session.cookies)
stdout_logger.debug("Web>rps=%s" % os.environ.get("RPS", 1))
return web

View File

@@ -62,6 +62,16 @@ COMPRESSION_GZIP = "gz"
COMPRESSION_ZSTD = "zstd"
class NDJsonLine:
__slots__ = "text"
def __init__(self, text):
self.text = text
def json(self):
return json.loads(self.text)
def ndjson_iter(*files, compression=""):
for file in files:
cleanup = None
@@ -75,7 +85,7 @@ def ndjson_iter(*files, compression=""):
line_iter = BufferedReader(gzip.open(file))
elif compression == COMPRESSION_ZSTD:
fp = open(file, "rb")
dctx = zstandard.ZstdDecompressor()
dctx = zstandard.ZstdDecompressor(max_window_size=2147483648)
reader = dctx.stream_reader(fp)
line_iter = BufferedReader(reader)
@@ -90,7 +100,6 @@ def ndjson_iter(*files, compression=""):
line_iter.close()
for line in line_iter:
yield json.loads(line)
yield NDJsonLine(line)
if cleanup:
cleanup()

51
hexlib/log.py Normal file
View File

@@ -0,0 +1,51 @@
import logging
import os
import sys
from logging import StreamHandler
DATE_FMT = "%Y-%m-%d %H:%M:%S"
class ColorFormatter(logging.Formatter):
def __init__(self, fmt):
super().__init__()
grey = "\x1b[38;21m"
yellow = "\x1b[33;21m"
red = "\x1b[31;21m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
self.formats = {
logging.DEBUG: logging.Formatter(grey + fmt + reset, datefmt=DATE_FMT),
logging.INFO: logging.Formatter(grey + fmt + reset, datefmt=DATE_FMT),
logging.WARNING: logging.Formatter(yellow + fmt + reset, datefmt=DATE_FMT),
logging.ERROR: logging.Formatter(red + fmt + reset, datefmt=DATE_FMT),
logging.CRITICAL: logging.Formatter(bold_red + fmt + reset, datefmt=DATE_FMT)
}
def format(self, record):
return self.formats[record.levelno].format(record)
stdout_logger = logging.getLogger("default")
if os.environ.get("LOG_LEVEL", "debug") == "debug":
stdout_logger.setLevel(logging.DEBUG)
for h in stdout_logger.handlers:
stdout_logger.removeHandler(h)
handler = StreamHandler(sys.stdout)
if os.environ.get("LOG_THREAD_NAME", "0") == "1":
fmt = "%(asctime)s %(levelname)-5s>%(threadName)s %(message)s"
else:
fmt = "%(asctime)s %(levelname)-5s>%(message)s"
if os.environ.get("LOG_COLORS", "1") == "1":
handler.formatter = ColorFormatter(fmt)
else:
handler.formatter = logging.Formatter(fmt, datefmt='%Y-%m-%d %H:%M:%S')
stdout_logger.addHandler(handler)
logger = stdout_logger

View File

@@ -1,23 +1,31 @@
import atexit
import itertools
import os
import sys
import time
from threading import Lock
from time import sleep
import atexit
import siphash
last_time_called = dict()
def retry(attempts, callback=None):
def retry(attempts, callback=None, retry_sleep=0):
def decorate(func):
retries = attempts
while retries > 0:
try:
func()
except Exception as e:
if callback:
callback(e)
def wrapper(*args, **kwargs):
retries = attempts
while retries > 0:
try:
return func(*args, **kwargs)
except Exception as e:
if callback:
callback(e)
retries -= 1
sleep(retry_sleep)
return wrapper
return decorate
@@ -26,6 +34,15 @@ def chunks(lst: list, chunk_len: int):
yield lst[i:i + chunk_len]
def ichunks(iterable, chunk_len: int):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, chunk_len))
if not chunk:
break
yield chunk
def rate_limit(per_second):
min_interval = 1.0 / float(per_second)
@@ -54,6 +71,12 @@ def buffered(batch_size: int, flush_on_exit: bool = False):
atexit.register(func, buffer)
def wrapper(items):
if items is None:
func(buffer)
buffer.clear()
return
with lock:
for item in items:
buffer.append(item)
@@ -93,4 +116,20 @@ class CustomStdOut:
self.fp.close()
class CustomStdErr:
original_stderr = sys.stderr
def __init__(self, fname):
self.fname = fname
def __enter__(self):
self.fp = open(self.fname, "w")
sys.stderr = self.fp
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout = CustomStdErr.original_stderr
self.fp.close()
silent_stdout = CustomStdOut(os.devnull)
silent_stderr = CustomStdErr(os.devnull)

View File

@@ -1,22 +1,29 @@
import logging
import traceback
from abc import ABC
from influxdb import InfluxDBClient
from hexlib.misc import buffered
class Monitoring:
def __init__(self, db, host="localhost", logger=logging.getLogger("default"), batch_size=1, flush_on_exit=False):
self._db = db
self._client = InfluxDBClient(host, 8086, "", "", db)
class Monitoring(ABC):
def log(self, points):
raise NotImplementedError()
class BufferedInfluxDBMonitoring(Monitoring):
def __init__(self, db_name, host="localhost", port=8086, logger=None, batch_size=1, flush_on_exit=False):
self._db = db_name
self._client = InfluxDBClient(host, port, "", "", db_name)
self._logger = logger
self._init()
if not self.db_exists(self._db):
self._client.create_database(self._db)
@buffered(batch_size, flush_on_exit)
def log(points):
self._log(points)
self.log = log
def db_exists(self, name):
@@ -25,14 +32,16 @@ class Monitoring:
return True
return False
def _init(self):
if not self.db_exists(self._db):
self._client.create_database(self._db)
def log(self, points):
# Is overwritten in __init__()
pass
def _log(self, points):
try:
self._client.write_points(points)
self._logger.debug("InfluxDB: Wrote %d points" % len(points))
if self._logger:
self._logger.debug("InfluxDB: Wrote %d points" % len(points))
except Exception as e:
self._logger.debug(traceback.format_exc())
self._logger.error(str(e))
if self._logger:
self._logger.debug(traceback.format_exc())
self._logger.error(str(e))

161
hexlib/mq.py Normal file
View File

@@ -0,0 +1,161 @@
import json
from collections import namedtuple
from functools import partial
from itertools import islice
from time import sleep, time
from orjson import orjson
from redis import Redis
RoutingKeyParts = namedtuple(
"RoutingKeyParts",
["arc_list", "project", "subproject", "type", "category"]
)
def parse_routing_key(key):
tokens = key.split(".")
if len(tokens) == 4:
arc_list, project, type_, category = tokens
return RoutingKeyParts(
arc_list=arc_list,
project=project,
subproject=None,
type=type_,
category=category
)
else:
arc_list, project, subproject, type_, category = tokens
return RoutingKeyParts(
arc_list=arc_list,
project=project,
subproject=subproject,
type=type_,
category=category
)
class MessageQueue:
def read_messages(self, topics):
raise NotImplementedError()
def publish(self, item, item_project, item_type, item_subproject, item_category):
raise NotImplementedError()
class RedisMQ(MessageQueue):
_MAX_KEYS = 30
def __init__(self, rdb, consumer_name="redis_mq", sep=".", max_pending_time=120, logger=None, publish_channel=None,
arc_lists=None, wait=1):
self._rdb: Redis = rdb
self._key_cache = None
self._consumer_id = consumer_name
self._pending_list = f"pending{sep}{consumer_name}"
self._max_pending_time = max_pending_time
self._logger = logger
self._publish_channel = publish_channel
self._arc_lists = arc_lists
self._wait = wait
def _get_keys(self, pattern):
if self._key_cache:
return self._key_cache
keys = list(islice(
self._rdb.scan_iter(match=pattern, count=RedisMQ._MAX_KEYS), RedisMQ._MAX_KEYS
))
self._key_cache = keys
return keys
def _get_pending_tasks(self):
for task_id, pending_task in self._rdb.hscan_iter(self._pending_list):
pending_task_json = orjson.loads(pending_task)
if time() >= pending_task_json["resubmit_at"]:
yield pending_task_json["topic"], pending_task_json["task"], partial(self._ack, task_id)
def _ack(self, task_id):
self._rdb.hdel(self._pending_list, task_id)
def read_messages(self, topics):
"""
Assumes json-encoded tasks with an _id field
Tasks are automatically put into a pending list until ack() is called.
When a task has been in the pending list for at least max_pending_time seconds, it
gets submitted again
"""
assert len(topics) == 1, "RedisMQ only supports 1 topic pattern"
pattern = topics[0]
counter = 0
if self._logger:
self._logger.info(f"MQ>Listening for new messages in {pattern}")
while True:
counter += 1
if counter % 1000 == 0:
yield from self._get_pending_tasks()
keys = self._get_keys(pattern)
if not keys:
sleep(self._wait)
self._key_cache = None
continue
result = self._rdb.blpop(keys, timeout=1)
if not result:
self._key_cache = None
continue
topic, task = result
task_json = orjson.loads(task)
topic = topic.decode()
if "_id" not in task_json or not task_json["_id"]:
raise ValueError(f"Task doesn't have _id field: {task}")
# Immediately put in pending queue
self._rdb.hset(
self._pending_list, task_json["_id"],
orjson.dumps({
"resubmit_at": time() + self._max_pending_time,
"topic": topic,
"task": task_json
})
)
yield topic, task_json, partial(self._ack, task_json["_id"])
def publish(self, item, item_project, item_type, item_subproject=None, item_category="x"):
if "_id" not in item:
raise ValueError("_id field must be set for item")
item = json.dumps(item, separators=(',', ':'), ensure_ascii=False, sort_keys=True)
item_project = item_project.replace(".", "-")
item_subproject = item_subproject.replace(".", "-") if item_subproject else None
item_source = item_project if not item_subproject else f"{item_project}.{item_subproject}"
item_type = item_type.replace(".", "-")
item_category = item_category.replace(".", "-")
# If specified, fan-out to pub/sub channel
if self._publish_channel is not None:
routing_key = f"{self._publish_channel}.{item_source}.{item_type}.{item_category}"
self._rdb.publish(routing_key, item)
# Save to list
for arc_list in self._arc_lists:
routing_key = f"{arc_list}.{item_source}.{item_type}.{item_category}"
self._rdb.lpush(routing_key, item)

View File

@@ -2,3 +2,6 @@ import re
LINK_RE = re.compile(r"(https?://[\w\-_.]+\.[a-z]{2,4}([^\s<'\"]*|$))")
HTML_HREF_RE = re.compile(r"href=\"([^\"]+)\"")
WHITESPACE_RE = re.compile(r"\s+")
PUNCTUATION_RE = re.compile(r"[.,;:\"“!?/()|*=>]+")
XML_ENTITY_RE = re.compile(r"&[a-z]+;")

128
hexlib/text.py Normal file
View File

@@ -0,0 +1,128 @@
import re
from itertools import chain, repeat
import nltk.corpus
from lxml import etree
from nltk import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from .regex_util import LINK_RE
get_text = etree.XPath("//text()")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
stop_words_en = set(stopwords.words("english"))
extra_stop_words_en = [
"u", "&", "-", "--"
]
stop_words_en.update(extra_stop_words_en)
lemmatizer = WordNetLemmatizer()
def _transform_bigram(ngram_seq, ngrams):
for ngram in ngram_seq:
if ngram in ngrams:
yield ngram[0] + "_" + ngram[1]
next(ngram_seq)
else:
yield ngram[0]
def _transform_trigram(ngram_seq, ngrams):
for ngram in ngram_seq:
if ngram in ngrams:
# yield ngram[0] + "_" + ngram[1] + "_" + ngram[2]
yield "_".join(ngram)
next(ngram_seq)
next(ngram_seq)
else:
yield ngram[0]
SINGLE_QUOTES = ("", "`", "")
SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", len(SINGLE_QUOTES))))
DASHES = ("", "", "", "")
DASHES_TRANS = str.maketrans("".join(DASHES), "".join(repeat("-", len(DASHES))))
DASHES_RE = re.compile(r"-+")
SPECIAL_PUNCTUATION = ";:\"/()|*=>"
SPECIAL_PUNCTUATION_TRANS = str.maketrans(SPECIAL_PUNCTUATION, " " * len(SPECIAL_PUNCTUATION))
PUNCTUATION = ".,!?"
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION))
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_special_punctuation=False,
remove_stopwords_en=False, lemmatize=False, fix_single_quotes=False, strip_quotes=False,
strip_dashes=False,
remove_urls=False, bigrams: set = None, trigrams: set = None, remove_numbers=False,
use_nltk_tokenizer=False):
if lowercase:
text = text.lower()
if fix_single_quotes:
text = text.translate(SINGLE_QUOTE_TRANS)
text = text.translate(DASHES_TRANS)
if strip_dashes:
text = DASHES_RE.sub("-", text)
if remove_urls:
text = LINK_RE.sub(" ", text)
if clean_html:
try:
text = "<root>" + text + "</root>"
parser = etree.XMLParser(recover=True)
root = etree.fromstring(text, parser)
text = " ".join(get_text(root))
except:
pass
if remove_punctuation:
text = text.translate(PUNCTUATION_TRANS)
if remove_special_punctuation:
text = text.translate(SPECIAL_PUNCTUATION_TRANS)
if use_nltk_tokenizer:
words = word_tokenize(text, language="english")
else:
words = text.split()
if strip_quotes:
words = map(lambda w: w.strip("\"'“”"), words)
if strip_dashes:
words = map(lambda w: w.strip("-"), words)
if bigrams:
words = _transform_bigram(nltk.bigrams(chain(words, ("*",))), bigrams)
if trigrams:
words = _transform_trigram(nltk.trigrams(chain(words, ("*", "*"))), trigrams)
if remove_numbers:
words = filter(lambda w: not w.isnumeric(), words)
if lemmatize:
words = map(lambda w: lemmatizer.lemmatize(w), words)
if remove_stopwords_en:
words = filter(lambda w: w not in stop_words_en, words)
return filter(lambda w: w != "", words)

View File

@@ -4,12 +4,19 @@ import os
from datetime import datetime
from base64 import b64encode, b64decode
from http.cookiejar import Cookie
from time import time
from urllib.parse import urlparse, parse_qs
from bs4 import BeautifulSoup
import requests
import orjson as json
from dateutil.parser import parse
from requests.cookies import RequestsCookieJar
from hexlib.misc import rate_limit, retry
def cookie_from_string(text: str, domain: str) -> Cookie:
tokens = [t.strip() for t in text.split(";")]
@@ -74,6 +81,26 @@ def cookiejar_filter(cj, pattern):
return filtered_cj
def cookiejar_filter_name(cj, pattern):
filtered_cj = RequestsCookieJar()
for c in cj:
if re.match(pattern, c.name):
filtered_cj.set_cookie(c)
return filtered_cj
def url_query_value(url, arg, as_list=False):
qs = urlparse(url).query
parsed_qs = parse_qs(qs)
arg = parsed_qs.get(arg, [])
if as_list:
return arg if arg else []
else:
return arg[0] if arg else None
def download_file(url, destination, session=None, headers=None, overwrite=False, retries=1, err_cb=None,
save_meta=False):
if os.path.exists(destination) and not overwrite:
@@ -99,8 +126,104 @@ def download_file(url, destination, session=None, headers=None, overwrite=False,
"url": url,
"timestamp": datetime.utcnow().replace(microsecond=0).isoformat()
}))
r.close()
break
except Exception as e:
if err_cb:
err_cb(e)
retries -= 1
class Web:
def __init__(self, proxy=None, rps=1, retries=3, retry_sleep=0, logger=None, cookie_file=None, retry_codes=None,
session=None,
ua=None):
self._cookie_file = cookie_file
self._proxy = proxy
self._logger = logger
self._current_req = None
if retry_codes is None or not retry_codes:
retry_codes = {500, 502, 503, 504, 520, 522, 524, 429}
self._retry_codes = retry_codes
if session is None:
session = requests.session()
self._session = session
if ua is not None:
session.headers["User-Agent"] = ua
if self._cookie_file:
self._session.cookies = load_cookiejar(cookie_file)
if self._proxy:
self._session.proxies = {
"http": proxy,
"https": proxy,
}
@rate_limit(rps)
@retry(retries, callback=self._error_callback, retry_sleep=retry_sleep)
def get(url, **kwargs):
self._current_req = "GET", url, kwargs
r = self._session.get(url, **kwargs)
if r.status_code in self._retry_codes:
raise Exception(f"HTTP {r.status_code}")
return r
self._get = get
@rate_limit(rps)
@retry(retries, callback=self._error_callback, retry_sleep=retry_sleep)
def post(url, **kwargs):
self._current_req = "POST", url, kwargs
r = self._session.post(url, **kwargs)
if r.status_code in self._retry_codes:
raise Exception(f"HTTP {r.status_code}")
return r
self._post = post
def _error_callback(self, e):
if self._logger:
self._logger.critical(f"{self._format_url(*self._current_req)}: {e}")
def _format_url(self, method, url, kwargs, r=None):
if "params" in kwargs and kwargs["params"]:
return "%s %s?%s <%s>" % (method, url, "&".join(f"{k}={v}" for k, v in kwargs["params"].items()),
r.status_code if r is not None else "ERR")
else:
return "%s %s <%s>" % (method, url, r.status_code if r is not None else "ERR",)
def get(self, url, **kwargs):
time_start = time()
r = self._get(url, **kwargs)
if self._cookie_file:
save_cookiejar(self._session.cookies, self._cookie_file)
if self._logger and r is not None:
self._logger.debug(self._format_url("GET", url, kwargs, r) + " %.2fs" % (time() - time_start))
return r
def post(self, url, **kwargs):
time_start = time()
r = self._post(url, **kwargs)
if self._cookie_file:
save_cookiejar(self._session.cookies, self._cookie_file)
if self._logger and r is not None:
self._logger.debug(self._format_url("POST", url, kwargs, r) + " %.2fs" % (time() - time_start))
return r
def get_soup(self, url, **kwargs):
r = self.get(url, **kwargs)
if not r:
return None
return BeautifulSoup(r.content, "html.parser")

View File

@@ -2,7 +2,7 @@ from setuptools import setup
setup(
name="hexlib",
version="1.20",
version="1.89",
description="Misc utility methods",
author="simon987",
author_email="me@simon987.net",
@@ -12,7 +12,9 @@ setup(
"data/*"
]},
install_requires=[
"ImageHash", "influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
"u-msgpack-python"
"influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
"u-msgpack-python", "psycopg2-binary", "bs4", "lxml", "nltk", "numpy",
"matplotlib", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent",
"requests", "pydantic==1.10.11"
]
)

0
test/__init__.py Normal file
View File

View File

@@ -1,49 +0,0 @@
import os
from unittest import TestCase
from hexlib.db import PersistentState
class TestPersistentState(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
s["a"]["1"] = val
val["id"] = "1"
self.assertDictEqual(val, s["a"]["1"])
def test_get_set_int_id(self):
s = PersistentState()
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
s["a"][1] = val
val["id"] = 1
self.assertDictEqual(val, s["a"][1])
def test_update_partial(self):
s = PersistentState()
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
s["a"][1] = val
s["a"][1] = {
"a": 2
}
val["a"] = 2
val["id"] = 1
self.assertDictEqual(val, s["a"][1])

View File

@@ -0,0 +1,143 @@
import os
from unittest import TestCase
from hexlib.db import PersistentState
class TestPersistentState(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
s["a"]["1"] = val
val["id"] = "1"
self.assertDictEqual(val, s["a"]["1"])
def test_get_set_int_id(self):
s = PersistentState()
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
s["a"][1] = val
val["id"] = 1
self.assertDictEqual(val, s["a"][1])
def test_update_partial(self):
s = PersistentState()
val = {"a": 1, "b": "2", "c": b'3', "d": 4.4}
s["a"][1] = val
s["a"][1] = {
"a": 2
}
val["a"] = 2
val["id"] = 1
self.assertDictEqual(val, s["a"][1])
def test_none(self):
s = PersistentState()
val = {"a": 1, "b": None}
s["a"][1] = val
s["a"][1] = {
"a": None
}
val["a"] = None
val["id"] = 1
self.assertDictEqual(val, s["a"][1])
def test_bool(self):
s = PersistentState()
val = {"a": True, "b": False}
s["a"][1] = val
s["a"][1] = {
"a": True
}
val["a"] = True
val["id"] = 1
self.assertDictEqual(val, s["a"][1])
def test_sql(self):
s = PersistentState()
s["a"][1] = {"a": True}
s["a"][2] = {"a": False}
s["a"][3] = {"a": True}
items = list(s["a"].sql("WHERE a=0 ORDER BY id"))
self.assertDictEqual(items[0], s["a"][2])
def test_delitem(self):
s = PersistentState()
s["a"][1] = {"a": True}
del s["a"][1]
self.assertIsNone(s["a"][1])
def test_delitem_nonexistent(self):
s = PersistentState()
s["a"][1] = {"a": True}
del s["a"][456]
self.assertIsNotNone(s["a"][1])
def test_delitem_no_table(self):
s = PersistentState()
try:
del s["a"][456]
except Exception as e:
self.fail(e)
def test_deserialize_get_set(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(s["a"][0]["x"], b'abc')
def test_deserialize_sql(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"].sql("WHERE 1=1"))[0]["x"], b'abc')
def test_deserialize_iter(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"])[0]["x"], b'abc')
def test_drop_table(self):
s = PersistentState()
s["a"][0] = {"x": 1}
s["a"][1] = {"x": 2}
self.assertEqual(len(list(s["a"])), 2)
del s["a"]
self.assertEqual(len(list(s["a"])), 0)

110
test/test_PydanticTable.py Normal file
View File

@@ -0,0 +1,110 @@
import os
from datetime import datetime
from enum import Enum
from typing import Optional
from unittest import TestCase
from pydantic import BaseModel
from pydantic.types import List
from hexlib.db import PersistentState
class Status(Enum):
yes = "yes"
no = "no"
class Point(BaseModel):
x: int
y: int
class Polygon(BaseModel):
points: List[Point] = []
created_date: datetime
status: Status = Status("yes")
class TestPydanticTable(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
],
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
self.assertEqual(s["a"]["1"].status, Status("yes"))
self.assertEqual(s["a"]["1"].points[1].x, 3)
self.assertEqual(s["a"]["1"].created_date.year, 2000)
def test_update(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
]
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
val.points[0].x = 2
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 2)
def test_sql(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"].sql(
"WHERE json->>'created_date' LIKE '2000-%'"
))
self.assertEqual(len(result), 1)
self.assertEqual(result[0].created_date.year, 2000)
def test_iterate(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"])
self.assertEqual(len(result), 2)
self.assertEqual(result[0].created_date.year, 2000)
self.assertEqual(result[1].created_date.year, 2010)

View File

@@ -1,9 +1,15 @@
from unittest import TestCase
from hexlib.db import VolatileState, VolatileBooleanState
from hexlib.db import VolatileState, VolatileBooleanState, VolatileQueue
from hexlib.env import get_redis
class TestVolatileState(TestCase):
def setUp(self) -> None:
rdb = get_redis()
rdb.delete("test1a", "test1b", "test1c", "test1:a", "test2b")
def test_get_set(self):
s = VolatileState(prefix="test1")
val = {
@@ -15,6 +21,17 @@ class TestVolatileState(TestCase):
self.assertDictEqual(val, s["a"]["1"])
def test_sep(self):
s = VolatileState(prefix="test1", sep=":")
val = {
"field1": 1,
"arr1": [1, 2, 3]
}
s["a"]["1"] = val
self.assertDictEqual(val, s["a"]["1"])
def test_iter(self):
s = VolatileState(prefix="test2")
@@ -23,7 +40,7 @@ class TestVolatileState(TestCase):
s["b"]["3"] = 3
s["b"]["4"] = 4
self.assertEqual(sum(v for k,v in s["b"]), 10)
self.assertEqual(sum(v for k, v in s["b"]), 10)
def test_int_key(self):
s = VolatileState(prefix="test2")
@@ -41,6 +58,10 @@ class TestVolatileState(TestCase):
class TestVolatileBoolState(TestCase):
def setUp(self) -> None:
rdb = get_redis()
rdb.delete("test1a", "test1b", "test1c", "test1:a", "test2b")
def test_get_set(self):
s = VolatileBooleanState(prefix="test1")
@@ -51,6 +72,16 @@ class TestVolatileBoolState(TestCase):
self.assertTrue(s["a"]["2"])
self.assertFalse(s["a"]["3"])
def test_sep(self):
s = VolatileBooleanState(prefix="test1", sep=":")
s["a"]["1"] = True
s["a"]["2"] = True
self.assertTrue(s["a"]["1"])
self.assertTrue(s["a"]["2"])
self.assertFalse(s["a"]["3"])
def test_iter(self):
s = VolatileBooleanState(prefix="test2")
@@ -68,3 +99,30 @@ class TestVolatileBoolState(TestCase):
self.assertTrue(s["c"]["1"])
del s["c"]["1"]
self.assertFalse(s["c"]["1"])
class TestVolatileQueue(TestCase):
def test_simple(self):
s = VolatileQueue(key="test5")
s.put("test")
item = s.get()
self.assertTrue(item == "test")
def test_dict(self):
s = VolatileQueue(key="test5")
s.put({"a": 1})
item = s.get()
self.assertTrue(item["a"] == 1)
def test_int(self):
s = VolatileQueue(key="test5")
s.put(123)
item = s.get()
self.assertTrue(item == 123)

36
test/test_buffered.py Normal file
View File

@@ -0,0 +1,36 @@
from unittest import TestCase
from hexlib.misc import buffered
class TestBuffered(TestCase):
def test_simple(self):
my_list = []
@buffered(batch_size=2)
def put_item(items):
my_list.extend(items)
put_item([1, 2])
put_item([1])
put_item([1])
put_item([1])
self.assertEqual(len(my_list), 4)
def test_flush(self):
my_list = []
@buffered(batch_size=2)
def put_item(items):
my_list.extend(items)
put_item([1, 2])
put_item([1])
put_item([1])
put_item([1])
put_item(None)
self.assertEqual(len(my_list), 5)

View File

@@ -1,13 +1,17 @@
from unittest import TestCase
import os
from unittest import TestCase
from hexlib.web import download_file
import warnings
class TestDownloadFile(TestCase):
def setUp(self) -> None:
warnings.filterwarnings(action="ignore", category=ResourceWarning)
def test_download_file(self):
download_file("http://ovh.net/files/10Mb.dat", "/tmp/10Mb.dat")
download_file("https://github.com/simon987/hexlib/raw/master/10MB.bin", "/tmp/10Mb.dat")
self.assertTrue(os.path.exists("/tmp/10Mb.dat"))
os.remove("/tmp/10Mb.dat")
@@ -22,8 +26,8 @@ class TestDownloadFile(TestCase):
self.assertEqual(len(exceptions), 3)
def test_download_file_meta(self):
download_file("http://ovh.net/files/10Mb.dat", "/tmp/10Mb.dat", save_meta=True)
download_file("https://github.com/simon987/hexlib/raw/master/10MB.bin", "/tmp/10Mb.dat", save_meta=True)
self.assertTrue(os.path.exists("/tmp/10Mb.dat"))
self.assertTrue(os.path.exists("/tmp/10Mb.dat.meta"))
os.remove("/tmp/10Mb.dat")
# os.remove("/tmp/10Mb.dat.meta")
os.remove("/tmp/10Mb.dat.meta")

62
test/test_redis_mq.py Normal file
View File

@@ -0,0 +1,62 @@
from unittest import TestCase
from hexlib.env import get_redis
from hexlib.mq import RedisMQ, parse_routing_key, RoutingKeyParts
class TestRedisMQ(TestCase):
def setUp(self) -> None:
self.rdb = get_redis()
self.rdb.delete("pending.test", "test_mq", "arc.test.msg.x")
def test_ack(self):
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=2, arc_lists=["arc"])
mq.publish({"_id": 1}, item_project="test", item_type="msg")
topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.*"]))
self.assertEqual(self.rdb.hlen("pending.test"), 1)
ack1()
self.assertEqual(self.rdb.hlen("pending.test"), 0)
def test_pending_timeout(self):
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0)
mq.publish({"_id": 1}, item_project="test", item_type="msg")
topic1, msg1, ack1 = next(mq.read_messages(topics=["arc.test.*"]))
self.assertEqual(self.rdb.hlen("pending.test"), 1)
# msg1 will timeout after 0.5s, next iteration takes ceil(0.5)s
topic1_, msg1_, ack1_ = next(mq.read_messages(topics=["arc.test.*"]))
self.assertEqual(self.rdb.hlen("pending.test"), 1)
ack1_()
self.assertEqual(self.rdb.hlen("pending.test"), 0)
self.assertEqual(msg1, msg1_)
def test_no_id_field(self):
mq = RedisMQ(self.rdb, consumer_name="test", max_pending_time=0.5, arc_lists=["arc"], wait=0)
with self.assertRaises(ValueError):
mq.publish({"a": 1}, item_project="test", item_type="msg")
class TestRoutingKey(TestCase):
def test1(self):
key = "arc.chan.4chan.post.b"
parts = parse_routing_key(key)
self.assertEqual(parts, RoutingKeyParts("arc", "chan", "4chan", "post", "b"))
def test2(self):
key = "arc.reddit.submission.birdpics"
parts = parse_routing_key(key)
self.assertEqual(parts, RoutingKeyParts("arc", "reddit", None, "submission", "birdpics"))

27
test/test_retry.py Normal file
View File

@@ -0,0 +1,27 @@
from unittest import TestCase
from hexlib.misc import retry
class TestRetry(TestCase):
def test_simple(self):
@retry(attempts=3)
def a(i):
return i + 1
self.assertEqual(a(1), 2)
def test_error(self):
arr = []
def cb(e):
arr.append(e)
@retry(attempts=3, callback=cb)
def a(i):
raise Exception("err")
a(1)
self.assertEqual(3, len(arr))

279
test/test_text.py Normal file
View File

@@ -0,0 +1,279 @@
from unittest import TestCase
from hexlib.text import preprocess
class TestText(TestCase):
def test_html_invalid(self):
text = ""
cleaned = preprocess(
text,
clean_html=True,
)
expected = ""
self.assertEqual(" ".join(cleaned), expected)
def test_html_1(self):
text = "<div>Hello, <strong>world</strong></div>"
cleaned = preprocess(
text,
clean_html=True,
)
expected = "Hello, world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_2(self):
text = "<div>Hello, <strong>world</strong></div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True
)
expected = "hello, world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_4(self):
text = "<div>\n Hello, \t\n<strong> world </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
)
expected = "hello, world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_5(self):
text = "<div>\n Hello, \t\n<strong> world </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True
)
expected = "hello world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_6(self):
text = "<div>\n Hello, \t\n<strong>a the world </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
remove_stopwords_en=True
)
expected = "hello world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_7(self):
text = "<div>\n Hello, \t\n<strong>a the worlds </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
remove_stopwords_en=True,
lemmatize=True
)
expected = "hello world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_8(self):
text = "<div>\n Hello, \t\n<strong>a the worlds! </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
remove_stopwords_en=True,
lemmatize=True
)
expected = "hello world"
self.assertEqual(" ".join(cleaned), expected)
def test_html_9(self):
text = "<div>\n Hello, \t\n<strong>world! it's it`s </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
lemmatize=True,
fix_single_quotes=True
)
expected = "hello world it's it's"
self.assertEqual(" ".join(cleaned), expected)
def test_single_quote(self):
text = "it's it`s its"
cleaned = preprocess(
text,
lowercase=True,
fix_single_quotes=True
)
expected = "it's it's it's"
self.assertEqual(" ".join(cleaned), expected)
def test_html_10(self):
text = "<div>\n Hello, \t\n<strong>world! it's it`s https://google.ca/test/abc.pdf </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
lemmatize=True,
fix_single_quotes=True,
remove_urls=True
)
expected = "hello world it's it's"
self.assertEqual(" ".join(cleaned), expected)
def test_html_11(self):
text = "<div>\n Hello, \t\n<strong>world! it's it`s & | </strong>\n\t</div>"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
lemmatize=True,
fix_single_quotes=True,
remove_stopwords_en=True,
remove_urls=True
)
expected = "hello world |"
self.assertEqual(" ".join(cleaned), expected)
def test_html_no_root(self):
text = "<a href=\"#p217709510\" class=\"quotelink\">&gt;&gt;217709510</a><br>Is there a<wbr>servant that is against civilization and humanity?<br>Literally instant summon."
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
lemmatize=False,
fix_single_quotes=True,
remove_stopwords_en=False,
remove_urls=False
)
expected = ">>217709510 is there a servant that is against civilization and humanity literally instant summon"
self.assertEqual(" ".join(cleaned), expected)
def test_html_entity(self):
text = "doesn&#039;t"
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
lemmatize=False,
fix_single_quotes=True,
remove_stopwords_en=False,
remove_urls=False
)
expected = "doesn't"
self.assertEqual(" ".join(cleaned), expected)
def test_html_invalid_attribute(self):
text = '<root><iframe width="560" height="315" src=" " title="youtube video player" frameborder="0" allowfullscreen></iframe></root>'
cleaned = preprocess(
text,
clean_html=True,
lowercase=True,
remove_punctuation=True,
lemmatize=False,
fix_single_quotes=True,
remove_stopwords_en=False,
remove_urls=False
)
expected = ""
self.assertEqual(" ".join(cleaned), expected)
def test_bigrams(self):
text = "x A b c d e f g h"
cleaned = preprocess(
text,
lowercase=True,
bigrams={
("a", "b"),
("c", "d"),
("f", "g"),
}
)
expected = "x a_b c_d e f_g h"
self.assertEqual(" ".join(cleaned), expected)
def test_trigrams(self):
text = "x A b c d e f g h"
cleaned = preprocess(
text,
lowercase=True,
trigrams={
("a", "b", "c"),
("e", "f", "g"),
}
)
expected = "x a_b_c d e_f_g h"
self.assertEqual(" ".join(cleaned), expected)
def test_remove_numbers(self):
text = "Hello1 test1124test 12 1 1111111 world"
cleaned = preprocess(
text,
lowercase=True,
remove_numbers=True
)
expected = "hello1 test1124test world"
self.assertEqual(" ".join(cleaned), expected)
def test_strip_quotes(self):
text = "'hi' “test” 'hello\""
cleaned = preprocess(
text,
strip_quotes=True
)
expected = "hi test hello"
self.assertEqual(" ".join(cleaned), expected)
def test_strip_dashes(self):
text = "yes -But something-something -- hello aa--bb"
cleaned = preprocess(
text,
strip_dashes=True
)
expected = "yes But something-something hello aa-bb"
self.assertEqual(" ".join(cleaned), expected)
def test_word_tokenize(self):
text = "i cannot believe'"
cleaned = preprocess(
text,
use_nltk_tokenizer=True
)
expected = "i can not believe '"
self.assertEqual(" ".join(cleaned), expected)

21
test/test_web.py Normal file
View File

@@ -0,0 +1,21 @@
from unittest import TestCase
from hexlib.web import url_query_value
class TestWebMiscFuncs(TestCase):
def test_qs_1(self):
url = "https://test.com/page?a=1&b=2&a=2&c=hello"
self.assertEqual(url_query_value(url, "a"), "1")
self.assertEqual(url_query_value(url, "b"), "2")
self.assertEqual(url_query_value(url, "c"), "hello")
self.assertEqual(url_query_value(url, "D"), None)
def test_qs_as_list(self):
url = "https://test.com/page?a=1&b=2&a=2&c=hello"
self.assertEqual(url_query_value(url, "a", as_list=True), ["1", "2"])
self.assertEqual(url_query_value(url, "b", as_list=True), ["2"])
self.assertEqual(url_query_value(url, "c", as_list=True), ["hello"])
self.assertEqual(url_query_value(url, "D", as_list=True), [])