diff --git a/hexlib/db.py b/hexlib/db.py index 7f4e498..c2b9a9c 100644 --- a/hexlib/db.py +++ b/hexlib/db.py @@ -126,9 +126,13 @@ class Table: try: conn.execute(sql, list(_serialize(v) for v in value.values())) except sqlite3.OperationalError: + if isinstance(key, int): + key_type = "integer" + else: + key_type = "text" conn.execute( - "create table if not exists %s (id text primary key,%s)" % - (self._table, ",".join("%s %s" % (k, _sqlite_type(v)) for k, v in value.items())) + "create table if not exists %s (id %s primary key,%s)" % + (self._table, key_type, ",".join("%s %s" % (k, _sqlite_type(v)) for k, v in value.items())) ) conn.execute(sql, list(_serialize(v) for v in value.values())) diff --git a/test/test_PersistantState.py b/test/test_PersistantState.py new file mode 100644 index 0000000..2148bbf --- /dev/null +++ b/test/test_PersistantState.py @@ -0,0 +1,33 @@ +import os +from unittest import TestCase + +from hexlib.db import PersistentState + + +class TestPersistentState(TestCase): + + def tearDown(self) -> None: + os.remove("state.db") + + def setUp(self) -> None: + os.remove("state.db") + + def test_get_set(self): + s = PersistentState() + + val = {"a": 1, "b": "2", "c": b'3', "d": 4.4} + s["a"]["1"] = val + + val["id"] = "1" + + self.assertDictEqual(val, s["a"]["1"]) + + def test_get_set_int_id(self): + s = PersistentState() + + val = {"a": 1, "b": "2", "c": b'3', "d": 4.4} + s["a"][1] = val + + val["id"] = 1 + + self.assertDictEqual(val, s["a"][1])