Fix tests, add pydantic row support for PersistentState

This commit is contained in:
simon987 2023-02-25 15:20:17 -05:00
parent 826312115c
commit a7b1a6e1ec
5 changed files with 157 additions and 11 deletions

View File

@ -1,14 +1,23 @@
import base64
import sqlite3
import traceback
from datetime import datetime
import psycopg2
import umsgpack
from psycopg2.errorcodes import UNIQUE_VIOLATION
import json
from pydantic import BaseModel
from hexlib.env import get_redis
def _json_encoder(x):
if isinstance(x, datetime):
return x.isoformat()
return x
class VolatileState:
"""Quick and dirty volatile dict-like redis wrapper"""
@ -114,7 +123,7 @@ class Table:
self._state = state
self._table = table
def sql(self, where_clause, *params):
def _sql_dict(self, where_clause, *params):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
@ -128,7 +137,14 @@ class Table:
except:
return None
def __iter__(self):
def sql(self, where_clause, *params):
for row in self._sql_dict(where_clause, *params):
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _iter_dict(self):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
@ -142,12 +158,19 @@ class Table:
except:
return None
def __getitem__(self, item):
def __iter__(self):
for row in self._iter_dict():
if row and "__pydantic" in row:
yield self._deserialize_pydantic(row)
else:
yield row
def _getitem_dict(self, key):
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row
try:
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
row = cur.fetchone()
if row:
@ -158,8 +181,32 @@ class Table:
except:
return None
@staticmethod
def _deserialize_pydantic(row):
module = __import__(row["__module"])
cls = getattr(module, row["__class"])
return cls.parse_raw(row["json"])
def __getitem__(self, key):
row = self._getitem_dict(key)
if row and "__pydantic" in row:
return self._deserialize_pydantic(row)
return row
def setitem_pydantic(self, key, value: BaseModel):
self.__setitem__(key, {
"json": value.json(encoder=_json_encoder),
"__class": value.__class__.__name__,
"__module": value.__class__.__module__,
"__pydantic": 1
})
def __setitem__(self, key, value):
if isinstance(value, BaseModel):
self.setitem_pydantic(key, value)
return
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
conn.row_factory = sqlite3.Row

View File

@ -1,17 +1,18 @@
import re
from functools import partial
from itertools import chain, repeat
from multiprocessing.pool import Pool
import nltk.corpus
from lxml import etree
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE
from .regex_util import LINK_RE
get_text = etree.XPath("//text()")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
stop_words_en = set(stopwords.words("english"))
extra_stop_words_en = [
@ -20,9 +21,6 @@ extra_stop_words_en = [
stop_words_en.update(extra_stop_words_en)
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
lemmatizer = WordNetLemmatizer()

View File

@ -2,7 +2,7 @@ from setuptools import setup
setup(
name="hexlib",
version="1.81",
version="1.82",
description="Misc utility methods",
author="simon987",
author_email="me@simon987.net",

101
test/test_PydanticTable.py Normal file
View File

@ -0,0 +1,101 @@
import os
from datetime import datetime
from unittest import TestCase
from pydantic import BaseModel
from pydantic.types import List
from hexlib.db import PersistentState
class Point(BaseModel):
x: int
y: int
class Polygon(BaseModel):
points: List[Point] = []
created_date: datetime
class TestPydanticTable(TestCase):
def tearDown(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def setUp(self) -> None:
if os.path.exists("state.db"):
os.remove("state.db")
def test_get_set(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
]
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
self.assertEqual(s["a"]["1"].points[1].x, 3)
self.assertEqual(s["a"]["1"].created_date.year, 2000)
def test_update(self):
s = PersistentState()
val = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[
Point(x=1, y=2),
Point(x=3, y=4),
]
)
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1)
val.points[0].x = 2
s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 2)
def test_sql(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"].sql(
"WHERE json->>'created_date' LIKE '2000-%'"
))
self.assertEqual(len(result), 1)
self.assertEqual(result[0].created_date.year, 2000)
def test_iterate(self):
s = PersistentState()
s["b"]["1"] = Polygon(
created_date=datetime(year=2000, day=1, month=1),
points=[]
)
s["b"]["2"] = Polygon(
created_date=datetime(year=2010, day=1, month=1),
points=[]
)
result = list(s["b"])
self.assertEqual(len(result), 2)
self.assertEqual(result[0].created_date.year, 2000)
self.assertEqual(result[1].created_date.year, 2010)