Compare commits

..

No commits in common. "3bd9f0399675b555c402ecbbfb8d3227886636bb" and "826312115c7a87e8af5bf27a4d2c181630130481" have entirely different histories.

6 changed files with 12 additions and 187 deletions

View File

@ -1,26 +1,14 @@
import base64 import base64
import sqlite3 import sqlite3
import traceback import traceback
from datetime import datetime
from enum import Enum
import psycopg2 import psycopg2
import umsgpack import umsgpack
from psycopg2.errorcodes import UNIQUE_VIOLATION from psycopg2.errorcodes import UNIQUE_VIOLATION
from pydantic import BaseModel
from hexlib.env import get_redis from hexlib.env import get_redis
def _json_encoder(x):
if isinstance(x, datetime):
return x.isoformat()
if isinstance(x, Enum):
return x.value
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
class VolatileState: class VolatileState:
"""Quick and dirty volatile dict-like redis wrapper""" """Quick and dirty volatile dict-like redis wrapper"""
@ -126,7 +114,7 @@ class Table:
self._state = state self._state = state
self._table = table self._table = table
def _sql_dict(self, where_clause, *params): def sql(self, where_clause, *params):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
try: try:
@ -140,14 +128,7 @@ class Table:
except: except:
return None return None
def sql(self, where_clause, *params): def __iter__(self):
for row in self._sql_dict(where_clause, *params):
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _iter_dict(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
try: try:
@ -161,19 +142,12 @@ class Table:
except: except:
return None return None
def __iter__(self): def __getitem__(self, item):
for row in self._iter_dict():
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _getitem_dict(self, key):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
try: try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall() col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,)) cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
row = cur.fetchone() row = cur.fetchone()
if row: if row:
@ -184,32 +158,8 @@ class Table:
except: except:
return None return None
@staticmethod
def _deserialize_pydantic(row):
module = __import__(row["__module"])
cls = getattr(module, row["__class"])
return cls.parse_raw(row["json"])
def __getitem__(self, key):
row = self._getitem_dict(key)
if row and "__pydantic" in row:
return self._deserialize_pydantic(row)
return row
def setitem_pydantic(self, key, value: BaseModel):
self.__setitem__(key, {
"json": value.json(encoder=_json_encoder, indent=2),
"__class": value.__class__.__name__,
"__module": value.__class__.__module__,
"__pydantic": 1
})
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(value, BaseModel):
self.setitem_pydantic(key, value)
return
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
@ -278,7 +228,7 @@ class PersistentState:
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs): def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
self.dbfile = dbfile self.dbfile = dbfile
self.logger = logger self.logger = logger
if dbargs is None or dbargs == {}: if dbargs is None:
dbargs = {"timeout": 30000} dbargs = {"timeout": 30000}
self.dbargs = dbargs self.dbargs = dbargs
self._table_factory = table_factory self._table_factory = table_factory
@ -286,13 +236,6 @@ class PersistentState:
def __getitem__(self, table): def __getitem__(self, table):
return self._table_factory(self, table) return self._table_factory(self, table)
def __delitem__(self, key):
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
try:
conn.execute(f"DROP TABLE {key}")
except:
pass
def pg_fetch_cursor_all(cur, name, batch_size=1000): def pg_fetch_cursor_all(cur, name, batch_size=1000):
while True: while True:

View File

@ -1,18 +1,17 @@
import re import re
from functools import partial
from itertools import chain, repeat from itertools import chain, repeat
from multiprocessing.pool import Pool
import nltk.corpus import nltk.corpus
from lxml import etree from lxml import etree
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer from nltk.stem import WordNetLemmatizer
from .regex_util import LINK_RE from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE
get_text = etree.XPath("//text()") get_text = etree.XPath("//text()")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
stop_words_en = set(stopwords.words("english")) stop_words_en = set(stopwords.words("english"))
extra_stop_words_en = [ extra_stop_words_en = [
@ -21,6 +20,9 @@ extra_stop_words_en = [
stop_words_en.update(extra_stop_words_en) stop_words_en.update(extra_stop_words_en)
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
lemmatizer = WordNetLemmatizer() lemmatizer = WordNetLemmatizer()

View File

@ -2,7 +2,7 @@ from setuptools import setup
setup( setup(
name="hexlib", name="hexlib",
version="1.86", version="1.81",
description="Misc utility methods", description="Misc utility methods",
author="simon987", author="simon987",
author_email="me@simon987.net", author_email="me@simon987.net",

View File

@ -131,13 +131,3 @@ class TestPersistentState(TestCase):
s["a"][0] = {"x": b'abc'} s["a"][0] = {"x": b'abc'}
self.assertEqual(list(s["a"])[0]["x"], b'abc') self.assertEqual(list(s["a"])[0]["x"], b'abc')
def test_drop_table(self):
s = PersistentState()
s["a"][0] = {"x": 1}
s["a"][1] = {"x": 2}
self.assertEqual(len(list(s["a"])), 2)
del s["a"]
self.assertEqual(len(list(s["a"])), 0)

View File

@ -1,110 +0,0 @@
import os
from datetime import datetime
from enum import Enum
from typing import Optional
from unittest import TestCase
from pydantic import BaseModel
from pydantic.types import List
from hexlib.db import PersistentState
class Status(Enum):
yes = "yes"
no = "no"
class Point(BaseModel):
x: int
y: int
class Polygon(BaseModel):
points: List[Point] = []
created_date: datetime
status: Status = Status("yes")
class TestPydanticTable(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
],
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
self.assertEqual(s["a"]["1"].status, Status("yes"))
self.assertEqual(s["a"]["1"].points[1].x, 3)
self.assertEqual(s["a"]["1"].created_date.year, 2000)
def test_update(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
]
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
val.points[0].x = 2
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 2)
def test_sql(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"].sql(
"WHERE json->>'created_date' LIKE '2000-%'"
))
self.assertEqual(len(result), 1)
self.assertEqual(result[0].created_date.year, 2000)
def test_iterate(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"])
self.assertEqual(len(result), 2)
self.assertEqual(result[0].created_date.year, 2000)
self.assertEqual(result[1].created_date.year, 2010)