mirror of
https://github.com/simon987/hexlib.git
synced 2025-04-04 02:12:59 +00:00
Fix tests, add pydantic row support for PersistentState
This commit is contained in:
parent
826312115c
commit
a7b1a6e1ec
55
hexlib/db.py
55
hexlib/db.py
@ -1,14 +1,23 @@
|
||||
import base64
|
||||
import sqlite3
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
import umsgpack
|
||||
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from hexlib.env import get_redis
|
||||
|
||||
|
||||
def _json_encoder(x):
|
||||
if isinstance(x, datetime):
|
||||
return x.isoformat()
|
||||
return x
|
||||
|
||||
|
||||
class VolatileState:
|
||||
"""Quick and dirty volatile dict-like redis wrapper"""
|
||||
|
||||
@ -114,7 +123,7 @@ class Table:
|
||||
self._state = state
|
||||
self._table = table
|
||||
|
||||
def sql(self, where_clause, *params):
|
||||
def _sql_dict(self, where_clause, *params):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
@ -128,7 +137,14 @@ class Table:
|
||||
except:
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
def sql(self, where_clause, *params):
|
||||
for row in self._sql_dict(where_clause, *params):
|
||||
if row and "__pydantic" in row:
|
||||
yield self._deserialize_pydantic(row)
|
||||
else:
|
||||
yield row
|
||||
|
||||
def _iter_dict(self):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
@ -142,12 +158,19 @@ class Table:
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
def __iter__(self):
|
||||
for row in self._iter_dict():
|
||||
if row and "__pydantic" in row:
|
||||
yield self._deserialize_pydantic(row)
|
||||
else:
|
||||
yield row
|
||||
|
||||
def _getitem_dict(self, key):
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
|
||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
|
||||
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
@ -158,8 +181,32 @@ class Table:
|
||||
except:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_pydantic(row):
|
||||
module = __import__(row["__module"])
|
||||
cls = getattr(module, row["__class"])
|
||||
return cls.parse_raw(row["json"])
|
||||
|
||||
def __getitem__(self, key):
|
||||
row = self._getitem_dict(key)
|
||||
if row and "__pydantic" in row:
|
||||
return self._deserialize_pydantic(row)
|
||||
return row
|
||||
|
||||
def setitem_pydantic(self, key, value: BaseModel):
|
||||
self.__setitem__(key, {
|
||||
"json": value.json(encoder=_json_encoder),
|
||||
"__class": value.__class__.__name__,
|
||||
"__module": value.__class__.__module__,
|
||||
"__pydantic": 1
|
||||
})
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
self.setitem_pydantic(key, value)
|
||||
return
|
||||
|
||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
|
@ -1,17 +1,18 @@
|
||||
import re
|
||||
from functools import partial
|
||||
from itertools import chain, repeat
|
||||
from multiprocessing.pool import Pool
|
||||
|
||||
import nltk.corpus
|
||||
from lxml import etree
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
|
||||
from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE
|
||||
from .regex_util import LINK_RE
|
||||
|
||||
get_text = etree.XPath("//text()")
|
||||
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
|
||||
stop_words_en = set(stopwords.words("english"))
|
||||
|
||||
extra_stop_words_en = [
|
||||
@ -20,9 +21,6 @@ extra_stop_words_en = [
|
||||
|
||||
stop_words_en.update(extra_stop_words_en)
|
||||
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="hexlib",
|
||||
version="1.81",
|
||||
version="1.82",
|
||||
description="Misc utility methods",
|
||||
author="simon987",
|
||||
author_email="me@simon987.net",
|
||||
|
101
test/test_PydanticTable.py
Normal file
101
test/test_PydanticTable.py
Normal file
@ -0,0 +1,101 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from unittest import TestCase
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import List
|
||||
|
||||
from hexlib.db import PersistentState
|
||||
|
||||
|
||||
class Point(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
class Polygon(BaseModel):
|
||||
points: List[Point] = []
|
||||
created_date: datetime
|
||||
|
||||
|
||||
class TestPydanticTable(TestCase):
|
||||
def tearDown(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def setUp(self) -> None:
|
||||
if os.path.exists("state.db"):
|
||||
os.remove("state.db")
|
||||
|
||||
def test_get_set(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[
|
||||
Point(x=1, y=2),
|
||||
Point(x=3, y=4),
|
||||
]
|
||||
)
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
||||
self.assertEqual(s["a"]["1"].points[1].x, 3)
|
||||
self.assertEqual(s["a"]["1"].created_date.year, 2000)
|
||||
|
||||
def test_update(self):
|
||||
s = PersistentState()
|
||||
|
||||
val = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[
|
||||
Point(x=1, y=2),
|
||||
Point(x=3, y=4),
|
||||
]
|
||||
)
|
||||
|
||||
s["a"]["1"] = val
|
||||
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
||||
|
||||
val.points[0].x = 2
|
||||
s["a"]["1"] = val
|
||||
self.assertEqual(s["a"]["1"].points[0].x, 2)
|
||||
|
||||
def test_sql(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["b"]["1"] = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
s["b"]["2"] = Polygon(
|
||||
created_date=datetime(year=2010, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
|
||||
result = list(s["b"].sql(
|
||||
"WHERE json->>'created_date' LIKE '2000-%'"
|
||||
))
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].created_date.year, 2000)
|
||||
|
||||
def test_iterate(self):
|
||||
s = PersistentState()
|
||||
|
||||
s["b"]["1"] = Polygon(
|
||||
created_date=datetime(year=2000, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
s["b"]["2"] = Polygon(
|
||||
created_date=datetime(year=2010, day=1, month=1),
|
||||
points=[]
|
||||
)
|
||||
|
||||
result = list(s["b"])
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].created_date.year, 2000)
|
||||
self.assertEqual(result[1].created_date.year, 2010)
|
Loading…
x
Reference in New Issue
Block a user