Compare commits

...

16 Commits

10 changed files with 303 additions and 282 deletions

View File

@@ -34,7 +34,7 @@ class StatelessStreamWorker:
class StatelessStreamProcessor: class StatelessStreamProcessor:
def __init__(self, worker_factory, chunk_size=128, processes=1): def __init__(self, worker_factory, chunk_size=128, processes=1, timeout=60):
self._chunk_size = 128 self._chunk_size = 128
self._queue = MPQueue(maxsize=chunk_size) self._queue = MPQueue(maxsize=chunk_size)
self._queue_out = MPQueue(maxsize=processes * 2) self._queue_out = MPQueue(maxsize=processes * 2)
@@ -42,6 +42,7 @@ class StatelessStreamProcessor:
self._processes = [] self._processes = []
self._factory = worker_factory self._factory = worker_factory
self._workers = [] self._workers = []
self._timeout = timeout
if processes > 1: if processes > 1:
for _ in range(processes): for _ in range(processes):
@@ -67,7 +68,7 @@ class StatelessStreamProcessor:
ingest_thread = Thread(target=self._ingest, args=(iterable,)) ingest_thread = Thread(target=self._ingest, args=(iterable,))
ingest_thread.start() ingest_thread.start()
for results in queue_iter(self._queue_out, joinable=False, timeout=10): for results in queue_iter(self._queue_out, joinable=False, timeout=self._timeout):
yield from results yield from results
ingest_thread.join() ingest_thread.join()

View File

@@ -1,26 +1,24 @@
import base64 import base64
import sqlite3 import sqlite3
import traceback import traceback
from datetime import datetime
from enum import Enum
import psycopg2 import psycopg2
import umsgpack import umsgpack
from psycopg2.errorcodes import UNIQUE_VIOLATION from psycopg2.errorcodes import UNIQUE_VIOLATION
from pydantic import BaseModel
from hexlib.env import get_redis from hexlib.env import get_redis
class PersistentState: def _json_encoder(x):
"""Quick and dirty persistent dict-like SQLite wrapper""" if isinstance(x, datetime):
return x.isoformat()
if isinstance(x, Enum):
return x.value
def __init__(self, dbfile="state.db", logger=None, **dbargs): raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
self.dbfile = dbfile
self.logger = logger
if dbargs is None:
dbargs = {"timeout": 30000}
self.dbargs = dbargs
def __getitem__(self, table):
return Table(self, table)
class VolatileState: class VolatileState:
@@ -128,32 +126,54 @@ class Table:
self._state = state self._state = state
self._table = table self._table = table
def sql(self, where_clause, *params): def _sql_dict(self, where_clause, *params):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
for row in cur:
yield dict(row)
except:
return None
def __iter__(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
cur = conn.execute("SELECT * FROM %s" % (self._table,))
for row in cur:
yield dict(row)
except:
return None
def __getitem__(self, item):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
try: try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall() col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,)) cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
for row in cur:
yield dict(
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
for i, col in enumerate(cur.description)
)
except:
return None
def sql(self, where_clause, *params):
for row in self._sql_dict(where_clause, *params):
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _iter_dict(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s" % (self._table,))
for row in cur:
yield dict(
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
for i, col in enumerate(cur.description)
)
except:
return None
def __iter__(self):
for row in self._iter_dict():
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _getitem_dict(self, key):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
row = cur.fetchone() row = cur.fetchone()
if row: if row:
@@ -164,8 +184,32 @@ class Table:
except: except:
return None return None
@staticmethod
def _deserialize_pydantic(row):
module = __import__(row["__module"])
cls = getattr(module, row["__class"])
return cls.parse_raw(row["json"])
def __getitem__(self, key):
row = self._getitem_dict(key)
if row and "__pydantic" in row:
return self._deserialize_pydantic(row)
return row
def setitem_pydantic(self, key, value: BaseModel):
self.__setitem__(key, {
"json": value.json(encoder=_json_encoder, indent=2),
"__class": value.__class__.__name__,
"__module": value.__class__.__module__,
"__pydantic": 1
})
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(value, BaseModel):
self.setitem_pydantic(key, value)
return
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@@ -223,11 +267,33 @@ def _serialize(value):
def _deserialize(value, col_type): def _deserialize(value, col_type):
if col_type == "blob": if col_type.lower() == "blob":
return base64.b64decode(value) return base64.b64decode(value)
return value return value
class PersistentState:
"""Quick and dirty persistent dict-like SQLite wrapper"""
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
self.dbfile = dbfile
self.logger = logger
if dbargs is None or dbargs == {}:
dbargs = {"timeout": 30000}
self.dbargs = dbargs
self._table_factory = table_factory
def __getitem__(self, table):
return self._table_factory(self, table)
def __delitem__(self, key):
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
try:
conn.execute(f"DROP TABLE {key}")
except:
pass
def pg_fetch_cursor_all(cur, name, batch_size=1000): def pg_fetch_cursor_all(cur, name, batch_size=1000):
while True: while True:
cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name)) cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name))
@@ -250,10 +316,10 @@ class PgConn:
def __init__(self, logger=None, **kwargs): def __init__(self, logger=None, **kwargs):
self._conn_args = kwargs self._conn_args = kwargs
self.conn = psycopg2.connect(**kwargs) self.conn = psycopg2.connect(**kwargs)
self.cur = self.conn.cursor()
self._logger = logger self._logger = logger
def __enter__(self): def __enter__(self):
self.cur = self.conn.cursor()
return self return self
def exec(self, query_string, args=None): def exec(self, query_string, args=None):

View File

@@ -85,7 +85,7 @@ def ndjson_iter(*files, compression=""):
line_iter = BufferedReader(gzip.open(file)) line_iter = BufferedReader(gzip.open(file))
elif compression == COMPRESSION_ZSTD: elif compression == COMPRESSION_ZSTD:
fp = open(file, "rb") fp = open(file, "rb")
dctx = zstandard.ZstdDecompressor() dctx = zstandard.ZstdDecompressor(max_window_size=2147483648)
reader = dctx.stream_reader(fp) reader = dctx.stream_reader(fp)
line_iter = BufferedReader(reader) line_iter = BufferedReader(reader)

View File

@@ -1,227 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
DATA = [
*["apple"] * 5,
*["banana"] * 12,
*["strawberry"] * 8,
*["pineapple"] * 2,
]
class Cmap:
Accent = "Accent"
Accent_r = "Accent_r"
Blues = "Blues"
Blues_r = "Blues_r"
BrBG = "BrBG"
BrBG_r = "BrBG_r"
BuGn = "BuGn"
BuGn_r = "BuGn_r"
BuPu = "BuPu"
BuPu_r = "BuPu_r"
CMRmap = "CMRmap"
CMRmap_r = "CMRmap_r"
Dark2 = "Dark2"
Dark2_r = "Dark2_r"
GnBu = "GnBu"
GnBu_r = "GnBu_r"
Greens = "Greens"
Greens_r = "Greens_r"
Greys = "Greys"
Greys_r = "Greys_r"
OrRd = "OrRd"
OrRd_r = "OrRd_r"
Oranges = "Oranges"
Oranges_r = "Oranges_r"
PRGn = "PRGn"
PRGn_r = "PRGn_r"
Paired = "Paired"
Paired_r = "Paired_r"
Pastel1 = "Pastel1"
Pastel1_r = "Pastel1_r"
Pastel2 = "Pastel2"
Pastel2_r = "Pastel2_r"
PiYG = "PiYG"
PiYG_r = "PiYG_r"
PuBu = "PuBu"
PuBuGn = "PuBuGn"
PuBuGn_r = "PuBuGn_r"
PuBu_r = "PuBu_r"
PuOr = "PuOr"
PuOr_r = "PuOr_r"
PuRd = "PuRd"
PuRd_r = "PuRd_r"
Purples = "Purples"
Purples_r = "Purples_r"
RdBu = "RdBu"
RdBu_r = "RdBu_r"
RdGy = "RdGy"
RdGy_r = "RdGy_r"
RdPu = "RdPu"
RdPu_r = "RdPu_r"
RdYlBu = "RdYlBu"
RdYlBu_r = "RdYlBu_r"
RdYlGn = "RdYlGn"
RdYlGn_r = "RdYlGn_r"
Reds = "Reds"
Reds_r = "Reds_r"
Set1 = "Set1"
Set1_r = "Set1_r"
Set2 = "Set2"
Set2_r = "Set2_r"
Set3 = "Set3"
Set3_r = "Set3_r"
Spectral = "Spectral"
Spectral_r = "Spectral_r"
Wistia = "Wistia"
Wistia_r = "Wistia_r"
YlGn = "YlGn"
YlGnBu = "YlGnBu"
YlGnBu_r = "YlGnBu_r"
YlGn_r = "YlGn_r"
YlOrBr = "YlOrBr"
YlOrBr_r = "YlOrBr_r"
YlOrRd = "YlOrRd"
YlOrRd_r = "YlOrRd_r"
afmhot = "afmhot"
afmhot_r = "afmhot_r"
autumn = "autumn"
autumn_r = "autumn_r"
binary = "binary"
binary_r = "binary_r"
bone = "bone"
bone_r = "bone_r"
brg = "brg"
brg_r = "brg_r"
bwr = "bwr"
bwr_r = "bwr_r"
cividis = "cividis"
cividis_r = "cividis_r"
cool = "cool"
cool_r = "cool_r"
coolwarm = "coolwarm"
coolwarm_r = "coolwarm_r"
copper = "copper"
copper_r = "copper_r"
cubehelix = "cubehelix"
cubehelix_r = "cubehelix_r"
flag = "flag"
flag_r = "flag_r"
gist_earth = "gist_earth"
gist_earth_r = "gist_earth_r"
gist_gray = "gist_gray"
gist_gray_r = "gist_gray_r"
gist_heat = "gist_heat"
gist_heat_r = "gist_heat_r"
gist_ncar = "gist_ncar"
gist_ncar_r = "gist_ncar_r"
gist_rainbow = "gist_rainbow"
gist_rainbow_r = "gist_rainbow_r"
gist_stern = "gist_stern"
gist_stern_r = "gist_stern_r"
gist_yarg = "gist_yarg"
gist_yarg_r = "gist_yarg_r"
gnuplot = "gnuplot"
gnuplot2 = "gnuplot2"
gnuplot2_r = "gnuplot2_r"
gnuplot_r = "gnuplot_r"
gray = "gray"
gray_r = "gray_r"
hot = "hot"
hot_r = "hot_r"
hsv = "hsv"
hsv_r = "hsv_r"
inferno = "inferno"
inferno_r = "inferno_r"
jet = "jet"
jet_r = "jet_r"
magma = "magma"
magma_r = "magma_r"
nipy_spectral = "nipy_spectral"
nipy_spectral_r = "nipy_spectral_r"
ocean = "ocean"
ocean_r = "ocean_r"
pink = "pink"
pink_r = "pink_r"
plasma = "plasma"
plasma_r = "plasma_r"
prism = "prism"
prism_r = "prism_r"
rainbow = "rainbow"
rainbow_r = "rainbow_r"
seismic = "seismic"
seismic_r = "seismic_r"
spring = "spring"
spring_r = "spring_r"
summer = "summer"
summer_r = "summer_r"
tab10 = "tab10"
tab10_r = "tab10_r"
tab20 = "tab20"
tab20_r = "tab20_r"
tab20b = "tab20b"
tab20b_r = "tab20b_r"
tab20c = "tab20c"
tab20c_r = "tab20c_r"
terrain = "terrain"
terrain_r = "terrain_r"
turbo = "turbo"
turbo_r = "turbo_r"
twilight = "twilight"
twilight_r = "twilight_r"
twilight_shifted = "twilight_shifted"
twilight_shifted_r = "twilight_shifted_r"
viridis = "viridis"
viridis_r = "viridis_r"
winter = "winter"
winter_r = "winter_r"
def plot_freq_bar(items, ylabel="frequency", title=""):
item_set, item_counts = np.unique(items, return_counts=True)
plt.bar(item_set, item_counts)
plt.xticks(rotation=35)
plt.ylabel(ylabel)
plt.title(title)
for i, cnt in enumerate(item_counts):
plt.text(x=i, y=cnt / 2, s=cnt, ha="center", color="white")
plt.tight_layout()
def plot_confusion_matrix(y_true=None, y_pred=None, cm=None, labels=None, title=None, cmap=None):
if not cm:
cm = confusion_matrix(y_true, y_pred, labels=labels)
if type(cm) == list:
cm = np.array(cm)
cm_display = ConfusionMatrixDisplay(cm, display_labels=labels)
cm_display.plot(cmap=cmap)
if title:
plt.title(title)
if labels:
plt.xticks(rotation=30)
plt.tight_layout()
if __name__ == '__main__':
plot_freq_bar(DATA, title="My title")
plt.show()
plot_confusion_matrix(
cm=[[12, 1, 0],
[3, 14, 1],
[5, 6, 7]],
title="My title",
labels=["apple", "orange", "grape"],
cmap=Cmap.viridis
)
plt.show()

View File

@@ -1,16 +1,20 @@
from functools import partial import re
from itertools import chain, repeat from itertools import chain, repeat
from multiprocessing.pool import Pool
import nltk.corpus import nltk.corpus
from lxml import etree from lxml import etree
from nltk import word_tokenize
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer from nltk.stem import WordNetLemmatizer
from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE from .regex_util import LINK_RE
get_text = etree.XPath("//text()") get_text = etree.XPath("//text()")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
stop_words_en = set(stopwords.words("english")) stop_words_en = set(stopwords.words("english"))
extra_stop_words_en = [ extra_stop_words_en = [
@@ -19,9 +23,6 @@ extra_stop_words_en = [
stop_words_en.update(extra_stop_words_en) stop_words_en.update(extra_stop_words_en)
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
lemmatizer = WordNetLemmatizer() lemmatizer = WordNetLemmatizer()
@@ -53,13 +54,20 @@ SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", l
DASHES = ("", "", "", "") DASHES = ("", "", "", "")
DASHES_TRANS = str.maketrans("".join(DASHES), "".join(repeat("-", len(DASHES)))) DASHES_TRANS = str.maketrans("".join(DASHES), "".join(repeat("-", len(DASHES))))
PUNCTUATION = ".,;:\"!?/()|*=>" DASHES_RE = re.compile(r"-+")
SPECIAL_PUNCTUATION = ";:\"/()|*=>"
SPECIAL_PUNCTUATION_TRANS = str.maketrans(SPECIAL_PUNCTUATION, " " * len(SPECIAL_PUNCTUATION))
PUNCTUATION = ".,!?"
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION)) PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION))
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_stopwords_en=False, def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_special_punctuation=False,
lemmatize=False, fix_single_quotes=False, strip_quotes=False, remove_urls=False, bigrams: set = None, remove_stopwords_en=False, lemmatize=False, fix_single_quotes=False, strip_quotes=False,
trigrams: set = None, remove_numbers=False): strip_dashes=False,
remove_urls=False, bigrams: set = None, trigrams: set = None, remove_numbers=False,
use_nltk_tokenizer=False):
if lowercase: if lowercase:
text = text.lower() text = text.lower()
@@ -68,6 +76,9 @@ def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False
text = text.translate(DASHES_TRANS) text = text.translate(DASHES_TRANS)
if strip_dashes:
text = DASHES_RE.sub("-", text)
if remove_urls: if remove_urls:
text = LINK_RE.sub(" ", text) text = LINK_RE.sub(" ", text)
@@ -85,11 +96,20 @@ def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False
if remove_punctuation: if remove_punctuation:
text = text.translate(PUNCTUATION_TRANS) text = text.translate(PUNCTUATION_TRANS)
words = text.split() if remove_special_punctuation:
text = text.translate(SPECIAL_PUNCTUATION_TRANS)
if use_nltk_tokenizer:
words = word_tokenize(text, language="english")
else:
words = text.split()
if strip_quotes: if strip_quotes:
words = map(lambda w: w.strip("\"'“”"), words) words = map(lambda w: w.strip("\"'“”"), words)
if strip_dashes:
words = map(lambda w: w.strip("-"), words)
if bigrams: if bigrams:
words = _transform_bigram(nltk.bigrams(chain(words, ("*",))), bigrams) words = _transform_bigram(nltk.bigrams(chain(words, ("*",))), bigrams)

View File

@@ -2,7 +2,7 @@ from setuptools import setup
setup( setup(
name="hexlib", name="hexlib",
version="1.73", version="1.89",
description="Misc utility methods", description="Misc utility methods",
author="simon987", author="simon987",
author_email="me@simon987.net", author_email="me@simon987.net",
@@ -12,9 +12,9 @@ setup(
"data/*" "data/*"
]}, ]},
install_requires=[ install_requires=[
"ImageHash", "influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard", "influxdb", "siphash", "python-dateutil", "redis", "orjson", "zstandard",
"u-msgpack-python", "psycopg2-binary", "bs4", "lxml", "nltk", "numpy", "u-msgpack-python", "psycopg2-binary", "bs4", "lxml", "nltk", "numpy",
"matplotlib", "scikit-learn", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent", "matplotlib", "fake-useragent @ git+https://github.com/Jordan9675/fake-useragent",
"requests" "requests", "pydantic==1.10.11"
] ]
) )

View File

@@ -110,3 +110,34 @@ class TestPersistentState(TestCase):
del s["a"][456] del s["a"][456]
except Exception as e: except Exception as e:
self.fail(e) self.fail(e)
def test_deserialize_get_set(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(s["a"][0]["x"], b'abc')
def test_deserialize_sql(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"].sql("WHERE 1=1"))[0]["x"], b'abc')
def test_deserialize_iter(self):
s = PersistentState()
s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"])[0]["x"], b'abc')
def test_drop_table(self):
s = PersistentState()
s["a"][0] = {"x": 1}
s["a"][1] = {"x": 2}
self.assertEqual(len(list(s["a"])), 2)
del s["a"]
self.assertEqual(len(list(s["a"])), 0)

110
test/test_PydanticTable.py Normal file
View File

@@ -0,0 +1,110 @@
import os
from datetime import datetime
from enum import Enum
from typing import Optional
from unittest import TestCase
from pydantic import BaseModel
from pydantic.types import List
from hexlib.db import PersistentState
class Status(Enum):
yes = "yes"
no = "no"
class Point(BaseModel):
x: int
y: int
class Polygon(BaseModel):
points: List[Point] = []
created_date: datetime
status: Status = Status("yes")
class TestPydanticTable(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
],
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
self.assertEqual(s["a"]["1"].status, Status("yes"))
self.assertEqual(s["a"]["1"].points[1].x, 3)
self.assertEqual(s["a"]["1"].created_date.year, 2000)
def test_update(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
]
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
val.points[0].x = 2
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 2)
def test_sql(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"].sql(
"WHERE json->>'created_date' LIKE '2000-%'"
))
self.assertEqual(len(result), 1)
self.assertEqual(result[0].created_date.year, 2000)
def test_iterate(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"])
self.assertEqual(len(result), 2)
self.assertEqual(result[0].created_date.year, 2000)
self.assertEqual(result[1].created_date.year, 2010)

View File

@@ -152,7 +152,7 @@ class TestText(TestCase):
remove_stopwords_en=True, remove_stopwords_en=True,
remove_urls=True remove_urls=True
) )
expected = "hello world" expected = "hello world |"
self.assertEqual(" ".join(cleaned), expected) self.assertEqual(" ".join(cleaned), expected)
@@ -170,7 +170,7 @@ class TestText(TestCase):
remove_urls=False remove_urls=False
) )
expected = "217709510 is there a servant that is against civilization and humanity literally instant summon" expected = ">>217709510 is there a servant that is against civilization and humanity literally instant summon"
self.assertEqual(" ".join(cleaned), expected) self.assertEqual(" ".join(cleaned), expected)
def test_html_entity(self): def test_html_entity(self):
@@ -257,3 +257,23 @@ class TestText(TestCase):
expected = "hi test hello" expected = "hi test hello"
self.assertEqual(" ".join(cleaned), expected) self.assertEqual(" ".join(cleaned), expected)
def test_strip_dashes(self):
text = "yes -But something-something -- hello aa--bb"
cleaned = preprocess(
text,
strip_dashes=True
)
expected = "yes But something-something hello aa-bb"
self.assertEqual(" ".join(cleaned), expected)
def test_word_tokenize(self):
text = "i cannot believe'"
cleaned = preprocess(
text,
use_nltk_tokenizer=True
)
expected = "i can not believe '"
self.assertEqual(" ".join(cleaned), expected)