blob: bf5aaa59051b9b0e77709dc166b19ad3bf1eef05 [file] [log] [blame]
Akronb7e1f352025-05-16 15:45:23 +02001package matcher
2
3import (
Akrond5850f82025-05-23 16:44:44 +02004 "fmt"
5
Akronfa55bb22025-05-26 15:10:42 +02006 "github.com/KorAP/KoralPipe-TermMapper/ast"
Akronb7e1f352025-05-16 15:45:23 +02007)
8
9// Matcher handles pattern matching and replacement in the AST
10type Matcher struct {
11 pattern ast.Pattern
12 replacement ast.Replacement
13}
14
Akrond5850f82025-05-23 16:44:44 +020015// validateNode checks if a node is valid for pattern/replacement ASTs
16func validateNode(node ast.Node) error {
17 if node == nil {
18 return fmt.Errorf("nil node")
19 }
20
21 switch n := node.(type) {
22 case *ast.Token:
23 if n.Wrap != nil {
24 return validateNode(n.Wrap)
25 }
26 return nil
27 case *ast.Term:
28 return nil
29 case *ast.TermGroup:
30 if len(n.Operands) == 0 {
31 return fmt.Errorf("empty term group")
32 }
33 for _, op := range n.Operands {
34 if err := validateNode(op); err != nil {
35 return fmt.Errorf("invalid operand: %v", err)
36 }
37 }
38 return nil
39 case *ast.CatchallNode:
40 return fmt.Errorf("catchall nodes are not allowed in pattern/replacement ASTs")
41 default:
42 return fmt.Errorf("unknown node type: %T", node)
43 }
44}
45
Akronb7e1f352025-05-16 15:45:23 +020046// NewMatcher creates a new Matcher with the given pattern and replacement
Akrond5850f82025-05-23 16:44:44 +020047func NewMatcher(pattern ast.Pattern, replacement ast.Replacement) (*Matcher, error) {
48 if err := validateNode(pattern.Root); err != nil {
49 return nil, fmt.Errorf("invalid pattern: %v", err)
50 }
51 if err := validateNode(replacement.Root); err != nil {
52 return nil, fmt.Errorf("invalid replacement: %v", err)
53 }
Akronb7e1f352025-05-16 15:45:23 +020054 return &Matcher{
55 pattern: pattern,
56 replacement: replacement,
Akrond5850f82025-05-23 16:44:44 +020057 }, nil
Akronb7e1f352025-05-16 15:45:23 +020058}
59
60// Match checks if the given node matches the pattern
61func (m *Matcher) Match(node ast.Node) bool {
62 return m.matchNode(node, m.pattern.Root)
63}
64
65// Replace replaces all occurrences of the pattern in the given node with the replacement
66func (m *Matcher) Replace(node ast.Node) ast.Node {
Akrond5850f82025-05-23 16:44:44 +020067 // First step: Create complete structure with replacements
68 replaced := m.replaceNode(node)
69 // Second step: Simplify the structure
70 simplified := m.simplifyNode(replaced)
71 // If the input was a Token, ensure the output is also a Token
72 if _, isToken := node.(*ast.Token); isToken {
73 if _, isToken := simplified.(*ast.Token); !isToken {
74 return &ast.Token{Wrap: simplified}
75 }
76 }
77 return simplified
78}
79
80// replaceNode creates a complete structure with replacements
81func (m *Matcher) replaceNode(node ast.Node) ast.Node {
82 if node == nil {
83 return nil
84 }
85
86 // First handle Token nodes specially to preserve their structure
87 if token, ok := node.(*ast.Token); ok {
88 if token.Wrap == nil {
89 return token
90 }
91 // Process the wrapped node
92 wrap := m.replaceNode(token.Wrap)
93 return &ast.Token{Wrap: wrap}
94 }
95
Akron6f455152025-05-27 09:03:00 +020096 // Handle TermGroup nodes
97 if tg, ok := node.(*ast.TermGroup); ok {
98 // Check if any operand matches the pattern
99 hasMatch := false
100 newOperands := make([]ast.Node, 0, len(tg.Operands))
101 for _, op := range tg.Operands {
102 if !hasMatch && m.matchNode(op, m.pattern.Root) {
103 newOperands = append(newOperands, m.cloneNode(m.replacement.Root))
104 hasMatch = true
105 } else {
106 newOperands = append(newOperands, m.replaceNode(op))
Akrond5850f82025-05-23 16:44:44 +0200107 }
108 }
Akron6f455152025-05-27 09:03:00 +0200109 // If we found a match, return the modified TermGroup
110 if hasMatch {
Akrond5850f82025-05-23 16:44:44 +0200111 return &ast.TermGroup{
112 Operands: newOperands,
Akron6f455152025-05-27 09:03:00 +0200113 Relation: tg.Relation,
Akrond5850f82025-05-23 16:44:44 +0200114 }
115 }
Akron6f455152025-05-27 09:03:00 +0200116 // If this TermGroup matches the pattern exactly, replace it
117 if m.matchNode(node, m.pattern.Root) {
118 return m.cloneNode(m.replacement.Root)
Akronb7e1f352025-05-16 15:45:23 +0200119 }
Akron6f455152025-05-27 09:03:00 +0200120 // Otherwise, return the modified TermGroup
Akronbf5149c2025-05-20 15:53:41 +0200121 return &ast.TermGroup{
122 Operands: newOperands,
Akron6f455152025-05-27 09:03:00 +0200123 Relation: tg.Relation,
Akronbf5149c2025-05-20 15:53:41 +0200124 }
Akron6f455152025-05-27 09:03:00 +0200125 }
Akronb7e1f352025-05-16 15:45:23 +0200126
Akron6f455152025-05-27 09:03:00 +0200127 // Handle CatchallNode nodes
128 if c, ok := node.(*ast.CatchallNode); ok {
Akron32958422025-05-16 16:33:05 +0200129 newNode := &ast.CatchallNode{
Akron6f455152025-05-27 09:03:00 +0200130 NodeType: c.NodeType,
131 RawContent: c.RawContent,
Akron32958422025-05-16 16:33:05 +0200132 }
Akron6f455152025-05-27 09:03:00 +0200133 if c.Wrap != nil {
134 newNode.Wrap = m.replaceNode(c.Wrap)
Akron32958422025-05-16 16:33:05 +0200135 }
Akron6f455152025-05-27 09:03:00 +0200136 if len(c.Operands) > 0 {
137 newNode.Operands = make([]ast.Node, len(c.Operands))
138 for i, op := range c.Operands {
Akrond5850f82025-05-23 16:44:44 +0200139 newNode.Operands[i] = m.replaceNode(op)
140 }
141 }
142 return newNode
Akrond5850f82025-05-23 16:44:44 +0200143 }
Akron6f455152025-05-27 09:03:00 +0200144
145 // If this node matches the pattern exactly, replace it
146 if m.matchNode(node, m.pattern.Root) {
147 return m.cloneNode(m.replacement.Root)
148 }
149
150 return node
Akrond5850f82025-05-23 16:44:44 +0200151}
152
153// simplifyNode removes unnecessary wrappers and empty nodes
154func (m *Matcher) simplifyNode(node ast.Node) ast.Node {
155 if node == nil {
156 return nil
157 }
158
159 switch n := node.(type) {
160 case *ast.Token:
161 if n.Wrap == nil {
162 return nil
163 }
164 simplified := m.simplifyNode(n.Wrap)
165 if simplified == nil {
166 return nil
167 }
168 return &ast.Token{Wrap: simplified}
169
170 case *ast.TermGroup:
171 // First simplify all operands
172 simplified := make([]ast.Node, 0, len(n.Operands))
173 for _, op := range n.Operands {
174 if s := m.simplifyNode(op); s != nil {
175 simplified = append(simplified, s)
176 }
177 }
178
179 // Handle special cases
180 if len(simplified) == 0 {
181 return nil
182 }
183 if len(simplified) == 1 {
Akron21e47762025-07-03 14:11:11 +0200184 return simplified[0]
Akrond5850f82025-05-23 16:44:44 +0200185 }
186
187 return &ast.TermGroup{
188 Operands: simplified,
189 Relation: n.Relation,
190 }
191
192 case *ast.CatchallNode:
193 newNode := &ast.CatchallNode{
194 NodeType: n.NodeType,
195 RawContent: n.RawContent,
196 }
197 if n.Wrap != nil {
198 newNode.Wrap = m.simplifyNode(n.Wrap)
199 }
200 if len(n.Operands) > 0 {
201 simplified := make([]ast.Node, 0, len(n.Operands))
202 for _, op := range n.Operands {
203 if s := m.simplifyNode(op); s != nil {
204 simplified = append(simplified, s)
205 }
206 }
207 if len(simplified) > 0 {
208 newNode.Operands = simplified
Akron32958422025-05-16 16:33:05 +0200209 }
210 }
211 return newNode
212
Akronb7e1f352025-05-16 15:45:23 +0200213 default:
214 return node
215 }
216}
217
218// matchNode recursively checks if two nodes match
219func (m *Matcher) matchNode(node, pattern ast.Node) bool {
220 if pattern == nil {
221 return true
222 }
223 if node == nil {
224 return false
225 }
226
Akron6f455152025-05-27 09:03:00 +0200227 // Handle wrapped nodes (Token and CatchallNode)
228 if m.tryMatchWrapped(node, pattern) {
229 return true
Akronbf5149c2025-05-20 15:53:41 +0200230 }
Akronb7e1f352025-05-16 15:45:23 +0200231
Akron6f455152025-05-27 09:03:00 +0200232 switch p := pattern.(type) {
233 case *ast.Token:
234 if n, ok := node.(*ast.Token); ok {
235 return m.matchNode(n.Wrap, p.Wrap)
Akronbf5149c2025-05-20 15:53:41 +0200236 }
Akronbf5149c2025-05-20 15:53:41 +0200237
Akron6f455152025-05-27 09:03:00 +0200238 case *ast.Term:
239 return m.matchTerm(node, p)
240
241 case *ast.TermGroup:
242 if p.Relation == ast.OrRelation {
243 // For OR relations, check if any operand matches
244 for _, pOp := range p.Operands {
Akronbf5149c2025-05-20 15:53:41 +0200245 if m.matchNode(node, pOp) {
Akronb7e1f352025-05-16 15:45:23 +0200246 return true
247 }
248 }
Akron6f455152025-05-27 09:03:00 +0200249 } else if tg, ok := node.(*ast.TermGroup); ok && tg.Relation == p.Relation {
250 // For AND relations, all pattern operands must match in any order
251 return m.matchAndTermGroup(tg, p)
Akronb7e1f352025-05-16 15:45:23 +0200252 }
Akronb7e1f352025-05-16 15:45:23 +0200253 }
254
255 return false
256}
257
Akron6f455152025-05-27 09:03:00 +0200258// tryMatchWrapped attempts to match a node that might wrap other nodes
259func (m *Matcher) tryMatchWrapped(node, pattern ast.Node) bool {
260 switch n := node.(type) {
261 case *ast.Token:
Akron21e47762025-07-03 14:11:11 +0200262 return n.Wrap != nil && m.matchNode(n.Wrap, pattern)
Akron6f455152025-05-27 09:03:00 +0200263 case *ast.CatchallNode:
264 if n.Wrap != nil && m.matchNode(n.Wrap, pattern) {
265 return true
266 }
267 for _, op := range n.Operands {
268 if m.matchNode(op, pattern) {
269 return true
270 }
271 }
272 case *ast.TermGroup:
273 for _, op := range n.Operands {
274 if m.matchNode(op, pattern) {
275 return true
276 }
277 }
278 }
279 return false
280}
281
282// matchTerm checks if a node matches a term pattern
283func (m *Matcher) matchTerm(node ast.Node, pattern *ast.Term) bool {
284 if t, ok := node.(*ast.Term); ok {
285 return t.Foundry == pattern.Foundry &&
286 t.Key == pattern.Key &&
287 t.Layer == pattern.Layer &&
288 t.Match == pattern.Match &&
289 (pattern.Value == "" || t.Value == pattern.Value)
290 }
291 return m.tryMatchWrapped(node, pattern)
292}
293
294// matchAndTermGroup checks if a TermGroup matches an AND pattern
295func (m *Matcher) matchAndTermGroup(node *ast.TermGroup, pattern *ast.TermGroup) bool {
296 if len(node.Operands) < len(pattern.Operands) {
297 return false
298 }
299 matched := make([]bool, len(node.Operands))
300 for _, pOp := range pattern.Operands {
301 found := false
302 for j, tOp := range node.Operands {
303 if !matched[j] && m.matchNode(tOp, pOp) {
304 matched[j] = true
305 found = true
306 break
307 }
308 }
309 if !found {
310 return false
311 }
312 }
313 return true
314}
315
Akronb7e1f352025-05-16 15:45:23 +0200316// cloneNode creates a deep copy of a node
317func (m *Matcher) cloneNode(node ast.Node) ast.Node {
318 if node == nil {
319 return nil
320 }
321
322 switch n := node.(type) {
323 case *ast.Token:
324 return &ast.Token{
325 Wrap: m.cloneNode(n.Wrap),
326 }
327
328 case *ast.TermGroup:
329 operands := make([]ast.Node, len(n.Operands))
330 for i, op := range n.Operands {
331 operands[i] = m.cloneNode(op)
332 }
333 return &ast.TermGroup{
334 Operands: operands,
335 Relation: n.Relation,
336 }
337
338 case *ast.Term:
339 return &ast.Term{
340 Foundry: n.Foundry,
341 Key: n.Key,
342 Layer: n.Layer,
343 Match: n.Match,
344 Value: n.Value,
345 }
346
Akron32958422025-05-16 16:33:05 +0200347 case *ast.CatchallNode:
348 newNode := &ast.CatchallNode{
349 NodeType: n.NodeType,
350 RawContent: n.RawContent,
351 }
352 if n.Wrap != nil {
353 newNode.Wrap = m.cloneNode(n.Wrap)
354 }
355 if len(n.Operands) > 0 {
356 newNode.Operands = make([]ast.Node, len(n.Operands))
357 for i, op := range n.Operands {
358 newNode.Operands[i] = m.cloneNode(op)
359 }
360 }
361 return newNode
362
Akronb7e1f352025-05-16 15:45:23 +0200363 default:
364 return nil
365 }
366}