Support arbitrary koral nodes in AST
diff --git a/.gitignore b/.gitignore
index 9aa3b8e..4fb5c14 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
testdata/sandbox
-cmd/termmapper
\ No newline at end of file
+cmd/termmapper
+README.md
\ No newline at end of file
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index aac7016..1bc8bb4 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -1,5 +1,9 @@
package ast
+import (
+ "encoding/json"
+)
+
// NodeType represents the type of a node in the AST
type NodeType string
@@ -71,3 +75,15 @@
type Replacement struct {
Root Node
}
+
+// CatchallNode represents any node type not explicitly handled
+type CatchallNode struct {
+ NodeType string // The original @type value
+ RawContent json.RawMessage // The original JSON content
+ Wrap Node // Optional wrapped node
+ Operands []Node // Optional operands
+}
+
+func (c *CatchallNode) Type() NodeType {
+ return NodeType(c.NodeType)
+}
diff --git a/pkg/matcher/matcher.go b/pkg/matcher/matcher.go
index 46b4cc4..8780550 100644
--- a/pkg/matcher/matcher.go
+++ b/pkg/matcher/matcher.go
@@ -42,6 +42,22 @@
n.Operands = newOperands
return n
+ case *ast.CatchallNode:
+ newNode := &ast.CatchallNode{
+ NodeType: n.NodeType,
+ RawContent: n.RawContent,
+ }
+ if n.Wrap != nil {
+ newNode.Wrap = m.Replace(n.Wrap)
+ }
+ if len(n.Operands) > 0 {
+ newNode.Operands = make([]ast.Node, len(n.Operands))
+ for i, op := range n.Operands {
+ newNode.Operands[i] = m.Replace(op)
+ }
+ }
+ return newNode
+
default:
return node
}
@@ -61,6 +77,7 @@
if t, ok := node.(*ast.Token); ok {
return m.matchNode(t.Wrap, p.Wrap)
}
+ return false
case *ast.TermGroup:
// If we're matching against a term, try to match it against any operand
@@ -113,6 +130,44 @@
}
return true
}
+ return false
+
+ case *ast.CatchallNode:
+ // For catchall nodes, we need to check both wrap and operands
+ if t, ok := node.(*ast.CatchallNode); ok {
+ // If pattern has wrap, match it
+ if p.Wrap != nil && !m.matchNode(t.Wrap, p.Wrap) {
+ return false
+ }
+
+ // If pattern has operands, match them
+ if len(p.Operands) > 0 {
+ if len(t.Operands) < len(p.Operands) {
+ return false
+ }
+
+ // Try to match pattern operands against node operands in any order
+ matched := make([]bool, len(t.Operands))
+ for _, pOp := range p.Operands {
+ found := false
+ for j, tOp := range t.Operands {
+ if !matched[j] && m.matchNode(tOp, pOp) {
+ matched[j] = true
+ found = true
+ break
+ }
+ }
+ if !found {
+ return false
+ }
+ }
+ return true
+ }
+
+ // If no wrap or operands to match, it's a match
+ return true
+ }
+ return false
case *ast.Term:
// If we're matching against a term group with OR relation,
@@ -134,6 +189,7 @@
t.Match == p.Match &&
(p.Value == "" || t.Value == p.Value)
}
+ return false
}
return false
@@ -170,6 +226,22 @@
Value: n.Value,
}
+ case *ast.CatchallNode:
+ newNode := &ast.CatchallNode{
+ NodeType: n.NodeType,
+ RawContent: n.RawContent,
+ }
+ if n.Wrap != nil {
+ newNode.Wrap = m.cloneNode(n.Wrap)
+ }
+ if len(n.Operands) > 0 {
+ newNode.Operands = make([]ast.Node, len(n.Operands))
+ for i, op := range n.Operands {
+ newNode.Operands[i] = m.cloneNode(op)
+ }
+ }
+ return newNode
+
default:
return nil
}
diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go
index be11da9..d0ff222 100644
--- a/pkg/parser/parser.go
+++ b/pkg/parser/parser.go
@@ -27,6 +27,9 @@
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
+ if raw.Type == "" {
+ return nil, fmt.Errorf("missing @type field")
+ }
return parseNode(raw)
}
@@ -82,7 +85,45 @@
}, nil
default:
- return nil, fmt.Errorf("unknown node type: %s", raw.Type)
+ // Store the original JSON content
+ rawContent, err := json.Marshal(raw)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal unknown node: %w", err)
+ }
+
+ // Create a catchall node
+ catchall := &ast.CatchallNode{
+ NodeType: raw.Type,
+ RawContent: rawContent,
+ }
+
+ // Parse wrap if present
+ if raw.Wrap != nil {
+ var wrapRaw rawNode
+ if err := json.Unmarshal(raw.Wrap, &wrapRaw); err != nil {
+ return nil, fmt.Errorf("failed to parse wrap in unknown node: %w", err)
+ }
+ wrap, err := parseNode(wrapRaw)
+ if err != nil {
+ return nil, err
+ }
+ catchall.Wrap = wrap
+ }
+
+ // Parse operands if present
+ if len(raw.Operands) > 0 {
+ operands := make([]ast.Node, len(raw.Operands))
+ for i, op := range raw.Operands {
+ node, err := parseNode(op)
+ if err != nil {
+ return nil, err
+ }
+ operands[i] = node
+ }
+ catchall.Operands = operands
+ }
+
+ return catchall, nil
}
}
@@ -122,6 +163,38 @@
Value: n.Value,
}
+ case *ast.CatchallNode:
+ // For catchall nodes, use the stored raw content
+ if n.RawContent != nil {
+ // If we have operands or wrap that were modified, we need to update the raw content
+ if len(n.Operands) > 0 || n.Wrap != nil {
+ var raw rawNode
+ if err := json.Unmarshal(n.RawContent, &raw); err != nil {
+ return rawNode{}
+ }
+
+ // Update operands if present
+ if len(n.Operands) > 0 {
+ raw.Operands = make([]rawNode, len(n.Operands))
+ for i, op := range n.Operands {
+ raw.Operands[i] = nodeToRaw(op)
+ }
+ }
+
+ // Update wrap if present
+ if n.Wrap != nil {
+ raw.Wrap = json.RawMessage(nodeToRaw(n.Wrap).toJSON())
+ }
+
+ return raw
+ }
+ // If no modifications, return the original content as is
+ var raw rawNode
+ _ = json.Unmarshal(n.RawContent, &raw)
+ return raw
+ }
+ return rawNode{}
+
default:
return rawNode{}
}
diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go
index 464c497..d2964ce 100644
--- a/pkg/parser/parser_test.go
+++ b/pkg/parser/parser_test.go
@@ -183,12 +183,16 @@
wantErr: true,
},
{
- name: "Invalid node type",
+ name: "Unknown node type",
input: `{
"@type": "koral:unknown",
"key": "value"
}`,
- wantErr: true,
+ expected: &ast.CatchallNode{
+ NodeType: "koral:unknown",
+ RawContent: json.RawMessage(`{"@type":"koral:unknown","key":"value"}`),
+ },
+ wantErr: false,
},
}
@@ -273,6 +277,21 @@
}`,
wantErr: false,
},
+ {
+ name: "Serialize unknown node type",
+ input: &ast.CatchallNode{
+ NodeType: "koral:unknown",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:unknown",
+ "key": "value"
+}`),
+ },
+ expected: `{
+ "@type": "koral:unknown",
+ "key": "value"
+}`,
+ wantErr: false,
+ },
}
for _, tt := range tests {
@@ -338,3 +357,53 @@
require.NoError(t, err)
assert.Equal(t, expected, actual)
}
+
+func TestRoundTripUnknownType(t *testing.T) {
+ // Test that parsing and then serializing an unknown node type preserves the structure
+ input := `{
+ "@type": "koral:unknown",
+ "key": "value",
+ "wrap": {
+ "@type": "koral:term",
+ "foundry": "opennlp",
+ "key": "DET",
+ "layer": "p",
+ "match": "match:eq"
+ },
+ "operands": [
+ {
+ "@type": "koral:term",
+ "foundry": "opennlp",
+ "key": "AdjType",
+ "layer": "m",
+ "match": "match:eq",
+ "value": "Pdt"
+ }
+ ]
+ }`
+
+ // Parse JSON to AST
+ node, err := ParseJSON([]byte(input))
+ require.NoError(t, err)
+
+ // Check that it's a CatchallNode
+ catchall, ok := node.(*ast.CatchallNode)
+ require.True(t, ok)
+ assert.Equal(t, "koral:unknown", catchall.NodeType)
+
+ // Check that wrap and operands were parsed
+ require.NotNil(t, catchall.Wrap)
+ require.Len(t, catchall.Operands, 1)
+
+ // Serialize AST back to JSON
+ output, err := SerializeToJSON(node)
+ require.NoError(t, err)
+
+ // Compare JSON objects
+ var expected, actual interface{}
+ err = json.Unmarshal([]byte(input), &expected)
+ require.NoError(t, err)
+ err = json.Unmarshal(output, &actual)
+ require.NoError(t, err)
+ assert.Equal(t, expected, actual)
+}