Support catchall nodes
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index 1bc8bb4..7159393 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -1,5 +1,7 @@
package ast
+// ast is the abstract syntax tree for the term mapper.
+
import (
"encoding/json"
)
diff --git a/pkg/matcher/matcher.go b/pkg/matcher/matcher.go
index 8780550..a26404b 100644
--- a/pkg/matcher/matcher.go
+++ b/pkg/matcher/matcher.go
@@ -25,22 +25,37 @@
// Replace replaces all occurrences of the pattern in the given node with the replacement
func (m *Matcher) Replace(node ast.Node) ast.Node {
+ // If this node matches the pattern, create replacement while preserving outer structure
if m.Match(node) {
- return m.cloneNode(m.replacement.Root)
+ switch node.(type) {
+ case *ast.Token:
+ // For Token nodes, preserve the Token wrapper but replace its wrap
+ newToken := &ast.Token{
+ Wrap: m.cloneNode(m.replacement.Root),
+ }
+ return newToken
+ default:
+ return m.cloneNode(m.replacement.Root)
+ }
}
+ // Otherwise recursively process children
switch n := node.(type) {
case *ast.Token:
- n.Wrap = m.Replace(n.Wrap)
- return n
+ newToken := &ast.Token{
+ Wrap: m.Replace(n.Wrap),
+ }
+ return newToken
case *ast.TermGroup:
newOperands := make([]ast.Node, len(n.Operands))
for i, op := range n.Operands {
newOperands[i] = m.Replace(op)
}
- n.Operands = newOperands
- return n
+ return &ast.TermGroup{
+ Operands: newOperands,
+ Relation: n.Relation,
+ }
case *ast.CatchallNode:
newNode := &ast.CatchallNode{
@@ -72,52 +87,80 @@
return false
}
- switch p := pattern.(type) {
- case *ast.Token:
- if t, ok := node.(*ast.Token); ok {
- return m.matchNode(t.Wrap, p.Wrap)
+ // Handle pattern being a Token
+ if pToken, ok := pattern.(*ast.Token); ok {
+ if nToken, ok := node.(*ast.Token); ok {
+ return m.matchNode(nToken.Wrap, pToken.Wrap)
}
return false
+ }
- case *ast.TermGroup:
- // If we're matching against a term, try to match it against any operand
- if t, ok := node.(*ast.Term); ok && p.Relation == ast.OrRelation {
- for _, op := range p.Operands {
- if m.matchNode(t, op) {
+ // Handle pattern being a Term
+ if pTerm, ok := pattern.(*ast.Term); ok {
+ // Direct term to term matching
+ if t, ok := node.(*ast.Term); ok {
+ return t.Foundry == pTerm.Foundry &&
+ t.Key == pTerm.Key &&
+ t.Layer == pTerm.Layer &&
+ t.Match == pTerm.Match &&
+ (pTerm.Value == "" || t.Value == pTerm.Value)
+ }
+ // If node is a Token, check its wrap
+ if tkn, ok := node.(*ast.Token); ok {
+ if tkn.Wrap == nil {
+ return false
+ }
+ return m.matchNode(tkn.Wrap, pattern)
+ }
+ // If node is a TermGroup, check its operands
+ if tg, ok := node.(*ast.TermGroup); ok {
+ for _, op := range tg.Operands {
+ if m.matchNode(op, pattern) {
+ return true
+ }
+ }
+ return false
+ }
+ // If node is a CatchallNode, check its wrap and operands
+ if c, ok := node.(*ast.CatchallNode); ok {
+ if c.Wrap != nil && m.matchNode(c.Wrap, pattern) {
+ return true
+ }
+ for _, op := range c.Operands {
+ if m.matchNode(op, pattern) {
+ return true
+ }
+ }
+ return false
+ }
+ return false
+ }
+
+ // Handle pattern being a TermGroup
+ if pGroup, ok := pattern.(*ast.TermGroup); ok {
+ // For OR relations, check if any operand matches the node
+ if pGroup.Relation == ast.OrRelation {
+ for _, pOp := range pGroup.Operands {
+ if m.matchNode(node, pOp) {
return true
}
}
return false
}
- // If we're matching against a term group
- if t, ok := node.(*ast.TermGroup); ok {
- if t.Relation != p.Relation {
+ // For AND relations, node must be a TermGroup with matching relation
+ if tg, ok := node.(*ast.TermGroup); ok {
+ if tg.Relation != pGroup.Relation {
return false
}
-
- if p.Relation == ast.OrRelation {
- // For OR relation, at least one operand must match
- for _, pOp := range p.Operands {
- for _, tOp := range t.Operands {
- if m.matchNode(tOp, pOp) {
- return true
- }
- }
- }
+ // Check that all pattern operands match in any order
+ if len(tg.Operands) < len(pGroup.Operands) {
return false
}
-
- // For AND relation, all pattern operands must match
- if len(t.Operands) < len(p.Operands) {
- return false
- }
-
- // Try to match pattern operands against node operands in any order
- matched := make([]bool, len(t.Operands))
- for _, pOp := range p.Operands {
+ matched := make([]bool, len(tg.Operands))
+ for _, pOp := range pGroup.Operands {
found := false
- for j, tOp := range t.Operands {
+ for j, tOp := range tg.Operands {
if !matched[j] && m.matchNode(tOp, pOp) {
matched[j] = true
found = true
@@ -130,65 +173,28 @@
}
return true
}
- return false
- case *ast.CatchallNode:
- // For catchall nodes, we need to check both wrap and operands
- if t, ok := node.(*ast.CatchallNode); ok {
- // If pattern has wrap, match it
- if p.Wrap != nil && !m.matchNode(t.Wrap, p.Wrap) {
+ // If node is a Token, check its wrap
+ if tkn, ok := node.(*ast.Token); ok {
+ if tkn.Wrap == nil {
return false
}
+ return m.matchNode(tkn.Wrap, pattern)
+ }
- // If pattern has operands, match them
- if len(p.Operands) > 0 {
- if len(t.Operands) < len(p.Operands) {
- return false
- }
-
- // Try to match pattern operands against node operands in any order
- matched := make([]bool, len(t.Operands))
- for _, pOp := range p.Operands {
- found := false
- for j, tOp := range t.Operands {
- if !matched[j] && m.matchNode(tOp, pOp) {
- matched[j] = true
- found = true
- break
- }
- }
- if !found {
- return false
- }
- }
+ // If node is a CatchallNode, check its wrap and operands
+ if c, ok := node.(*ast.CatchallNode); ok {
+ if c.Wrap != nil && m.matchNode(c.Wrap, pattern) {
return true
}
-
- // If no wrap or operands to match, it's a match
- return true
- }
- return false
-
- case *ast.Term:
- // If we're matching against a term group with OR relation,
- // try to match against any of its operands
- if t, ok := node.(*ast.TermGroup); ok && t.Relation == ast.OrRelation {
- for _, op := range t.Operands {
- if m.matchNode(op, p) {
+ for _, op := range c.Operands {
+ if m.matchNode(op, pattern) {
return true
}
}
return false
}
- // Direct term to term matching
- if t, ok := node.(*ast.Term); ok {
- return t.Foundry == p.Foundry &&
- t.Key == p.Key &&
- t.Layer == p.Layer &&
- t.Match == p.Match &&
- (p.Value == "" || t.Value == p.Value)
- }
return false
}
diff --git a/pkg/matcher/matcher_test.go b/pkg/matcher/matcher_test.go
index d058a7c..adc8784 100644
--- a/pkg/matcher/matcher_test.go
+++ b/pkg/matcher/matcher_test.go
@@ -1,6 +1,10 @@
package matcher
+// matcher is a function that takes a pattern and a node and returns true if the node matches the pattern.
+// It is used to match a pattern against a node in the AST.
+
import (
+ "encoding/json"
"testing"
"github.com/KorAP/KoralPipe-TermMapper2/pkg/ast"
@@ -76,7 +80,7 @@
expected: false,
},
{
- name: "Wrong node type",
+ name: "Nested node",
input: &ast.Token{
Wrap: &ast.Term{
Foundry: "opennlp",
@@ -85,7 +89,7 @@
Match: ast.MatchEqual,
},
},
- expected: false,
+ expected: true,
},
}
@@ -502,3 +506,176 @@
assert.True(t, m.Match(input1), "Should match with original order")
assert.True(t, m.Match(input2), "Should match with reversed order")
}
+
+func TestMatchWithUnknownNodes(t *testing.T) {
+ // Create a pattern that looks for a term with DET inside any structure
+ pattern := ast.Pattern{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ }
+
+ replacement := ast.Replacement{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "COMBINED_DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ }
+
+ m := NewMatcher(pattern, replacement)
+
+ tests := []struct {
+ name string
+ input ast.Node
+ expected bool
+ }{
+ {
+ name: "Match term inside unknown node with wrap",
+ input: &ast.CatchallNode{
+ NodeType: "koral:custom",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:custom",
+ "customField": "value"
+ }`),
+ Wrap: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ expected: true,
+ },
+ {
+ name: "Match term inside unknown node's operands",
+ input: &ast.CatchallNode{
+ NodeType: "koral:custom",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:custom",
+ "customField": "value"
+ }`),
+ Operands: []ast.Node{
+ &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ },
+ expected: true,
+ },
+ {
+ name: "No match in unknown node with different term",
+ input: &ast.CatchallNode{
+ NodeType: "koral:custom",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:custom",
+ "customField": "value"
+ }`),
+ Wrap: &ast.Term{
+ Foundry: "opennlp",
+ Key: "NOUN",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ expected: false,
+ },
+ {
+ name: "Match in deeply nested unknown nodes",
+ input: &ast.CatchallNode{
+ NodeType: "koral:outer",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:outer",
+ "outerField": "value"
+ }`),
+ Wrap: &ast.CatchallNode{
+ NodeType: "koral:inner",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:inner",
+ "innerField": "value"
+ }`),
+ Wrap: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ },
+ expected: true,
+ },
+ {
+ name: "Match in mixed known and unknown nodes",
+ input: &ast.Token{
+ Wrap: &ast.CatchallNode{
+ NodeType: "koral:custom",
+ RawContent: json.RawMessage(`{
+ "@type": "koral:custom",
+ "customField": "value"
+ }`),
+ Operands: []ast.Node{
+ &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ Relation: ast.AndRelation,
+ },
+ },
+ },
+ },
+ expected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := m.Match(tt.input)
+ assert.Equal(t, tt.expected, result)
+
+ if tt.expected {
+ // Test replacement when there's a match
+ replaced := m.Replace(tt.input)
+ // Verify the replacement happened somewhere in the structure
+ containsReplacement := false
+ var checkNode func(ast.Node)
+ checkNode = func(node ast.Node) {
+ switch n := node.(type) {
+ case *ast.Term:
+ if n.Key == "COMBINED_DET" {
+ containsReplacement = true
+ }
+ case *ast.Token:
+ if n.Wrap != nil {
+ checkNode(n.Wrap)
+ }
+ case *ast.TermGroup:
+ for _, op := range n.Operands {
+ checkNode(op)
+ }
+ case *ast.CatchallNode:
+ if n.Wrap != nil {
+ checkNode(n.Wrap)
+ }
+ for _, op := range n.Operands {
+ checkNode(op)
+ }
+ }
+ }
+ checkNode(replaced)
+ assert.True(t, containsReplacement, "Replacement should be found in the result")
+ }
+ })
+ }
+}
diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go
index d0ff222..db20bc8 100644
--- a/pkg/parser/parser.go
+++ b/pkg/parser/parser.go
@@ -1,5 +1,8 @@
package parser
+// parser is a function that takes a JSON string and returns an AST node.
+// It is used to parse a JSON string into an AST node.
+
import (
"encoding/json"
"fmt"