Skip to content

Commit b7eff1c

Browse files
committed
Update json serializer so that we automatically short-circuit circular references and thus can serialize more of the AST
1 parent 96e9bf7 commit b7eff1c

File tree

6 files changed

+325
-84
lines changed

6 files changed

+325
-84
lines changed

datajunction-server/datajunction_server/models/node.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,15 @@ def has_available_materialization(self, build_criteria: BuildCriteria) -> bool:
835835
)
836836
)
837837

838+
def __json_encode__(self):
839+
"""
840+
JSON encoder for node revision
841+
"""
842+
return {
843+
"name": self.name,
844+
"type": self.type,
845+
}
846+
838847

839848
class ImmutableNodeFields(BaseSQLModel):
840849
"""

datajunction-server/datajunction_server/sql/parsing/ast.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ class Node(ABC):
106106
def json_ignore_keys(self):
107107
return ["parent", "parent_key", "_is_compiled"]
108108

109+
def __json_encode__(self):
110+
return {
111+
key: self.__dict__[key]
112+
for key in self.__dict__
113+
if key not in self.json_ignore_keys
114+
}
115+
109116
def __post_init__(self):
110117
self.add_self_as_parent()
111118

@@ -628,6 +635,10 @@ def identifier(self, quotes: bool = True) -> str:
628635
f"{namespace}{quote_style}{self.name}{quote_style}" # pylint: disable=C0301
629636
)
630637

638+
@property
639+
def json_ignore_keys(self):
640+
return ["names", "parent", "parent_key"]
641+
631642

632643
TNamed = TypeVar("TNamed", bound="Named") # pylint: disable=C0103
633644

@@ -711,9 +722,7 @@ class Column(Aliasable, Named, Expression):
711722

712723
@property
713724
def json_ignore_keys(self):
714-
if set(self._expression.columns).intersection(self.columns):
715-
return ["parent", "parent_key", "_is_compiled", "_expression", "columns"]
716-
return ["parent", "parent_key", "_is_compiled", "columns"]
725+
return ["parent", "parent_key", "columns"]
717726

718727
@property
719728
def type(self):
@@ -1000,10 +1009,11 @@ def json_ignore_keys(self):
10001009
return [
10011010
"parent",
10021011
"parent_key",
1003-
"_is_compiled",
1012+
# "_is_compiled",
10041013
"_columns",
1005-
"column_list",
1014+
# "column_list",
10061015
"_ref_columns",
1016+
"columns",
10071017
]
10081018

10091019
@property
@@ -1250,6 +1260,11 @@ class BinaryOpKind(DJEnum):
12501260
Minus = "-"
12511261
Modulo = "%"
12521262

1263+
def __json_encode__(self):
1264+
return {
1265+
"value": self.value,
1266+
}
1267+
12531268

12541269
@dataclass(eq=False)
12551270
class BinaryOp(Operation):
@@ -2026,7 +2041,16 @@ class FunctionTable(FunctionTableExpression):
20262041

20272042
@property
20282043
def json_ignore_keys(self):
2029-
return ["parent", "parent_key", "_is_compiled", "_table"]
2044+
return [
2045+
"parent",
2046+
"parent_key",
2047+
"_is_compiled",
2048+
"_table",
2049+
"_columns",
2050+
"column_list",
2051+
"_ref_columns",
2052+
"columns",
2053+
]
20302054

20312055
def __str__(self) -> str:
20322056
alias = f" {self.alias}" if self.alias else ""

datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,30 @@
33
"""
44
from json import JSONEncoder
55

6+
from sqlmodel import select
7+
8+
from datajunction_server.models import Node
9+
from datajunction_server.sql.parsing import ast
10+
11+
12+
def remove_circular_refs(obj, _seen: set = None):
13+
"""
14+
Short-circuits circular references in AST nodes
15+
"""
16+
if _seen is None:
17+
_seen = set()
18+
if id(obj) in _seen:
19+
return None
20+
_seen.add(id(obj))
21+
if issubclass(obj.__class__, ast.Node):
22+
serializable_keys = [
23+
key for key in obj.__dict__.keys() if key not in obj.json_ignore_keys
24+
]
25+
for key in serializable_keys:
26+
setattr(obj, key, remove_circular_refs(getattr(obj, key), _seen))
27+
_seen.remove(id(obj))
28+
return obj
29+
630

731
class ASTEncoder(JSONEncoder):
832
"""
@@ -12,26 +36,50 @@ class ASTEncoder(JSONEncoder):
1236
"""
1337

1438
def __init__(self, *args, **kwargs):
15-
kwargs["check_circular"] = False # no need to check anymore
39+
kwargs["check_circular"] = False
40+
self.markers = set()
1641
super().__init__(*args, **kwargs)
17-
self._processed = set()
1842

1943
def default(self, o):
20-
if id(o) in self._processed:
21-
return None
22-
self._processed.add(id(o))
23-
24-
if o.__class__.__name__ == "NodeRevision":
25-
return {
26-
"__class__": o.__class__.__name__,
27-
"name": o.name,
28-
"type": o.type,
29-
}
30-
44+
o = remove_circular_refs(o)
3145
json_dict = {
32-
k: o.__dict__[k]
33-
for k in o.__dict__
34-
if hasattr(o, "json_ignore_keys") and k not in o.json_ignore_keys
46+
"__class__": o.__class__.__name__,
3547
}
36-
json_dict["__class__"] = o.__class__.__name__
48+
if hasattr(o, "__json_encode__"):
49+
json_dict = {**json_dict, **o.__json_encode__()}
3750
return json_dict
51+
52+
53+
def ast_decoder(session, json_dict):
54+
"""Decodes json dict"""
55+
class_name = json_dict["__class__"]
56+
if not class_name or not hasattr(ast, class_name):
57+
return None
58+
clazz = getattr(ast, class_name)
59+
if class_name == "NodeRevision":
60+
instance = (
61+
session.exec(select(Node).where(Node.name == json_dict["name"]))
62+
.one()
63+
.current
64+
)
65+
else:
66+
instance = clazz(
67+
**{
68+
k: v
69+
for k, v in json_dict.items()
70+
if k not in {"__class__", "_type", "laterals", "_is_compiled"}
71+
},
72+
)
73+
for key, value in json_dict.items():
74+
if key not in {"__class__", "_is_compiled"}:
75+
try:
76+
setattr(instance, key, value)
77+
except AttributeError:
78+
pass
79+
80+
if class_name == "Table":
81+
instance._columns = [ # pylint: disable=protected-access
82+
ast.Column(ast.Name(col.name), _table=instance, _type=col.type)
83+
for col in instance._dj_node.columns # pylint: disable=protected-access
84+
]
85+
return instance

datajunction-server/datajunction_server/sql/parsing/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def __str__(self):
7474
def __deepcopy__(self, memo):
7575
return self
7676

77+
def __json_encode__(self):
78+
return {
79+
"__class__": self.__class__.__name__,
80+
}
81+
7782
@classmethod
7883
def __get_validators__(cls) -> Generator[AnyCallable, None, None]:
7984
"""

0 commit comments

Comments
 (0)