Fix JSON encoding for Enums

This commit is contained in:
simon987 2023-02-25 15:51:28 -05:00
parent 5275c332cc
commit 42e33b72b2
3 changed files with 16 additions and 4 deletions

View File

@ -2,11 +2,11 @@ import base64
import sqlite3 import sqlite3
import traceback import traceback
from datetime import datetime from datetime import datetime
from enum import Enum
import psycopg2 import psycopg2
import umsgpack import umsgpack
from psycopg2.errorcodes import UNIQUE_VIOLATION from psycopg2.errorcodes import UNIQUE_VIOLATION
import json
from pydantic import BaseModel from pydantic import BaseModel
from hexlib.env import get_redis from hexlib.env import get_redis
@ -15,7 +15,10 @@ from hexlib.env import get_redis
def _json_encoder(x): def _json_encoder(x):
if isinstance(x, datetime): if isinstance(x, datetime):
return x.isoformat() return x.isoformat()
return x if isinstance(x, Enum):
return x.value
raise Exception(f"I don't know how to JSON encode {x} ({type(x)})")
class VolatileState: class VolatileState:

View File

@ -2,7 +2,7 @@ from setuptools import setup
setup( setup(
name="hexlib", name="hexlib",
version="1.83", version="1.84",
description="Misc utility methods", description="Misc utility methods",
author="simon987", author="simon987",
author_email="me@simon987.net", author_email="me@simon987.net",

View File

@ -1,5 +1,7 @@
import os import os
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Optional
from unittest import TestCase from unittest import TestCase
from pydantic import BaseModel from pydantic import BaseModel
@ -8,6 +10,11 @@ from pydantic.types import List
from hexlib.db import PersistentState from hexlib.db import PersistentState
class Status(Enum):
yes = "yes"
no = "no"
class Point(BaseModel): class Point(BaseModel):
x: int x: int
y: int y: int
@ -16,6 +23,7 @@ class Point(BaseModel):
class Polygon(BaseModel): class Polygon(BaseModel):
points: List[Point] = [] points: List[Point] = []
created_date: datetime created_date: datetime
status: Status = Status("yes")
class TestPydanticTable(TestCase): class TestPydanticTable(TestCase):
@ -35,12 +43,13 @@ class TestPydanticTable(TestCase):
points=[ points=[
Point(x=1, y=2), Point(x=1, y=2),
Point(x=3, y=4), Point(x=3, y=4),
] ],
) )
s["a"]["1"] = val s["a"]["1"] = val
self.assertEqual(s["a"]["1"].points[0].x, 1) self.assertEqual(s["a"]["1"].points[0].x, 1)
self.assertEqual(s["a"]["1"].status, Status("yes"))
self.assertEqual(s["a"]["1"].points[1].x, 3) self.assertEqual(s["a"]["1"].points[1].x, 3)
self.assertEqual(s["a"]["1"].created_date.year, 2000) self.assertEqual(s["a"]["1"].created_date.year, 2000)