mirror of
https://github.com/simon987/hexlib.git
synced 2025-04-21 18:46:42 +00:00
Compare commits
No commits in common. "3bd9f0399675b555c402ecbbfb8d3227886636bb" and "826312115c7a87e8af5bf27a4d2c181630130481" have entirely different histories.
3bd9f03996
...
826312115c
67
hexlib/db.py
67
hexlib/db.py
@ -1,26 +1,14 @@
|
|||||||
import base64
|
import base64
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import umsgpack
|
import umsgpack
|
||||||
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from hexlib.env import get_redis
|
from hexlib.env import get_redis
|
||||||
|
|
||||||
|
|
||||||
def _json_encoder(x):
|
|
||||||
if isinstance(x, datetime):
|
|
||||||
return x.isoformat()
|
|
||||||
if isinstance(x, Enum):
|
|
||||||
return x.value
|
|
||||||
|
|
||||||
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
|
|
||||||
|
|
||||||
|
|
||||||
class VolatileState:
|
class VolatileState:
|
||||||
"""Quick and dirty volatile dict-like redis wrapper"""
|
"""Quick and dirty volatile dict-like redis wrapper"""
|
||||||
|
|
||||||
@ -126,7 +114,7 @@ class Table:
|
|||||||
self._state = state
|
self._state = state
|
||||||
self._table = table
|
self._table = table
|
||||||
|
|
||||||
def _sql_dict(self, where_clause, *params):
|
def sql(self, where_clause, *params):
|
||||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
try:
|
try:
|
||||||
@ -140,14 +128,7 @@ class Table:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sql(self, where_clause, *params):
|
def __iter__(self):
|
||||||
for row in self._sql_dict(where_clause, *params):
|
|
||||||
if row and "__pydantic" in row:
|
|
||||||
yield self._deserialize_pydantic(row)
|
|
||||||
else:
|
|
||||||
yield row
|
|
||||||
|
|
||||||
def _iter_dict(self):
|
|
||||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
try:
|
try:
|
||||||
@ -161,19 +142,12 @@ class Table:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __iter__(self):
|
def __getitem__(self, item):
|
||||||
for row in self._iter_dict():
|
|
||||||
if row and "__pydantic" in row:
|
|
||||||
yield self._deserialize_pydantic(row)
|
|
||||||
else:
|
|
||||||
yield row
|
|
||||||
|
|
||||||
def _getitem_dict(self, key):
|
|
||||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
try:
|
try:
|
||||||
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
||||||
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (key,))
|
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
|
||||||
|
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if row:
|
if row:
|
||||||
@ -184,32 +158,8 @@ class Table:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _deserialize_pydantic(row):
|
|
||||||
module = __import__(row["__module"])
|
|
||||||
cls = getattr(module, row["__class"])
|
|
||||||
return cls.parse_raw(row["json"])
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
row = self._getitem_dict(key)
|
|
||||||
if row and "__pydantic" in row:
|
|
||||||
return self._deserialize_pydantic(row)
|
|
||||||
return row
|
|
||||||
|
|
||||||
def setitem_pydantic(self, key, value: BaseModel):
|
|
||||||
self.__setitem__(key, {
|
|
||||||
"json": value.json(encoder=_json_encoder, indent=2),
|
|
||||||
"__class": value.__class__.__name__,
|
|
||||||
"__module": value.__class__.__module__,
|
|
||||||
"__pydantic": 1
|
|
||||||
})
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
|
||||||
if isinstance(value, BaseModel):
|
|
||||||
self.setitem_pydantic(key, value)
|
|
||||||
return
|
|
||||||
|
|
||||||
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
@ -278,7 +228,7 @@ class PersistentState:
|
|||||||
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
|
def __init__(self, dbfile="state.db", logger=None, table_factory=Table, **dbargs):
|
||||||
self.dbfile = dbfile
|
self.dbfile = dbfile
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
if dbargs is None or dbargs == {}:
|
if dbargs is None:
|
||||||
dbargs = {"timeout": 30000}
|
dbargs = {"timeout": 30000}
|
||||||
self.dbargs = dbargs
|
self.dbargs = dbargs
|
||||||
self._table_factory = table_factory
|
self._table_factory = table_factory
|
||||||
@ -286,13 +236,6 @@ class PersistentState:
|
|||||||
def __getitem__(self, table):
|
def __getitem__(self, table):
|
||||||
return self._table_factory(self, table)
|
return self._table_factory(self, table)
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
with sqlite3.connect(self.dbfile, **self.dbargs) as conn:
|
|
||||||
try:
|
|
||||||
conn.execute(f"DROP TABLE {key}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def pg_fetch_cursor_all(cur, name, batch_size=1000):
|
def pg_fetch_cursor_all(cur, name, batch_size=1000):
|
||||||
while True:
|
while True:
|
||||||
|
@ -1,18 +1,17 @@
|
|||||||
import re
|
import re
|
||||||
|
from functools import partial
|
||||||
from itertools import chain, repeat
|
from itertools import chain, repeat
|
||||||
|
from multiprocessing.pool import Pool
|
||||||
|
|
||||||
import nltk.corpus
|
import nltk.corpus
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from nltk.stem import WordNetLemmatizer
|
from nltk.stem import WordNetLemmatizer
|
||||||
|
|
||||||
from .regex_util import LINK_RE
|
from .regex import WHITESPACE_RE, PUNCTUATION_RE, LINK_RE, XML_ENTITY_RE
|
||||||
|
|
||||||
get_text = etree.XPath("//text()")
|
get_text = etree.XPath("//text()")
|
||||||
|
|
||||||
nltk.download("stopwords", quiet=True)
|
|
||||||
nltk.download("wordnet", quiet=True)
|
|
||||||
|
|
||||||
stop_words_en = set(stopwords.words("english"))
|
stop_words_en = set(stopwords.words("english"))
|
||||||
|
|
||||||
extra_stop_words_en = [
|
extra_stop_words_en = [
|
||||||
@ -21,6 +20,9 @@ extra_stop_words_en = [
|
|||||||
|
|
||||||
stop_words_en.update(extra_stop_words_en)
|
stop_words_en.update(extra_stop_words_en)
|
||||||
|
|
||||||
|
nltk.download("stopwords", quiet=True)
|
||||||
|
nltk.download("wordnet", quiet=True)
|
||||||
|
|
||||||
lemmatizer = WordNetLemmatizer()
|
lemmatizer = WordNetLemmatizer()
|
||||||
|
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import setup
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="hexlib",
|
name="hexlib",
|
||||||
version="1.86",
|
version="1.81",
|
||||||
description="Misc utility methods",
|
description="Misc utility methods",
|
||||||
author="simon987",
|
author="simon987",
|
||||||
author_email="me@simon987.net",
|
author_email="me@simon987.net",
|
||||||
|
@ -131,13 +131,3 @@ class TestPersistentState(TestCase):
|
|||||||
s["a"][0] = {"x": b'abc'}
|
s["a"][0] = {"x": b'abc'}
|
||||||
|
|
||||||
self.assertEqual(list(s["a"])[0]["x"], b'abc')
|
self.assertEqual(list(s["a"])[0]["x"], b'abc')
|
||||||
|
|
||||||
def test_drop_table(self):
|
|
||||||
s = PersistentState()
|
|
||||||
|
|
||||||
s["a"][0] = {"x": 1}
|
|
||||||
s["a"][1] = {"x": 2}
|
|
||||||
self.assertEqual(len(list(s["a"])), 2)
|
|
||||||
|
|
||||||
del s["a"]
|
|
||||||
self.assertEqual(len(list(s["a"])), 0)
|
|
||||||
|
@ -1,110 +0,0 @@
|
|||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.types import List
|
|
||||||
|
|
||||||
from hexlib.db import PersistentState
|
|
||||||
|
|
||||||
|
|
||||||
class Status(Enum):
|
|
||||||
yes = "yes"
|
|
||||||
no = "no"
|
|
||||||
|
|
||||||
|
|
||||||
class Point(BaseModel):
|
|
||||||
x: int
|
|
||||||
y: int
|
|
||||||
|
|
||||||
|
|
||||||
class Polygon(BaseModel):
|
|
||||||
points: List[Point] = []
|
|
||||||
created_date: datetime
|
|
||||||
status: Status = Status("yes")
|
|
||||||
|
|
||||||
|
|
||||||
class TestPydanticTable(TestCase):
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
if os.path.exists("state.db"):
|
|
||||||
os.remove("state.db")
|
|
||||||
|
|
||||||
def setUp(self) -> None:
|
|
||||||
if os.path.exists("state.db"):
|
|
||||||
os.remove("state.db")
|
|
||||||
|
|
||||||
def test_get_set(self):
|
|
||||||
s = PersistentState()
|
|
||||||
|
|
||||||
val = Polygon(
|
|
||||||
created_date=datetime(year=2000, day=1, month=1),
|
|
||||||
points=[
|
|
||||||
Point(x=1, y=2),
|
|
||||||
Point(x=3, y=4),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
s["a"]["1"] = val
|
|
||||||
|
|
||||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
|
||||||
self.assertEqual(s["a"]["1"].status, Status("yes"))
|
|
||||||
self.assertEqual(s["a"]["1"].points[1].x, 3)
|
|
||||||
self.assertEqual(s["a"]["1"].created_date.year, 2000)
|
|
||||||
|
|
||||||
def test_update(self):
|
|
||||||
s = PersistentState()
|
|
||||||
|
|
||||||
val = Polygon(
|
|
||||||
created_date=datetime(year=2000, day=1, month=1),
|
|
||||||
points=[
|
|
||||||
Point(x=1, y=2),
|
|
||||||
Point(x=3, y=4),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
s["a"]["1"] = val
|
|
||||||
|
|
||||||
self.assertEqual(s["a"]["1"].points[0].x, 1)
|
|
||||||
|
|
||||||
val.points[0].x = 2
|
|
||||||
s["a"]["1"] = val
|
|
||||||
self.assertEqual(s["a"]["1"].points[0].x, 2)
|
|
||||||
|
|
||||||
def test_sql(self):
|
|
||||||
s = PersistentState()
|
|
||||||
|
|
||||||
s["b"]["1"] = Polygon(
|
|
||||||
created_date=datetime(year=2000, day=1, month=1),
|
|
||||||
points=[]
|
|
||||||
)
|
|
||||||
s["b"]["2"] = Polygon(
|
|
||||||
created_date=datetime(year=2010, day=1, month=1),
|
|
||||||
points=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = list(s["b"].sql(
|
|
||||||
"WHERE json->>'created_date' LIKE '2000-%'"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
|
||||||
self.assertEqual(result[0].created_date.year, 2000)
|
|
||||||
|
|
||||||
def test_iterate(self):
|
|
||||||
s = PersistentState()
|
|
||||||
|
|
||||||
s["b"]["1"] = Polygon(
|
|
||||||
created_date=datetime(year=2000, day=1, month=1),
|
|
||||||
points=[]
|
|
||||||
)
|
|
||||||
s["b"]["2"] = Polygon(
|
|
||||||
created_date=datetime(year=2010, day=1, month=1),
|
|
||||||
points=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = list(s["b"])
|
|
||||||
|
|
||||||
self.assertEqual(len(result), 2)
|
|
||||||
self.assertEqual(result[0].created_date.year, 2000)
|
|
||||||
self.assertEqual(result[1].created_date.year, 2010)
|
|
Loading…
x
Reference in New Issue
Block a user