mirror of
https://github.com/simon987/hexlib.git
synced 2025-12-19 09:39:02 +00:00
Compare commits
16 Commits
cd5a1ac50c
...
1.89
| Author | SHA1 | Date | |
|---|---|---|---|
| b1a1da3bac | |||
| a047366926 | |||
| 24230cdc1e | |||
| 3bd9f03996 | |||
| e267bbf1c8 | |||
| 42e33b72b2 | |||
| 5275c332cc | |||
| a7b1a6e1ec | |||
| 826312115c | |||
| 372abb0076 | |||
| 78c04ef6f3 | |||
| a51ad2cbb4 | |||
| 4befc3973d | |||
| c9fac7151a | |||
| 084acbe184 | |||
| d578be3218 |
@@ -34,7 +34,7 @@ class StatelessStreamWorker:
|
||||
|
||||
|
||||
class StatelessStreamProcessor:
|
||||
def __init__(self, worker_factory, chunk_size=128, processes=1):
|
||||
def __init__(self, worker_factory, chunk_size=128, processes=1, timeout=60):
|
||||
self._chunk_size = 128
|
||||
self._queue = MPQueue(maxsize=chunk_size)
|
||||
self._queue_out = MPQueue(maxsize=processes * 2)
|
||||
@@ -42,6 +42,7 @@ class StatelessStreamProcessor:
|
||||
self._processes = []
|
||||
self._factory = worker_factory
|
||||
self._workers = []
|
||||
self._timeout = timeout
|
||||
|
||||
if processes > 1:
|
||||
for _ in range(processes):
|
||||
@@ -67,7 +68,7 @@ class StatelessStreamProcessor:
|
||||
ingest_thread = Thread(target=self._ingest, args=(iterable,))
|
||||
ingest_thread.start()
|
||||
|
||||
for results in queue_iter(self._queue_out, joinable=False, timeout=10):
|
||||
for results in queue_iter(self._queue_out, joinable=False, timeout=self._timeout):
|
||||
yield from results
|
||||
|
||||
ingest_thread.join()
|
||||
|
||||
136
hexlib/db.py
136
hexlib/db.py
@@ -1,26 +1,24 @@
|
||||
import base64
|
||||
import sqlite3
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
import psycopg2
|
||||
import umsgpack
|
||||
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hexlib.env import get_redis
|
||||
|
||||
|
||||
class PersistentState:
|
||||
"""Quick and dirty persistent dict-like SQLite wrapper"""
|
||||
def _json_encoder(x):
|
||||
if isinstance(x, datetime):
|
||||
return x.isoformat()
|
||||
if isinstance(x, Enum):
|
||||
return x.value
|
||||
|
||||
def __init__(self, dbfile="state.db", logger=None, **dbargs):
|
||||
self.dbfile = dbfile
|
||||
self.logger = logger
|
||||
if dbargs is None:
|
||||
dbargs = {"timeout": 30000}
|
||||
self.dbargs = dbargs
|
||||
|
||||
def __getitem__(self, table):
|
||||
return Table(self, table)
|
||||
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
|
||||
|
||||
|
||||
class VolatileState:
|
||||
@@ -128,32 +126,54 @@ class Table:
|
||||
self._state = state
|
||||
self._table = table
|
||||
|
||||
def sql(self, where_clause, *params):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
|
||||
for row in cur:
|
||||
yield dict(row)
|
||||
except:
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
cur = conn.execute("SELECT * FROM %s" % (self._table,))
|
||||
for row in cur:
|
||||
yield dict(row)
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
def _sql_dict(self, where_clause, *params):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
|
||||
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
|
||||
for row in cur:
|
||||
yield dict(
|
||||
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
|
||||
for i, col in enumerate(cur.description)
|
||||
)
|
||||
except:
|
||||
return None
|
||||
|
||||
def sql(self, where_clause, *params):
|
||||
for row in self._sql_dict(where_clause, *params):
|
||||
if row and "__pydantic" in row:
|
||||
yield self._deserialize_pydantic(row)
|
||||
else:
|
||||
yield row
|
||||
|
||||
def _iter_dict(self):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s" % (self._table,))
|
||||
for row in cur:
|
||||
yield dict(
|
||||
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
|
||||
for i, col in enumerate(cur.description)
|
||||
)
|
||||
except:
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
for row in self._iter_dict():
|
||||
if row and "__pydantic" in row:
|
||||
yield self._deserialize_pydantic(row)
|
||||
else:
|
||||
yield row
|
||||
|
||||
def _getitem_dict(self, key):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
|
||||
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
@@ -164,8 +184,32 @@ class Table:
|
||||
except:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_pydantic(row):
|
||||
module = __import__(row["__module"])
|
||||
cls = getattr(module, row["__class"])
|
||||
return cls.parse_raw(row["json"])
|
||||
|
||||
def __getitem__(self, key):
|
||||
row = self._getitem_dict(key)
|
||||
if row and "__pydantic" in row:
|
||||
return self._deserialize_pydantic(row)
|
||||
return row
|
||||
|
||||
def setitem_pydantic(self, key, value: BaseModel):
|
||||
self.__setitem__(key, {
|
||||
"json": value.json(encoder=_json_encoder, indent=2),
|
||||
"__class": value.__class__.__name__,
|
||||
"__module": value.__class__.__module__,
|
||||
"__pydantic": 1
|
||||
})
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
self.setitem_pydantic(key, value)
|
||||
return
|
||||
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
@@ -223,11 +267,33 @@ def _serialize(value):
|
||||
|
||||
|
||||
def _deserialize(value, col_type):
|
||||
if col_type == "blob":
|
||||
if col_type.lower() == "blob":
|
||||
return base64.b64decode(value)
|
||||
return value
|
||||
|
||||
|
||||
class PersistentState:
|
||||
"""Quick and dirty persistent dict-like SQLite wrapper"""
|
||||
|
||||
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
|
||||
self.dbfile = dbfile
|
||||
self.logger = logger
|
||||
if dbargs is None or dbargs == {}:
|
||||
dbargs = {"timeout": 30000}
|
||||
self.dbargs = dbargs
|
||||
self._table_factory = table_factory
|
||||
|
||||
def __getitem__(self, table):
|
||||
return self._table_factory(self, table)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
|
||||
try:
|
||||
conn.execute(f"DROP TABLE {key}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def pg_fetch_cursor_all(cur, name, batch_size=1000):
|
||||
while True:
|
||||
cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name))
|
||||
@@ -250,10 +316,10 @@ class PgConn:
|
||||
def __init__(self, logger=None, **kwargs):
|
||||
self._conn_args = kwargs
|
||||
self.conn = psycopg2.connect(**kwargs)
|
||||
self.cur = self.conn.cursor()
|
||||
self._logger = logger
|
||||
|
||||
def __enter__(self):
|
||||
self.cur = self.conn.cursor()
|
||||
return self
|
||||
|
||||
def exec(self, query_string, args=None):
|
||||
|
||||
@@ -85,7 +85,7 @@ def ndjson_iter(*files, compression=""):
|
||||
line_iter = BufferedReader(gzip.open(file))
|
||||
elif compression == COMPRESSION_ZSTD:
|
||||
fp = open(file, "rb")
|
||||
dctx = zstandard.ZstdDecompressor()
|
||||
dctx = zstandard.ZstdDecompressor(max_window_size=2147483648)
|
||||
reader = dctx.stream_reader(fp)
|
||||
line_iter = BufferedReader(reader)
|
||||
|
||||
|
||||
227
hexlib/plot.py
227
hexlib/plot.py
@@ -1,227 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
||||
|
||||
DATA = [
|
||||
*["apple"] * 5,
|
||||
*["banana"] * 12,
|
||||
*["strawberry"] * 8,
|
||||
*["pineapple"] * 2,
|
||||
]
|
||||
|
||||
|
||||
class Cmap:
|
||||
Accent = "Accent"
|
||||
Accent_r = "Accent_r"
|
||||
Blues = "Blues"
|
||||
Blues_r = "Blues_r"
|
||||
BrBG = "BrBG"
|
||||
BrBG_r = "BrBG_r"
|
||||
BuGn = "BuGn"
|
||||
BuGn_r = "BuGn_r"
|
||||
BuPu = "BuPu"
|
||||
BuPu_r = "BuPu_r"
|
||||
CMRmap = "CMRmap"
|
||||
CMRmap_r = "CMRmap_r"
|
||||
Dark2 = "Dark2"
|
||||
Dark2_r = "Dark2_r"
|
||||
GnBu = "GnBu"
|
||||
GnBu_r = "GnBu_r"
|
||||
Greens = "Greens"
|
||||
Greens_r = "Greens_r"
|
||||
Greys = "Greys"
|
||||
Greys_r = "Greys_r"
|
||||
OrRd = "OrRd"
|
||||
OrRd_r = "OrRd_r"
|
||||
Oranges = "Oranges"
|
||||
Oranges_r = "Oranges_r"
|
||||
PRGn = "PRGn"
|
||||
PRGn_r = "PRGn_r"
|
||||
Paired = "Paired"
|
||||
Paired_r = "Paired_r"
|
||||
Pastel1 = "Pastel1"
|
||||
Pastel1_r = "Pastel1_r"
|
||||
Pastel2 = "Pastel2"
|
||||
Pastel2_r = "Pastel2_r"
|
||||
PiYG = "PiYG"
|
||||
PiYG_r = "PiYG_r"
|
||||
PuBu = "PuBu"
|
||||
PuBuGn = "PuBuGn"
|
||||
PuBuGn_r = "PuBuGn_r"
|
||||
PuBu_r = "PuBu_r"
|
||||
PuOr = "PuOr"
|
||||
PuOr_r = "PuOr_r"
|
||||
PuRd = "PuRd"
|
||||
PuRd_r = "PuRd_r"
|
||||
Purples = "Purples"
|
||||
Purples_r = "Purples_r"
|
||||
RdBu = "RdBu"
|
||||
RdBu_r = "RdBu_r"
|
||||
RdGy = "RdGy"
|
||||
RdGy_r = "RdGy_r"
|
||||
RdPu = "RdPu"
|
||||
RdPu_r = "RdPu_r"
|
||||
RdYlBu = "RdYlBu"
|
||||
RdYlBu_r = "RdYlBu_r"
|
||||
RdYlGn = "RdYlGn"
|
||||
RdYlGn_r = "RdYlGn_r"
|
||||
Reds = "Reds"
|
||||
Reds_r = "Reds_r"
|
||||
Set1 = "Set1"
|
||||
Set1_r = "Set1_r"
|
||||
Set2 = "Set2"
|
||||
Set2_r = "Set2_r"
|
||||
Set3 = "Set3"
|
||||
Set3_r = "Set3_r"
|
||||
Spectral = "Spectral"
|
||||
Spectral_r = "Spectral_r"
|
||||
Wistia = "Wistia"
|
||||
Wistia_r = "Wistia_r"
|
||||
YlGn = "YlGn"
|
||||
YlGnBu = "YlGnBu"
|
||||
YlGnBu_r = "YlGnBu_r"
|
||||
YlGn_r = "YlGn_r"
|
||||
YlOrBr = "YlOrBr"
|
||||
YlOrBr_r = "YlOrBr_r"
|
||||
YlOrRd = "YlOrRd"
|
||||
YlOrRd_r = "YlOrRd_r"
|
||||
afmhot = "afmhot"
|
||||
afmhot_r = "afmhot_r"
|
||||
autumn = "autumn"
|
||||
autumn_r = "autumn_r"
|
||||
binary = "binary"
|
||||
binary_r = "binary_r"
|
||||
bone = "bone"
|
||||
bone_r = "bone_r"
|
||||
brg = "brg"
|
||||
brg_r = "brg_r"
|
||||
bwr = "bwr"
|
||||
bwr_r = "bwr_r"
|
||||
cividis = "cividis"
|
||||
cividis_r = "cividis_r"
|
||||
cool = "cool"
|
||||
cool_r = "cool_r"
|
||||
coolwarm = "coolwarm"
|
||||
coolwarm_r = "coolwarm_r"
|
||||
copper = "copper"
|
||||
copper_r = "copper_r"
|
||||
cubehelix = "cubehelix"
|
||||
cubehelix_r = "cubehelix_r"
|
||||
flag = "flag"
|
||||
flag_r = "flag_r"
|
||||
gist_earth = "gist_earth"
|
||||
gist_earth_r = "gist_earth_r"
|
||||
gist_gray = "gist_gray"
|
||||
gist_gray_r = "gist_gray_r"
|
||||
gist_heat = "gist_heat"
|
||||
gist_heat_r = "gist_heat_r"
|
||||
gist_ncar = "gist_ncar"
|
||||
gist_ncar_r = "gist_ncar_r"
|
||||
gist_rainbow = "gist_rainbow"
|
||||
gist_rainbow_r = "gist_rainbow_r"
|
||||
gist_stern = "gist_stern"
|
||||
gist_stern_r = "gist_stern_r"
|
||||
gist_yarg = "gist_yarg"
|
||||
gist_yarg_r = "gist_yarg_r"
|
||||
gnuplot = "gnuplot"
|
||||
gnuplot2 = "gnuplot2"
|
||||
gnuplot2_r = "gnuplot2_r"
|
||||
gnuplot_r = "gnuplot_r"
|
||||
gray = "gray"
|
||||
gray_r = "gray_r"
|
||||
hot = "hot"
|
||||
hot_r = "hot_r"
|
||||
hsv = "hsv"
|
||||
hsv_r = "hsv_r"
|
||||
inferno = "inferno"
|
||||
inferno_r = "inferno_r"
|
||||
jet = "jet"
|
||||
jet_r = "jet_r"
|
||||
magma = "magma"
|
||||
magma_r = "magma_r"
|
||||
nipy_spectral = "nipy_spectral"
|
||||
nipy_spectral_r = "nipy_spectral_r"
|
||||
ocean = "ocean"
|
||||
ocean_r = "ocean_r"
|
||||
pink = "pink"
|
||||
pink_r = "pink_r"
|
||||
plasma = "plasma"
|
||||
plasma_r = "plasma_r"
|
||||
prism = "prism"
|
||||
prism_r = "prism_r"
|
||||
rainbow = "rainbow"
|
||||
rainbow_r = "rainbow_r"
|
||||
seismic = "seismic"
|
||||
seismic_r = "seismic_r"
|
||||
spring = "spring"
|
||||
spring_r = "spring_r"
|
||||
summer = "summer"
|
||||
summer_r = "summer_r"
|
||||
tab10 = "tab10"
|
||||
tab10_r = "tab10_r"
|
||||
tab20 = "tab20"
|
||||
tab20_r = "tab20_r"
|
||||
tab20b = "tab20b"
|
||||
tab20b_r = "tab20b_r"
|
||||
tab20c = "tab20c"
|
||||
tab20c_r = "tab20c_r"
|
||||
terrain = "terrain"
|
||||
terrain_r = "terrain_r"
|
||||
turbo = "turbo"
|
||||
turbo_r = "turbo_r"
|
||||
twilight = "twilight"
|
||||
twilight_r = "twilight_r"
|
||||
twilight_shifted = "twilight_shifted"
|
||||
twilight_shifted_r = "twilight_shifted_r"
|
||||
viridis = "viridis"
|
||||
viridis_r = "viridis_r"
|
||||
winter = "winter"
|
||||
winter_r = "winter_r"
|
||||
|
||||
|
||||
def plot_freq_bar(items, ylabel="frequency", title=""):
|
||||
item_set, item_counts = np.unique(items, return_counts=True)
|
||||
|
||||
plt.bar(item_set, item_counts)
|
||||
plt.xticks(rotation=35)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
|
||||
for i, cnt in enumerate(item_counts):
|
||||
plt.text(x=i, y=cnt / 2, s=cnt, ha="center", color="white")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true=None, y_pred=None, cm=None, labels=None, title=None, cmap=None):
|
||||
if not cm:
|
||||
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
||||
|
||||
if type(cm) == list:
|
||||
cm = np.array(cm)
|
||||
|
||||
cm_display = ConfusionMatrixDisplay(cm, display_labels=labels)
|
||||
cm_display.plot(cmap=cmap)
|
||||
|
||||
if title:
|
||||
plt.title(title)
|
||||
|
||||
if labels:
|
||||
plt.xticks(rotation=30)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plot_freq_bar(DATA, title="My title")
|
||||
plt.show()
|
||||
|
||||
plot_confusion_matrix(
|
||||
cm=[[12, 1, 0],
|
||||
[3, 14, 1],
|
||||
[5, 6, 7]],
|
||||
title="My title",
|
||||
labels=["apple", "orange", "grape"],
|
||||
cmap=Cmap.viridis
|
||||
)
|
||||
plt.show()
|
||||
@@ -1,16 +1,20 @@
|
||||
from functools import partial
|
||||
import re
|
||||
from itertools import chain, repeat
|
||||
from multiprocessing.pool import Pool
|
||||
|
||||
import nltk.corpus
|
||||
from lxml import etree
|
||||
from nltk import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
|
||||
from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE
|
||||
from .regex_util import LINK_RE
|
||||
|
||||
get_text = etree.XPath("//text()")
|
||||
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
stop_words_en = set(stopwords.words("english"))
|
||||
|
||||
extra_stop_words_en = [
|
||||
@@ -19,9 +23,6 @@ extra_stop_words_en = [
|
||||
|
||||
stop_words_en.update(extra_stop_words_en)
|
||||
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
|
||||
|
||||
@@ -53,13 +54,20 @@ SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", l
|
||||
DASHES = ("–", "⸺", "–", "—")
|
||||
DASHES_TRANS = str.maketrans("".join(DASHES), "".join(repeat("-", len(DASHES))))
|
||||
|
||||
PUNCTUATION = ".,;:\"!?/()|*=>"
|
||||
DASHES_RE = re.compile(r"-+")
|
||||
|
||||
SPECIAL_PUNCTUATION = ";:\"/()|*=>"
|
||||
SPECIAL_PUNCTUATION_TRANS = str.maketrans(SPECIAL_PUNCTUATION, " " * len(SPECIAL_PUNCTUATION))
|
||||
|
||||
PUNCTUATION = ".,!?"
|
||||
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION))
|
||||
|
||||
|
||||
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_stopwords_en=False,
|
||||
lemmatize=False, fix_single_quotes=False, strip_quotes=False, remove_urls=False, bigrams: set = None,
|
||||
trigrams: set = None, remove_numbers=False):
|
||||
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_special_punctuation=False,
|
||||
remove_stopwords_en=False, lemmatize=False, fix_single_quotes=False, strip_quotes=False,
|
||||
strip_dashes=False,
|
||||
remove_urls=False, bigrams: set = None, trigrams: set = None, remove_numbers=False,
|
||||
use_nltk_tokenizer=False):
|
||||
if lowercase:
|
||||
text = text.lower()
|
||||
|
||||
@@ -68,6 +76,9 @@ def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False
|
||||
|
||||
text = text.translate(DASHES_TRANS)
|
||||
|
||||
if strip_dashes:
|
||||
text = DASHES_RE.sub("-", text)
|
||||
|
||||
if remove_urls:
|
||||
text = LINK_RE.sub(" ", text)
|
||||
|
||||
@@ -85,11 +96,20 @@ def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False
|
||||
if remove_punctuation:
|
||||
text = text.translate(PUNCTUATION_TRANS)
|
||||
|
||||
words = text.split()
|
||||
if remove_special_punctuation:
|
||||
text = text.translate(SPECIAL_PUNCTUATION_TRANS)
|
||||
|
||||
if use_nltk_tokenizer:
|
||||
words = word_tokenize(text, language="english")
|
||||
else:
|
||||
words = text.split()
|
||||
|
||||
if strip_quotes:
|
||||
words = map(lambda w: w.strip("\"'“”"), words)
|
||||
|
||||
if strip_dashes:
|
||||
words = map(lambda w: w.strip("-"), words)
|
||||
|
||||
if bigrams:
|
||||
words = _transform_bigram(nltk.bigrams(chain(words, ("*",))), bigrams)
|
||||
|
||||
|
||||
8
setup.py
8
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="hexlib",
|
||||
version="1.73",
|
||||
version="1.89",
|
||||
description="Misc utility methods",
|
||||
author="simon987",
|
||||
author_email="me@simon987.net",
|
||||
@@ -12,9 +12,9 @@ setup(
|
||||
"data/*"
|
||||
]},
|
||||
install_requires=[
|
||||
"ImageHash", "influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
|
||||
"influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
|
||||
"u-msgpack-python", "psycopg2-binary", "bs4", "lxml", "nltk", "numpy",
|
||||
"matplotlib", "scikit-learn", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent",
|
||||
"requests"
|
||||
"matplotlib", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent",
|
||||
"requests", "pydantic==1.10.11"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,3 +110,34 @@ class TestPersistentState(TestCase):
|
||||
del s["a"][456]
|
||||
except Exception as e:
|
||||
self.fail(e)
|
||||
|
||||
def test_deserialize_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": b'abc'}
|
||||
|
||||
self.assertEqual(s["a"][0]["x"], b'abc')
|
||||
|
||||
def test_deserialize_sql(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": b'abc'}
|
||||
|
||||
self.assertEqual(list(s["a"].sql("WHERE 1=1"))[0]["x"], b'abc')
|
||||
|
||||
def test_deserialize_iter(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": b'abc'}
|
||||
|
||||
self.assertEqual(list(s["a"])[0]["x"], b'abc')
|
||||
|
||||
def test_drop_table(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["a"][0] = {"x": 1}
|
||||
s["a"][1] = {"x": 2}
|
||||
self.assertEqual(len(list(s["a"])), 2)
|
||||
|
||||
del s["a"]
|
||||
self.assertEqual(len(list(s["a"])), 0)
|
||||
110
test/test_PydanticTable.py
Normal file
110
test/test_PydanticTable.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from unittest import TestCase
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import List
|
||||
|
||||
from hexlib.db import PersistentState
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
yes = "yes"
|
||||
no = "no"
|
||||
|
||||
|
||||
class Point(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
class Polygon(BaseModel):
|
||||
points: List[Point] = []
|
||||
created_date: datetime
|
||||
status: Status = Status("yes")
|
||||
|
||||
|
||||
class TestPydanticTable(TestCase):
|
||||
def tearDown(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def setUp(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def test_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[
|
||||
Point(x=1, y=2),
|
||||
Point(x=3, y=4),
|
||||
],
|
||||
)
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
||||
self.assertEqual(s["a"]["1"].status, Status("yes"))
|
||||
self.assertEqual(s["a"]["1"].points[1].x, 3)
|
||||
self.assertEqual(s["a"]["1"].created_date.year, 2000)
|
||||
|
||||
def test_update(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[
|
||||
Point(x=1, y=2),
|
||||
Point(x=3, y=4),
|
||||
]
|
||||
)
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
||||
|
||||
val.points[0].x = 2
|
||||
s["a"]["1"] = val
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 2)
|
||||
|
||||
def test_sql(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["b"]["1"] = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
s["b"]["2"] = Polygon(
|
||||
created_date=datetime(year=2010, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
|
||||
result = list(s["b"].sql(
|
||||
"WHERE json->>'created_date' LIKE '2000-%'"
|
||||
))
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].created_date.year, 2000)
|
||||
|
||||
def test_iterate(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["b"]["1"] = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
s["b"]["2"] = Polygon(
|
||||
created_date=datetime(year=2010, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
|
||||
result = list(s["b"])
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].created_date.year, 2000)
|
||||
self.assertEqual(result[1].created_date.year, 2010)
|
||||
@@ -152,7 +152,7 @@ class TestText(TestCase):
|
||||
remove_stopwords_en=True,
|
||||
remove_urls=True
|
||||
)
|
||||
expected = "hello world"
|
||||
expected = "hello world |"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestText(TestCase):
|
||||
remove_urls=False
|
||||
)
|
||||
|
||||
expected = "217709510 is there a servant that is against civilization and humanity literally instant summon"
|
||||
expected = ">>217709510 is there a servant that is against civilization and humanity literally instant summon"
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_html_entity(self):
|
||||
@@ -257,3 +257,23 @@ class TestText(TestCase):
|
||||
expected = "hi test hello"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_strip_dashes(self):
|
||||
text = "yes -But something-something -- hello aa--bb"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
strip_dashes=True
|
||||
)
|
||||
expected = "yes But something-something hello aa-bb"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
def test_word_tokenize(self):
|
||||
text = "i cannot believe'"
|
||||
cleaned = preprocess(
|
||||
text,
|
||||
use_nltk_tokenizer=True
|
||||
)
|
||||
expected = "i can not believe '"
|
||||
|
||||
self.assertEqual(" ".join(cleaned), expected)
|
||||
|
||||
Reference in New Issue
Block a user