Support arbitrary koral nodes in AST
diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go
index be11da9..d0ff222 100644
--- a/pkg/parser/parser.go
+++ b/pkg/parser/parser.go
@@ -27,6 +27,9 @@
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
+ if raw.Type == "" {
+ return nil, fmt.Errorf("missing @type field")
+ }
return parseNode(raw)
}
@@ -82,7 +85,45 @@
}, nil
default:
- return nil, fmt.Errorf("unknown node type: %s", raw.Type)
+ // Store the original JSON content
+ rawContent, err := json.Marshal(raw)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal unknown node: %w", err)
+ }
+
+ // Create a catchall node
+ catchall := &ast.CatchallNode{
+ NodeType: raw.Type,
+ RawContent: rawContent,
+ }
+
+ // Parse wrap if present
+ if raw.Wrap != nil {
+ var wrapRaw rawNode
+ if err := json.Unmarshal(raw.Wrap, &wrapRaw); err != nil {
+ return nil, fmt.Errorf("failed to parse wrap in unknown node: %w", err)
+ }
+ wrap, err := parseNode(wrapRaw)
+ if err != nil {
+ return nil, err
+ }
+ catchall.Wrap = wrap
+ }
+
+ // Parse operands if present
+ if len(raw.Operands) > 0 {
+ operands := make([]ast.Node, len(raw.Operands))
+ for i, op := range raw.Operands {
+ node, err := parseNode(op)
+ if err != nil {
+ return nil, err
+ }
+ operands[i] = node
+ }
+ catchall.Operands = operands
+ }
+
+ return catchall, nil
}
}
@@ -122,6 +163,38 @@
Value: n.Value,
}
+ case *ast.CatchallNode:
+ // For catchall nodes, use the stored raw content
+ if n.RawContent != nil {
+ // If we have operands or wrap that were modified, we need to update the raw content
+ if len(n.Operands) > 0 || n.Wrap != nil {
+ var raw rawNode
+ if err := json.Unmarshal(n.RawContent, &raw); err != nil {
+ return rawNode{}
+ }
+
+ // Update operands if present
+ if len(n.Operands) > 0 {
+ raw.Operands = make([]rawNode, len(n.Operands))
+ for i, op := range n.Operands {
+ raw.Operands[i] = nodeToRaw(op)
+ }
+ }
+
+ // Update wrap if present
+ if n.Wrap != nil {
+ raw.Wrap = json.RawMessage(nodeToRaw(n.Wrap).toJSON())
+ }
+
+ return raw
+ }
+ // If no modifications, return the original content as is
+ var raw rawNode
+ _ = json.Unmarshal(n.RawContent, &raw)
+ return raw
+ }
+ return rawNode{}
+
default:
return rawNode{}
}
diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go
index 464c497..d2964ce 100644
--- a/pkg/parser/parser_test.go
+++ b/pkg/parser/parser_test.go
@@ -183,12 +183,16 @@
wantErr: true,
},
{
- name: "Invalid node type",
+ name: "Unknown node type",
input: `{
"@type": "koral:unknown",
"key": "value"
}`,
- wantErr: true,
+ expected: &ast.CatchallNode{
+ NodeType: "koral:unknown",
+ RawContent: json.RawMessage(`{"@type":"koral:unknown","key":"value"}`),
+ },
+ wantErr: false,
},
}
@@ -273,6 +277,21 @@
}`,
wantErr: false,
},
+ {
+ name: "Serialize unknown node type",
+ input: &ast.CatchallNode{
+ NodeType: "koral:unknown",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:unknown",
+ "key": "value"
+}`),
+ },
+ expected: `{
+ "@type": "koral:unknown",
+ "key": "value"
+}`,
+ wantErr: false,
+ },
}
for _, tt := range tests {
@@ -338,3 +357,53 @@
require.NoError(t, err)
assert.Equal(t, expected, actual)
}
+
+func TestRoundTripUnknownType(t *testing.T) {
+ // Test that parsing and then serializing an unknown node type preserves the structure
+ input := `{
+ "@type": "koral:unknown",
+ "key": "value",
+ "wrap": {
+ "@type": "koral:term",
+ "foundry": "opennlp",
+ "key": "DET",
+ "layer": "p",
+ "match": "match:eq"
+ },
+ "operands": [
+ {
+ "@type": "koral:term",
+ "foundry": "opennlp",
+ "key": "AdjType",
+ "layer": "m",
+ "match": "match:eq",
+ "value": "Pdt"
+ }
+ ]
+ }`
+
+ // Parse JSON to AST
+ node, err := ParseJSON([]byte(input))
+ require.NoError(t, err)
+
+ // Check that it's a CatchallNode
+ catchall, ok := node.(*ast.CatchallNode)
+ require.True(t, ok)
+ assert.Equal(t, "koral:unknown", catchall.NodeType)
+
+ // Check that wrap and operands were parsed
+ require.NotNil(t, catchall.Wrap)
+ require.Len(t, catchall.Operands, 1)
+
+ // Serialize AST back to JSON
+ output, err := SerializeToJSON(node)
+ require.NoError(t, err)
+
+ // Compare JSON objects
+ var expected, actual interface{}
+ err = json.Unmarshal([]byte(input), &expected)
+ require.NoError(t, err)
+ err = json.Unmarshal(output, &actual)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+}