From a7b1a6e1ec635e33020e367c49771e0ccb19084f Mon Sep 17 00:00:00 2001 From: simon987 Date: Sat, 25 Feb 2023 15:20:17 -0500 Subject: [PATCH] Fix tests, add pydantic row support for PersistentState --- hexlib/db.py | 55 ++++++++++++++-- hexlib/{regex.py => regex_util.py} | 0 hexlib/text.py | 10 ++- setup.py | 2 +- test/test_PydanticTable.py | 101 +++++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 11 deletions(-) rename hexlib/{regex.py => regex_util.py} (100%) create mode 100644 test/test_PydanticTable.py diff --git a/hexlib/db.py b/hexlib/db.py index b919e91..583778a 100644 --- a/hexlib/db.py +++ b/hexlib/db.py @@ -1,14 +1,23 @@ import base64 import sqlite3 import traceback +from datetime import datetime import psycopg2 import umsgpack from psycopg2.errorcodes import UNIQUE_VIOLATION +import json +from pydantic import BaseModel from hexlib.env import get_redis +def _json_encoder(x): + if isinstance(x, datetime): + return x.isoformat() + return x + + class VolatileState: """Quick and dirty volatile dict-like redis wrapper""" @@ -114,7 +123,7 @@ class Table: self._state = state self._table = table - def sql(self, where_clause, *params): + def _sql_dict(self, where_clause, *params): with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: conn.row_factory = sqlite3.Row try: @@ -128,7 +137,14 @@ class Table: except: return None - def __iter__(self): + def sql(self, where_clause, *params): + for row in self._sql_dict(where_clause, *params): + if row and "__pydantic" in row: + yield self._deserialize_pydantic(row) + else: + yield row + + def _iter_dict(self): with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: conn.row_factory = sqlite3.Row try: @@ -142,12 +158,19 @@ class Table: except: return None - def __getitem__(self, item): + def __iter__(self): + for row in self._iter_dict(): + if row and "__pydantic" in row: + yield self._deserialize_pydantic(row) + else: + yield row + + def _getitem_dict(self, key): with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: conn.row_factory = sqlite3.Row try: col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall() - cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,)) + cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,)) row = cur.fetchone() if row: @@ -158,8 +181,32 @@ class Table: except: return None + @staticmethod + def _deserialize_pydantic(row): + module = __import__(row["__module"]) + cls = getattr(module, row["__class"]) + return cls.parse_raw(row["json"]) + + def __getitem__(self, key): + row = self._getitem_dict(key) + if row and "__pydantic" in row: + return self._deserialize_pydantic(row) + return row + + def setitem_pydantic(self, key, value: BaseModel): + self.__setitem__(key, { + "json": value.json(encoder=_json_encoder), + "__class": value.__class__.__name__, + "__module": value.__class__.__module__, + "__pydantic": 1 + }) + def __setitem__(self, key, value): + if isinstance(value, BaseModel): + self.setitem_pydantic(key, value) + return + with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn: conn.row_factory = sqlite3.Row diff --git a/hexlib/regex.py b/hexlib/regex_util.py similarity index 100% rename from hexlib/regex.py rename to hexlib/regex_util.py diff --git a/hexlib/text.py b/hexlib/text.py index b1400d7..d6e5a22 100644 --- a/hexlib/text.py +++ b/hexlib/text.py @@ -1,17 +1,18 @@ import re -from functools import partial from itertools import chain, repeat -from multiprocessing.pool import Pool import nltk.corpus from lxml import etree from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer -from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE +from .regex_util import LINK_RE get_text = etree.XPath("//text()") +nltk.download("stopwords", quiet=True) +nltk.download("wordnet", quiet=True) + stop_words_en = set(stopwords.words("english")) extra_stop_words_en = [ @@ -20,9 +21,6 @@ extra_stop_words_en = [ stop_words_en.update(extra_stop_words_en) -nltk.download("stopwords", quiet=True) -nltk.download("wordnet", quiet=True) - lemmatizer = WordNetLemmatizer() diff --git a/setup.py b/setup.py index 2f7bb38..9096835 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup( name="hexlib", - version="1.81", + version="1.82", description="Misc utility methods", author="simon987", author_email="me@simon987.net", diff --git a/test/test_PydanticTable.py b/test/test_PydanticTable.py new file mode 100644 index 0000000..d9c86b2 --- /dev/null +++ b/test/test_PydanticTable.py @@ -0,0 +1,101 @@ +import os +from datetime import datetime +from unittest import TestCase + +from pydantic import BaseModel +from pydantic.types import List + +from hexlib.db import PersistentState + + +class Point(BaseModel): + x: int + y: int + + +class Polygon(BaseModel): + points: List[Point] = [] + created_date: datetime + + +class TestPydanticTable(TestCase): + def tearDown(self) -> None: + if os.path.exists("state.db"): + os.remove("state.db") + + def setUp(self) -> None: + if os.path.exists("state.db"): + os.remove("state.db") + + def test_get_set(self): + s = PersistentState() + + val = Polygon( + created_date=datetime(year=2000, day=1, month=1), + points=[ + Point(x=1, y=2), + Point(x=3, y=4), + ] + ) + + s["a"]["1"] = val + + self.assertEqual(s["a"]["1"].points[0].x, 1) + self.assertEqual(s["a"]["1"].points[1].x, 3) + self.assertEqual(s["a"]["1"].created_date.year, 2000) + + def test_update(self): + s = PersistentState() + + val = Polygon( + created_date=datetime(year=2000, day=1, month=1), + points=[ + Point(x=1, y=2), + Point(x=3, y=4), + ] + ) + + s["a"]["1"] = val + + self.assertEqual(s["a"]["1"].points[0].x, 1) + + val.points[0].x = 2 + s["a"]["1"] = val + self.assertEqual(s["a"]["1"].points[0].x, 2) + + def test_sql(self): + s = PersistentState() + + s["b"]["1"] = Polygon( + created_date=datetime(year=2000, day=1, month=1), + points=[] + ) + s["b"]["2"] = Polygon( + created_date=datetime(year=2010, day=1, month=1), + points=[] + ) + + result = list(s["b"].sql( + "WHERE json->>'created_date' LIKE '2000-%'" + )) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].created_date.year, 2000) + + def test_iterate(self): + s = PersistentState() + + s["b"]["1"] = Polygon( + created_date=datetime(year=2000, day=1, month=1), + points=[] + ) + s["b"]["2"] = Polygon( + created_date=datetime(year=2010, day=1, month=1), + points=[] + ) + + result = list(s["b"]) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].created_date.year, 2000) + self.assertEqual(result[1].created_date.year, 2010)