Disallow non-supported nodes in pattern and replacement
diff --git a/pkg/mapper/mapper.go b/pkg/mapper/mapper.go
index f540525..8350823 100644
--- a/pkg/mapper/mapper.go
+++ b/pkg/mapper/mapper.go
@@ -70,7 +70,7 @@
}
// ApplyMappings applies the specified mapping rules to a JSON object
-func (m *Mapper) ApplyMappings(mappingID string, opts MappingOptions, jsonData interface{}) (interface{}, error) {
+func (m *Mapper) ApplyMappings(mappingID string, opts MappingOptions, jsonData any) (any, error) {
// Validate mapping ID
if _, exists := m.mappingLists[mappingID]; !exists {
return nil, fmt.Errorf("mapping list with ID %s not found", mappingID)
@@ -95,9 +95,13 @@
return nil, fmt.Errorf("failed to parse JSON into AST: %w", err)
}
- // Extract the inner node if it's a token
+ // Store whether the input was a Token
+ isToken := false
+ var tokenWrap ast.Node
if token, ok := node.(*ast.Token); ok {
- node = token.Wrap
+ isToken = true
+ tokenWrap = token.Wrap
+ node = tokenWrap
}
// Apply each rule to the AST
@@ -130,12 +134,20 @@
}
// Create matcher and apply replacement
- m := matcher.NewMatcher(ast.Pattern{Root: pattern}, ast.Replacement{Root: replacement})
+ m, err := matcher.NewMatcher(ast.Pattern{Root: pattern}, ast.Replacement{Root: replacement})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create matcher: %w", err)
+ }
node = m.Replace(node)
}
- // Wrap the result in a token
- result := &ast.Token{Wrap: node}
+ // Wrap the result in a token if the input was a token
+ var result ast.Node
+ if isToken {
+ result = &ast.Token{Wrap: node}
+ } else {
+ result = node
+ }
// Convert AST back to JSON
resultBytes, err := parser.SerializeToJSON(result)
diff --git a/pkg/mapper/mapper_test.go b/pkg/mapper/mapper_test.go
index b5a8950..40ec88e 100644
--- a/pkg/mapper/mapper_test.go
+++ b/pkg/mapper/mapper_test.go
@@ -6,6 +6,8 @@
"path/filepath"
"testing"
+ "github.com/KorAP/KoralPipe-TermMapper2/pkg/ast"
+ "github.com/KorAP/KoralPipe-TermMapper2/pkg/matcher"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -226,3 +228,151 @@
})
}
}
+
+func TestMatchComplexPatterns(t *testing.T) {
+ tests := []struct {
+ name string
+ pattern ast.Pattern
+ replacement ast.Replacement
+ input ast.Node
+ expected ast.Node
+ }{
+ {
+ name: "Deep nested pattern with mixed operators",
+ pattern: ast.Pattern{
+ Root: &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Key: "A",
+ Match: ast.MatchEqual,
+ },
+ &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Key: "B",
+ Match: ast.MatchEqual,
+ },
+ &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Key: "C",
+ Match: ast.MatchEqual,
+ },
+ &ast.Term{
+ Key: "D",
+ Match: ast.MatchEqual,
+ },
+ },
+ Relation: ast.AndRelation,
+ },
+ },
+ Relation: ast.OrRelation,
+ },
+ },
+ Relation: ast.AndRelation,
+ },
+ },
+ replacement: ast.Replacement{
+ Root: &ast.Term{
+ Key: "RESULT",
+ Match: ast.MatchEqual,
+ },
+ },
+ input: &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Key: "A",
+ Match: ast.MatchEqual,
+ },
+ &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Key: "C",
+ Match: ast.MatchEqual,
+ },
+ &ast.Term{
+ Key: "D",
+ Match: ast.MatchEqual,
+ },
+ },
+ Relation: ast.AndRelation,
+ },
+ },
+ Relation: ast.AndRelation,
+ },
+ expected: &ast.Term{
+ Key: "RESULT",
+ Match: ast.MatchEqual,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m, err := matcher.NewMatcher(tt.pattern, tt.replacement)
+ require.NoError(t, err)
+ result := m.Replace(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestInvalidPatternReplacement(t *testing.T) {
+ // Create a temporary config file
+ tmpDir := t.TempDir()
+ configFile := filepath.Join(tmpDir, "test-config.yaml")
+
+ configContent := `- id: test-mapper
+ foundryA: opennlp
+ layerA: p
+ foundryB: upos
+ layerB: p
+ mappings:
+ - "[PIDAT] <> [opennlp/p=PIDAT & opennlp/p=AdjType:Pdt]"`
+
+ err := os.WriteFile(configFile, []byte(configContent), 0644)
+ require.NoError(t, err)
+
+ // Create a new mapper
+ m, err := NewMapper(configFile)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ input string
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "Invalid input - empty term group",
+ input: `{
+ "@type": "koral:token",
+ "wrap": {
+ "@type": "koral:termGroup",
+ "operands": [],
+ "relation": "relation:and"
+ }
+ }`,
+ expectError: true,
+ errorMsg: "failed to parse JSON into AST: error parsing wrapped node: term group must have at least one operand",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var inputData any
+ err := json.Unmarshal([]byte(tt.input), &inputData)
+ require.NoError(t, err)
+
+ result, err := m.ApplyMappings("test-mapper", MappingOptions{Direction: AtoB}, inputData)
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Equal(t, tt.errorMsg, err.Error())
+ assert.Nil(t, result)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, result)
+ }
+ })
+ }
+}
diff --git a/pkg/matcher/matcher.go b/pkg/matcher/matcher.go
index a26404b..d0a2259 100644
--- a/pkg/matcher/matcher.go
+++ b/pkg/matcher/matcher.go
@@ -1,6 +1,8 @@
package matcher
import (
+ "fmt"
+
"github.com/KorAP/KoralPipe-TermMapper2/pkg/ast"
)
@@ -10,12 +12,49 @@
replacement ast.Replacement
}
+// validateNode checks if a node is valid for pattern/replacement ASTs
+func validateNode(node ast.Node) error {
+ if node == nil {
+ return fmt.Errorf("nil node")
+ }
+
+ switch n := node.(type) {
+ case *ast.Token:
+ if n.Wrap != nil {
+ return validateNode(n.Wrap)
+ }
+ return nil
+ case *ast.Term:
+ return nil
+ case *ast.TermGroup:
+ if len(n.Operands) == 0 {
+ return fmt.Errorf("empty term group")
+ }
+ for _, op := range n.Operands {
+ if err := validateNode(op); err != nil {
+ return fmt.Errorf("invalid operand: %v", err)
+ }
+ }
+ return nil
+ case *ast.CatchallNode:
+ return fmt.Errorf("catchall nodes are not allowed in pattern/replacement ASTs")
+ default:
+ return fmt.Errorf("unknown node type: %T", node)
+ }
+}
+
// NewMatcher creates a new Matcher with the given pattern and replacement
-func NewMatcher(pattern ast.Pattern, replacement ast.Replacement) *Matcher {
+func NewMatcher(pattern ast.Pattern, replacement ast.Replacement) (*Matcher, error) {
+ if err := validateNode(pattern.Root); err != nil {
+ return nil, fmt.Errorf("invalid pattern: %v", err)
+ }
+ if err := validateNode(replacement.Root); err != nil {
+ return nil, fmt.Errorf("invalid replacement: %v", err)
+ }
return &Matcher{
pattern: pattern,
replacement: replacement,
- }
+ }, nil
}
// Match checks if the given node matches the pattern
@@ -25,32 +64,113 @@
// Replace replaces all occurrences of the pattern in the given node with the replacement
func (m *Matcher) Replace(node ast.Node) ast.Node {
- // If this node matches the pattern, create replacement while preserving outer structure
+ // First step: Create complete structure with replacements
+ replaced := m.replaceNode(node)
+ // Second step: Simplify the structure
+ simplified := m.simplifyNode(replaced)
+ // If the input was a Token, ensure the output is also a Token
+ if _, isToken := node.(*ast.Token); isToken {
+ if _, isToken := simplified.(*ast.Token); !isToken {
+ return &ast.Token{Wrap: simplified}
+ }
+ }
+ return simplified
+}
+
+// replaceNode creates a complete structure with replacements
+func (m *Matcher) replaceNode(node ast.Node) ast.Node {
+ if node == nil {
+ return nil
+ }
+
+ // First handle Token nodes specially to preserve their structure
+ if token, ok := node.(*ast.Token); ok {
+ if token.Wrap == nil {
+ return token
+ }
+ // Process the wrapped node
+ wrap := m.replaceNode(token.Wrap)
+ return &ast.Token{Wrap: wrap}
+ }
+
+ // If this node matches the pattern
if m.Match(node) {
- switch node.(type) {
- case *ast.Token:
- // For Token nodes, preserve the Token wrapper but replace its wrap
- newToken := &ast.Token{
- Wrap: m.cloneNode(m.replacement.Root),
+ // For TermGroups that contain a matching Term, preserve unmatched operands
+ if tg, ok := node.(*ast.TermGroup); ok {
+ // Check if any operand matches the pattern exactly
+ hasExactMatch := false
+ for _, op := range tg.Operands {
+ if m.matchNode(op, m.pattern.Root) {
+ hasExactMatch = true
+ break
+ }
}
- return newToken
- default:
+
+ // If we have an exact match, replace matching operands
+ if hasExactMatch {
+ hasMatch := false
+ newOperands := make([]ast.Node, 0, len(tg.Operands))
+ for _, op := range tg.Operands {
+ if m.matchNode(op, m.pattern.Root) {
+ if !hasMatch {
+ newOperands = append(newOperands, m.cloneNode(m.replacement.Root))
+ hasMatch = true
+ } else {
+ newOperands = append(newOperands, m.replaceNode(op))
+ }
+ } else {
+ newOperands = append(newOperands, m.replaceNode(op))
+ }
+ }
+ return &ast.TermGroup{
+ Operands: newOperands,
+ Relation: tg.Relation,
+ }
+ }
+ // Otherwise, replace the entire TermGroup
return m.cloneNode(m.replacement.Root)
}
+ // For other nodes, return the replacement
+ return m.cloneNode(m.replacement.Root)
}
// Otherwise recursively process children
switch n := node.(type) {
- case *ast.Token:
- newToken := &ast.Token{
- Wrap: m.Replace(n.Wrap),
- }
- return newToken
-
case *ast.TermGroup:
+ // Check if any operand matches the pattern exactly
+ hasExactMatch := false
+ for _, op := range n.Operands {
+ if m.matchNode(op, m.pattern.Root) {
+ hasExactMatch = true
+ break
+ }
+ }
+
+ // If we have an exact match, replace matching operands
+ if hasExactMatch {
+ hasMatch := false
+ newOperands := make([]ast.Node, 0, len(n.Operands))
+ for _, op := range n.Operands {
+ if m.matchNode(op, m.pattern.Root) {
+ if !hasMatch {
+ newOperands = append(newOperands, m.cloneNode(m.replacement.Root))
+ hasMatch = true
+ } else {
+ newOperands = append(newOperands, m.replaceNode(op))
+ }
+ } else {
+ newOperands = append(newOperands, m.replaceNode(op))
+ }
+ }
+ return &ast.TermGroup{
+ Operands: newOperands,
+ Relation: n.Relation,
+ }
+ }
+ // Otherwise, recursively process operands
newOperands := make([]ast.Node, len(n.Operands))
for i, op := range n.Operands {
- newOperands[i] = m.Replace(op)
+ newOperands[i] = m.replaceNode(op)
}
return &ast.TermGroup{
Operands: newOperands,
@@ -63,12 +183,81 @@
RawContent: n.RawContent,
}
if n.Wrap != nil {
- newNode.Wrap = m.Replace(n.Wrap)
+ newNode.Wrap = m.replaceNode(n.Wrap)
}
if len(n.Operands) > 0 {
newNode.Operands = make([]ast.Node, len(n.Operands))
for i, op := range n.Operands {
- newNode.Operands[i] = m.Replace(op)
+ newNode.Operands[i] = m.replaceNode(op)
+ }
+ }
+ return newNode
+
+ default:
+ return node
+ }
+}
+
+// simplifyNode removes unnecessary wrappers and empty nodes
+func (m *Matcher) simplifyNode(node ast.Node) ast.Node {
+ if node == nil {
+ return nil
+ }
+
+ switch n := node.(type) {
+ case *ast.Token:
+ if n.Wrap == nil {
+ return nil
+ }
+ simplified := m.simplifyNode(n.Wrap)
+ if simplified == nil {
+ return nil
+ }
+ return &ast.Token{Wrap: simplified}
+
+ case *ast.TermGroup:
+ // First simplify all operands
+ simplified := make([]ast.Node, 0, len(n.Operands))
+ for _, op := range n.Operands {
+ if s := m.simplifyNode(op); s != nil {
+ simplified = append(simplified, s)
+ }
+ }
+
+ // Handle special cases
+ if len(simplified) == 0 {
+ return nil
+ }
+ if len(simplified) == 1 {
+ // If we have a single operand, return it directly
+ // But only if we're not inside a Token
+ if _, isToken := node.(*ast.Token); !isToken {
+ return simplified[0]
+ }
+ }
+
+ return &ast.TermGroup{
+ Operands: simplified,
+ Relation: n.Relation,
+ }
+
+ case *ast.CatchallNode:
+ newNode := &ast.CatchallNode{
+ NodeType: n.NodeType,
+ RawContent: n.RawContent,
+ }
+ if n.Wrap != nil {
+ newNode.Wrap = m.simplifyNode(n.Wrap)
+ }
+ if len(n.Operands) > 0 {
+ simplified := make([]ast.Node, 0, len(n.Operands))
+ for _, op := range n.Operands {
+ if s := m.simplifyNode(op); s != nil {
+ simplified = append(simplified, s)
+ }
+ }
+ if len(simplified) > 0 {
+ newNode.Operands = simplified
}
}
return newNode
diff --git a/pkg/matcher/matcher_test.go b/pkg/matcher/matcher_test.go
index adc8784..47af1f9 100644
--- a/pkg/matcher/matcher_test.go
+++ b/pkg/matcher/matcher_test.go
@@ -11,6 +11,130 @@
"github.com/stretchr/testify/assert"
)
+func TestNewMatcherValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ pattern ast.Pattern
+ replacement ast.Replacement
+ expectedError string
+ }{
+ {
+ name: "Valid pattern and replacement",
+ pattern: ast.Pattern{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ replacement: ast.Replacement{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "COMBINED_DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ expectedError: "",
+ },
+ {
+ name: "Invalid pattern - CatchallNode",
+ pattern: ast.Pattern{
+ Root: &ast.CatchallNode{
+ NodeType: "custom",
+ },
+ },
+ replacement: ast.Replacement{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ expectedError: "invalid pattern: catchall nodes are not allowed in pattern/replacement ASTs",
+ },
+ {
+ name: "Invalid replacement - CatchallNode",
+ pattern: ast.Pattern{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ replacement: ast.Replacement{
+ Root: &ast.CatchallNode{
+ NodeType: "custom",
+ },
+ },
+ expectedError: "invalid replacement: catchall nodes are not allowed in pattern/replacement ASTs",
+ },
+ {
+ name: "Invalid pattern - Empty TermGroup",
+ pattern: ast.Pattern{
+ Root: &ast.TermGroup{
+ Operands: []ast.Node{},
+ Relation: ast.AndRelation,
+ },
+ },
+ replacement: ast.Replacement{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ expectedError: "invalid pattern: empty term group",
+ },
+ {
+ name: "Invalid pattern - Nested CatchallNode",
+ pattern: ast.Pattern{
+ Root: &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ &ast.CatchallNode{
+ NodeType: "custom",
+ },
+ },
+ Relation: ast.AndRelation,
+ },
+ },
+ replacement: ast.Replacement{
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ },
+ expectedError: "invalid pattern: invalid operand: catchall nodes are not allowed in pattern/replacement ASTs",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ matcher, err := NewMatcher(tt.pattern, tt.replacement)
+ if tt.expectedError != "" {
+ assert.Error(t, err)
+ assert.Equal(t, tt.expectedError, err.Error())
+ assert.Nil(t, matcher)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, matcher)
+ }
+ })
+ }
+}
+
func TestMatchSimplePattern(t *testing.T) {
// Create a simple pattern: match a term with DET
pattern := ast.Pattern{
@@ -32,7 +156,9 @@
},
}
- m := NewMatcher(pattern, replacement)
+ m, err := NewMatcher(pattern, replacement)
+ assert.NoError(t, err)
+ assert.NotNil(t, m)
tests := []struct {
name string
@@ -149,7 +275,9 @@
},
}
- m := NewMatcher(pattern, replacement)
+ m, err := NewMatcher(pattern, replacement)
+ assert.NoError(t, err)
+ assert.NotNil(t, m)
tests := []struct {
name string
@@ -263,25 +391,12 @@
}
func TestReplace(t *testing.T) {
- // Create pattern and replacement
pattern := ast.Pattern{
- Root: &ast.TermGroup{
- Operands: []ast.Node{
- &ast.Term{
- Foundry: "opennlp",
- Key: "DET",
- Layer: "p",
- Match: ast.MatchEqual,
- },
- &ast.Term{
- Foundry: "opennlp",
- Key: "AdjType",
- Layer: "m",
- Match: ast.MatchEqual,
- Value: "Pdt",
- },
- },
- Relation: ast.AndRelation,
+ Root: &ast.Term{
+ Foundry: "opennlp",
+ Key: "DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
},
}
@@ -294,7 +409,9 @@
},
}
- m := NewMatcher(pattern, replacement)
+ m, err := NewMatcher(pattern, replacement)
+ assert.NoError(t, err)
+ assert.NotNil(t, m)
tests := []struct {
name string
@@ -321,11 +438,23 @@
},
Relation: ast.AndRelation,
},
- expected: &ast.Term{
- Foundry: "opennlp",
- Key: "COMBINED_DET",
- Layer: "p",
- Match: ast.MatchEqual,
+ expected: &ast.TermGroup{
+ Operands: []ast.Node{
+ &ast.Term{
+ Foundry: "opennlp",
+ Key: "COMBINED_DET",
+ Layer: "p",
+ Match: ast.MatchEqual,
+ },
+ &ast.Term{
+ Foundry: "opennlp",
+ Key: "AdjType",
+ Layer: "m",
+ Match: ast.MatchEqual,
+ Value: "Pdt",
+ },
+ },
+ Relation: ast.AndRelation,
},
},
{
@@ -431,7 +560,6 @@
}
func TestMatchNodeOrder(t *testing.T) {
- // Test that operands can match in any order
pattern := ast.Pattern{
Root: &ast.TermGroup{
Operands: []ast.Node{
@@ -462,7 +590,9 @@
},
}
- m := NewMatcher(pattern, replacement)
+ m, err := NewMatcher(pattern, replacement)
+ assert.NoError(t, err)
+ assert.NotNil(t, m)
// Test with operands in different orders
input1 := &ast.TermGroup{
@@ -508,7 +638,6 @@
}
func TestMatchWithUnknownNodes(t *testing.T) {
- // Create a pattern that looks for a term with DET inside any structure
pattern := ast.Pattern{
Root: &ast.Term{
Foundry: "opennlp",
@@ -527,7 +656,9 @@
},
}
- m := NewMatcher(pattern, replacement)
+ m, err := NewMatcher(pattern, replacement)
+ assert.NoError(t, err)
+ assert.NotNil(t, m)
tests := []struct {
name string