mirror of
https://github.com/simon987/hexlib.git
synced 2025-04-04 02:12:59 +00:00
308 lines
9.1 KiB
Python
308 lines
9.1 KiB
Python
import base64
|
|
import sqlite3
|
|
import traceback
|
|
|
|
import psycopg2
|
|
import umsgpack
|
|
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
|
|
|
from hexlib.env import get_redis
|
|
|
|
|
|
class PersistentState:
|
|
"""Quick and dirty persistent dict-like SQLite wrapper"""
|
|
|
|
def __init__(self, dbfile="state.db", logger=None, **dbargs):
|
|
self.dbfile = dbfile
|
|
self.logger = logger
|
|
if dbargs is None:
|
|
dbargs = {"timeout": 30000}
|
|
self.dbargs = dbargs
|
|
|
|
def __getitem__(self, table):
|
|
return Table(self, table)
|
|
|
|
|
|
class VolatileState:
|
|
"""Quick and dirty volatile dict-like redis wrapper"""
|
|
|
|
def __init__(self, prefix, redis_db=None, sep=""):
|
|
if redis_db is None:
|
|
redis_db = get_redis()
|
|
self.rdb = redis_db
|
|
self.prefix = prefix
|
|
self._sep = sep
|
|
|
|
def __getitem__(self, table):
|
|
return RedisTable(self, table, self._sep)
|
|
|
|
def __delitem__(self, key):
|
|
self.rdb.delete(f"{self.prefix}{self._sep}{key}")
|
|
|
|
|
|
class VolatileQueue:
|
|
"""Quick and dirty volatile queue-like redis wrapper"""
|
|
|
|
def __init__(self, key, redis_db=None):
|
|
if redis_db is None:
|
|
redis_db = get_redis()
|
|
self.rdb = redis_db
|
|
self.key = key
|
|
|
|
def put(self, item):
|
|
self.rdb.sadd(self.key, umsgpack.dumps(item))
|
|
|
|
def get(self):
|
|
v = self.rdb.spop(self.key)
|
|
if v:
|
|
return umsgpack.loads(v)
|
|
|
|
|
|
class VolatileBooleanState:
|
|
"""Quick and dirty volatile dict-like redis wrapper for boolean values"""
|
|
|
|
def __init__(self, prefix, redis_db=None, sep=""):
|
|
if redis_db is None:
|
|
redis_db = get_redis()
|
|
self.rdb = redis_db
|
|
self.prefix = prefix
|
|
self._sep = sep
|
|
|
|
def __getitem__(self, table):
|
|
return RedisBooleanTable(self, table, self._sep)
|
|
|
|
def __delitem__(self, table):
|
|
self.rdb.delete(f"{self.prefix}{self._sep}{table}")
|
|
|
|
|
|
class RedisTable:
|
|
def __init__(self, state, table, sep=""):
|
|
self._state = state
|
|
self._table = table
|
|
self._sep = sep
|
|
self._key = f"{self._state.prefix}{self._sep}{self._table}"
|
|
|
|
def __setitem__(self, key, value):
|
|
self._state.rdb.hset(self._key, str(key), umsgpack.dumps(value))
|
|
|
|
def __getitem__(self, key):
|
|
val = self._state.rdb.hget(self._key, str(key))
|
|
if val:
|
|
return umsgpack.loads(val)
|
|
return None
|
|
|
|
def __delitem__(self, key):
|
|
self._state.rdb.hdel(self._key, str(key))
|
|
|
|
def __iter__(self):
|
|
for val in self._state.rdb.hscan(self._key):
|
|
if val:
|
|
return ((k, umsgpack.loads(v)) for k, v in val.items())
|
|
|
|
|
|
class RedisBooleanTable:
|
|
def __init__(self, state, table, sep=""):
|
|
self._state = state
|
|
self._table = table
|
|
self._sep = sep
|
|
self._key = f"{self._state.prefix}{self._sep}{self._table}"
|
|
|
|
def __setitem__(self, key, value):
|
|
if value:
|
|
self._state.rdb.sadd(self._key, str(key))
|
|
else:
|
|
self.__delitem__(key)
|
|
|
|
def __getitem__(self, key):
|
|
return self._state.rdb.sismember(self._key, str(key))
|
|
|
|
def __delitem__(self, key):
|
|
self._state.rdb.srem(self._key, str(key))
|
|
|
|
def __iter__(self):
|
|
yield from self._state.rdb.sscan_iter(self._key)
|
|
|
|
|
|
class Table:
|
|
def __init__(self, state, table):
|
|
self._state = state
|
|
self._table = table
|
|
|
|
def sql(self, where_clause, *params):
|
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
try:
|
|
cur = conn.execute("SELECT * FROM %s %s" % (self._table, where_clause), params)
|
|
for row in cur:
|
|
yield dict(row)
|
|
except:
|
|
return None
|
|
|
|
def __iter__(self):
|
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
try:
|
|
cur = conn.execute("SELECT * FROM %s" % (self._table,))
|
|
for row in cur:
|
|
yield dict(row)
|
|
except:
|
|
return None
|
|
|
|
def __getitem__(self, item):
|
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
try:
|
|
col_types = conn.execute("PRAGMA table_info(%s)" % self._table).fetchall()
|
|
cur = conn.execute("SELECT * FROM %s WHERE id=?" % (self._table,), (item,))
|
|
|
|
row = cur.fetchone()
|
|
if row:
|
|
return dict(
|
|
(col[0], _deserialize(row[col[0]], col_types[i]["type"]))
|
|
for i, col in enumerate(cur.description)
|
|
)
|
|
except:
|
|
return None
|
|
|
|
def __setitem__(self, key, value):
|
|
|
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
sql = "INSERT INTO %s (id,%s) VALUES ('%s',%s)" % \
|
|
(self._table, ",".join(value.keys()), key, ",".join("?" for _ in value.values()))
|
|
try:
|
|
conn.execute(sql, list(_serialize(v) for v in value.values()))
|
|
except sqlite3.OperationalError:
|
|
if isinstance(key, int):
|
|
key_type = "integer"
|
|
else:
|
|
key_type = "text"
|
|
conn.execute(
|
|
"create table if not exists %s (id %s primary key,%s)" %
|
|
(self._table, key_type, ",".join("%s %s" % (k, _sqlite_type(v)) for k, v in value.items()))
|
|
)
|
|
conn.execute(sql, list(_serialize(v) for v in value.values()))
|
|
|
|
except sqlite3.IntegrityError:
|
|
sql = "UPDATE %s SET (%s) = (%s) WHERE id=?" \
|
|
% (self._table, ",".join(value.keys()), ",".join("?" for _ in value.values()))
|
|
args = list(_serialize(v) for v in value.values())
|
|
args.append(key)
|
|
conn.execute(
|
|
sql,
|
|
args
|
|
)
|
|
|
|
def __delitem__(self, key):
|
|
with sqlite3.connect(self._state.dbfile, **self._state.dbargs) as conn:
|
|
try:
|
|
conn.execute("DELETE FROM %s WHERE id=?" % self._table, (key,))
|
|
except sqlite3.OperationalError:
|
|
pass
|
|
|
|
|
|
def _sqlite_type(value):
|
|
if isinstance(value, int):
|
|
return "integer"
|
|
if isinstance(value, float):
|
|
return "real"
|
|
if isinstance(value, bytes):
|
|
return "blob"
|
|
return "text"
|
|
|
|
|
|
def _serialize(value):
|
|
if isinstance(value, bytes):
|
|
return base64.b64encode(value)
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, bool):
|
|
return value
|
|
return str(value)
|
|
|
|
|
|
def _deserialize(value, col_type):
|
|
if col_type == "blob":
|
|
return base64.b64decode(value)
|
|
return value
|
|
|
|
|
|
def pg_fetch_cursor_all(cur, name, batch_size=1000):
|
|
while True:
|
|
cur.execute("FETCH FORWARD %d FROM %s" % (batch_size, name))
|
|
cnt = 0
|
|
|
|
for row in cur:
|
|
cnt += 1
|
|
yield row
|
|
|
|
if cnt != batch_size:
|
|
cur.execute("FETCH ALL FROM %s" % (name,))
|
|
for row in cur:
|
|
yield row
|
|
break
|
|
|
|
|
|
class PgConn:
|
|
"""Wrapper for PostgreSQL connection"""
|
|
|
|
def __init__(self, logger=None, **kwargs):
|
|
self._conn_args = kwargs
|
|
self.conn = psycopg2.connect(**kwargs)
|
|
self.cur = self.conn.cursor()
|
|
self._logger = logger
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def exec(self, query_string, args=None):
|
|
while True:
|
|
try:
|
|
if self._logger:
|
|
self._logger.debug(query_string)
|
|
self._logger.debug("With args " + str(args))
|
|
|
|
self.cur.execute(query_string, args)
|
|
break
|
|
except psycopg2.Error as e:
|
|
if e.pgcode == UNIQUE_VIOLATION:
|
|
break
|
|
traceback.print_stack()
|
|
self._handle_err(e, query_string, args)
|
|
|
|
def query(self, query_string, args=None, max_retries=1):
|
|
retries = max_retries
|
|
while retries > 0:
|
|
try:
|
|
if self._logger:
|
|
self._logger.debug(query_string)
|
|
self._logger.debug("With args " + str(args))
|
|
|
|
self.cur.execute(query_string, args)
|
|
res = self.cur.fetchall()
|
|
|
|
if self._logger:
|
|
self._logger.debug("result: " + str(res))
|
|
|
|
return res
|
|
except psycopg2.Error as e:
|
|
if e.pgcode == UNIQUE_VIOLATION:
|
|
break
|
|
self._handle_err(e, query_string, args)
|
|
retries -= 1
|
|
|
|
def _handle_err(self, err, query, args):
|
|
if self._logger:
|
|
self._logger.warning(
|
|
"Error during query '%s' with args %s: %s %s (%s)" % (query, args, type(err), err, err.pgcode))
|
|
self.conn = psycopg2.connect(**self._conn_args)
|
|
self.cur = self.conn.cursor()
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
try:
|
|
self.conn.commit()
|
|
self.cur.close()
|
|
except:
|
|
pass
|