blob: d0a2259e8d30a6b608d63f4d1d89d4e35b3f795c [file] [log] [blame]
Akronb7e1f352025-05-16 15:45:23 +02001package matcher
2
3import (
Akrond5850f82025-05-23 16:44:44 +02004 "fmt"
5
Akronb7e1f352025-05-16 15:45:23 +02006 "github.com/KorAP/KoralPipe-TermMapper2/pkg/ast"
7)
8
9// Matcher handles pattern matching and replacement in the AST
10type Matcher struct {
11 pattern ast.Pattern
12 replacement ast.Replacement
13}
14
Akrond5850f82025-05-23 16:44:44 +020015// validateNode checks if a node is valid for pattern/replacement ASTs
16func validateNode(node ast.Node) error {
17 if node == nil {
18 return fmt.Errorf("nil node")
19 }
20
21 switch n := node.(type) {
22 case *ast.Token:
23 if n.Wrap != nil {
24 return validateNode(n.Wrap)
25 }
26 return nil
27 case *ast.Term:
28 return nil
29 case *ast.TermGroup:
30 if len(n.Operands) == 0 {
31 return fmt.Errorf("empty term group")
32 }
33 for _, op := range n.Operands {
34 if err := validateNode(op); err != nil {
35 return fmt.Errorf("invalid operand: %v", err)
36 }
37 }
38 return nil
39 case *ast.CatchallNode:
40 return fmt.Errorf("catchall nodes are not allowed in pattern/replacement ASTs")
41 default:
42 return fmt.Errorf("unknown node type: %T", node)
43 }
44}
45
Akronb7e1f352025-05-16 15:45:23 +020046// NewMatcher creates a new Matcher with the given pattern and replacement
Akrond5850f82025-05-23 16:44:44 +020047func NewMatcher(pattern ast.Pattern, replacement ast.Replacement) (*Matcher, error) {
48 if err := validateNode(pattern.Root); err != nil {
49 return nil, fmt.Errorf("invalid pattern: %v", err)
50 }
51 if err := validateNode(replacement.Root); err != nil {
52 return nil, fmt.Errorf("invalid replacement: %v", err)
53 }
Akronb7e1f352025-05-16 15:45:23 +020054 return &Matcher{
55 pattern: pattern,
56 replacement: replacement,
Akrond5850f82025-05-23 16:44:44 +020057 }, nil
Akronb7e1f352025-05-16 15:45:23 +020058}
59
60// Match checks if the given node matches the pattern
61func (m *Matcher) Match(node ast.Node) bool {
62 return m.matchNode(node, m.pattern.Root)
63}
64
65// Replace replaces all occurrences of the pattern in the given node with the replacement
66func (m *Matcher) Replace(node ast.Node) ast.Node {
Akrond5850f82025-05-23 16:44:44 +020067 // First step: Create complete structure with replacements
68 replaced := m.replaceNode(node)
69 // Second step: Simplify the structure
70 simplified := m.simplifyNode(replaced)
71 // If the input was a Token, ensure the output is also a Token
72 if _, isToken := node.(*ast.Token); isToken {
73 if _, isToken := simplified.(*ast.Token); !isToken {
74 return &ast.Token{Wrap: simplified}
75 }
76 }
77 return simplified
78}
79
80// replaceNode creates a complete structure with replacements
81func (m *Matcher) replaceNode(node ast.Node) ast.Node {
82 if node == nil {
83 return nil
84 }
85
86 // First handle Token nodes specially to preserve their structure
87 if token, ok := node.(*ast.Token); ok {
88 if token.Wrap == nil {
89 return token
90 }
91 // Process the wrapped node
92 wrap := m.replaceNode(token.Wrap)
93 return &ast.Token{Wrap: wrap}
94 }
95
96 // If this node matches the pattern
Akronb7e1f352025-05-16 15:45:23 +020097 if m.Match(node) {
Akrond5850f82025-05-23 16:44:44 +020098 // For TermGroups that contain a matching Term, preserve unmatched operands
99 if tg, ok := node.(*ast.TermGroup); ok {
100 // Check if any operand matches the pattern exactly
101 hasExactMatch := false
102 for _, op := range tg.Operands {
103 if m.matchNode(op, m.pattern.Root) {
104 hasExactMatch = true
105 break
106 }
Akronbf5149c2025-05-20 15:53:41 +0200107 }
Akrond5850f82025-05-23 16:44:44 +0200108
109 // If we have an exact match, replace matching operands
110 if hasExactMatch {
111 hasMatch := false
112 newOperands := make([]ast.Node, 0, len(tg.Operands))
113 for _, op := range tg.Operands {
114 if m.matchNode(op, m.pattern.Root) {
115 if !hasMatch {
116 newOperands = append(newOperands, m.cloneNode(m.replacement.Root))
117 hasMatch = true
118 } else {
119 newOperands = append(newOperands, m.replaceNode(op))
120 }
121 } else {
122 newOperands = append(newOperands, m.replaceNode(op))
123 }
124 }
125 return &ast.TermGroup{
126 Operands: newOperands,
127 Relation: tg.Relation,
128 }
129 }
130 // Otherwise, replace the entire TermGroup
Akronbf5149c2025-05-20 15:53:41 +0200131 return m.cloneNode(m.replacement.Root)
132 }
Akrond5850f82025-05-23 16:44:44 +0200133 // For other nodes, return the replacement
134 return m.cloneNode(m.replacement.Root)
Akronb7e1f352025-05-16 15:45:23 +0200135 }
136
Akronbf5149c2025-05-20 15:53:41 +0200137 // Otherwise recursively process children
Akronb7e1f352025-05-16 15:45:23 +0200138 switch n := node.(type) {
Akronb7e1f352025-05-16 15:45:23 +0200139 case *ast.TermGroup:
Akrond5850f82025-05-23 16:44:44 +0200140 // Check if any operand matches the pattern exactly
141 hasExactMatch := false
142 for _, op := range n.Operands {
143 if m.matchNode(op, m.pattern.Root) {
144 hasExactMatch = true
145 break
146 }
147 }
148
149 // If we have an exact match, replace matching operands
150 if hasExactMatch {
151 hasMatch := false
152 newOperands := make([]ast.Node, 0, len(n.Operands))
153 for _, op := range n.Operands {
154 if m.matchNode(op, m.pattern.Root) {
155 if !hasMatch {
156 newOperands = append(newOperands, m.cloneNode(m.replacement.Root))
157 hasMatch = true
158 } else {
159 newOperands = append(newOperands, m.replaceNode(op))
160 }
161 } else {
162 newOperands = append(newOperands, m.replaceNode(op))
163 }
164 }
165 return &ast.TermGroup{
166 Operands: newOperands,
167 Relation: n.Relation,
168 }
169 }
170 // Otherwise, recursively process operands
Akronb7e1f352025-05-16 15:45:23 +0200171 newOperands := make([]ast.Node, len(n.Operands))
172 for i, op := range n.Operands {
Akrond5850f82025-05-23 16:44:44 +0200173 newOperands[i] = m.replaceNode(op)
Akronb7e1f352025-05-16 15:45:23 +0200174 }
Akronbf5149c2025-05-20 15:53:41 +0200175 return &ast.TermGroup{
176 Operands: newOperands,
177 Relation: n.Relation,
178 }
Akronb7e1f352025-05-16 15:45:23 +0200179
Akron32958422025-05-16 16:33:05 +0200180 case *ast.CatchallNode:
181 newNode := &ast.CatchallNode{
182 NodeType: n.NodeType,
183 RawContent: n.RawContent,
184 }
185 if n.Wrap != nil {
Akrond5850f82025-05-23 16:44:44 +0200186 newNode.Wrap = m.replaceNode(n.Wrap)
Akron32958422025-05-16 16:33:05 +0200187 }
188 if len(n.Operands) > 0 {
189 newNode.Operands = make([]ast.Node, len(n.Operands))
190 for i, op := range n.Operands {
Akrond5850f82025-05-23 16:44:44 +0200191 newNode.Operands[i] = m.replaceNode(op)
192 }
193 }
194 return newNode
195
196 default:
197 return node
198 }
199}
200
201// simplifyNode removes unnecessary wrappers and empty nodes
202func (m *Matcher) simplifyNode(node ast.Node) ast.Node {
203 if node == nil {
204 return nil
205 }
206
207 switch n := node.(type) {
208 case *ast.Token:
209 if n.Wrap == nil {
210 return nil
211 }
212 simplified := m.simplifyNode(n.Wrap)
213 if simplified == nil {
214 return nil
215 }
216 return &ast.Token{Wrap: simplified}
217
218 case *ast.TermGroup:
219 // First simplify all operands
220 simplified := make([]ast.Node, 0, len(n.Operands))
221 for _, op := range n.Operands {
222 if s := m.simplifyNode(op); s != nil {
223 simplified = append(simplified, s)
224 }
225 }
226
227 // Handle special cases
228 if len(simplified) == 0 {
229 return nil
230 }
231 if len(simplified) == 1 {
232 // If we have a single operand, return it directly
233 // But only if we're not inside a Token
234 if _, isToken := node.(*ast.Token); !isToken {
235 return simplified[0]
236 }
237 }
238
239 return &ast.TermGroup{
240 Operands: simplified,
241 Relation: n.Relation,
242 }
243
244 case *ast.CatchallNode:
245 newNode := &ast.CatchallNode{
246 NodeType: n.NodeType,
247 RawContent: n.RawContent,
248 }
249 if n.Wrap != nil {
250 newNode.Wrap = m.simplifyNode(n.Wrap)
251 }
252 if len(n.Operands) > 0 {
253 simplified := make([]ast.Node, 0, len(n.Operands))
254 for _, op := range n.Operands {
255 if s := m.simplifyNode(op); s != nil {
256 simplified = append(simplified, s)
257 }
258 }
259 if len(simplified) > 0 {
260 newNode.Operands = simplified
Akron32958422025-05-16 16:33:05 +0200261 }
262 }
263 return newNode
264
Akronb7e1f352025-05-16 15:45:23 +0200265 default:
266 return node
267 }
268}
269
270// matchNode recursively checks if two nodes match
271func (m *Matcher) matchNode(node, pattern ast.Node) bool {
272 if pattern == nil {
273 return true
274 }
275 if node == nil {
276 return false
277 }
278
Akronbf5149c2025-05-20 15:53:41 +0200279 // Handle pattern being a Token
280 if pToken, ok := pattern.(*ast.Token); ok {
281 if nToken, ok := node.(*ast.Token); ok {
282 return m.matchNode(nToken.Wrap, pToken.Wrap)
Akronb7e1f352025-05-16 15:45:23 +0200283 }
Akron32958422025-05-16 16:33:05 +0200284 return false
Akronbf5149c2025-05-20 15:53:41 +0200285 }
Akronb7e1f352025-05-16 15:45:23 +0200286
Akronbf5149c2025-05-20 15:53:41 +0200287 // Handle pattern being a Term
288 if pTerm, ok := pattern.(*ast.Term); ok {
289 // Direct term to term matching
290 if t, ok := node.(*ast.Term); ok {
291 return t.Foundry == pTerm.Foundry &&
292 t.Key == pTerm.Key &&
293 t.Layer == pTerm.Layer &&
294 t.Match == pTerm.Match &&
295 (pTerm.Value == "" || t.Value == pTerm.Value)
296 }
297 // If node is a Token, check its wrap
298 if tkn, ok := node.(*ast.Token); ok {
299 if tkn.Wrap == nil {
300 return false
301 }
302 return m.matchNode(tkn.Wrap, pattern)
303 }
304 // If node is a TermGroup, check its operands
305 if tg, ok := node.(*ast.TermGroup); ok {
306 for _, op := range tg.Operands {
307 if m.matchNode(op, pattern) {
308 return true
309 }
310 }
311 return false
312 }
313 // If node is a CatchallNode, check its wrap and operands
314 if c, ok := node.(*ast.CatchallNode); ok {
315 if c.Wrap != nil && m.matchNode(c.Wrap, pattern) {
316 return true
317 }
318 for _, op := range c.Operands {
319 if m.matchNode(op, pattern) {
320 return true
321 }
322 }
323 return false
324 }
325 return false
326 }
327
328 // Handle pattern being a TermGroup
329 if pGroup, ok := pattern.(*ast.TermGroup); ok {
330 // For OR relations, check if any operand matches the node
331 if pGroup.Relation == ast.OrRelation {
332 for _, pOp := range pGroup.Operands {
333 if m.matchNode(node, pOp) {
Akronb7e1f352025-05-16 15:45:23 +0200334 return true
335 }
336 }
337 return false
338 }
339
Akronbf5149c2025-05-20 15:53:41 +0200340 // For AND relations, node must be a TermGroup with matching relation
341 if tg, ok := node.(*ast.TermGroup); ok {
342 if tg.Relation != pGroup.Relation {
Akronb7e1f352025-05-16 15:45:23 +0200343 return false
344 }
Akronbf5149c2025-05-20 15:53:41 +0200345 // Check that all pattern operands match in any order
346 if len(tg.Operands) < len(pGroup.Operands) {
Akronb7e1f352025-05-16 15:45:23 +0200347 return false
348 }
Akronbf5149c2025-05-20 15:53:41 +0200349 matched := make([]bool, len(tg.Operands))
350 for _, pOp := range pGroup.Operands {
Akronb7e1f352025-05-16 15:45:23 +0200351 found := false
Akronbf5149c2025-05-20 15:53:41 +0200352 for j, tOp := range tg.Operands {
Akronb7e1f352025-05-16 15:45:23 +0200353 if !matched[j] && m.matchNode(tOp, pOp) {
354 matched[j] = true
355 found = true
356 break
357 }
358 }
359 if !found {
360 return false
361 }
362 }
363 return true
364 }
Akron32958422025-05-16 16:33:05 +0200365
Akronbf5149c2025-05-20 15:53:41 +0200366 // If node is a Token, check its wrap
367 if tkn, ok := node.(*ast.Token); ok {
368 if tkn.Wrap == nil {
Akron32958422025-05-16 16:33:05 +0200369 return false
370 }
Akronbf5149c2025-05-20 15:53:41 +0200371 return m.matchNode(tkn.Wrap, pattern)
372 }
Akron32958422025-05-16 16:33:05 +0200373
Akronbf5149c2025-05-20 15:53:41 +0200374 // If node is a CatchallNode, check its wrap and operands
375 if c, ok := node.(*ast.CatchallNode); ok {
376 if c.Wrap != nil && m.matchNode(c.Wrap, pattern) {
Akron32958422025-05-16 16:33:05 +0200377 return true
378 }
Akronbf5149c2025-05-20 15:53:41 +0200379 for _, op := range c.Operands {
380 if m.matchNode(op, pattern) {
Akronb7e1f352025-05-16 15:45:23 +0200381 return true
382 }
383 }
384 return false
385 }
386
Akron32958422025-05-16 16:33:05 +0200387 return false
Akronb7e1f352025-05-16 15:45:23 +0200388 }
389
390 return false
391}
392
393// cloneNode creates a deep copy of a node
394func (m *Matcher) cloneNode(node ast.Node) ast.Node {
395 if node == nil {
396 return nil
397 }
398
399 switch n := node.(type) {
400 case *ast.Token:
401 return &ast.Token{
402 Wrap: m.cloneNode(n.Wrap),
403 }
404
405 case *ast.TermGroup:
406 operands := make([]ast.Node, len(n.Operands))
407 for i, op := range n.Operands {
408 operands[i] = m.cloneNode(op)
409 }
410 return &ast.TermGroup{
411 Operands: operands,
412 Relation: n.Relation,
413 }
414
415 case *ast.Term:
416 return &ast.Term{
417 Foundry: n.Foundry,
418 Key: n.Key,
419 Layer: n.Layer,
420 Match: n.Match,
421 Value: n.Value,
422 }
423
Akron32958422025-05-16 16:33:05 +0200424 case *ast.CatchallNode:
425 newNode := &ast.CatchallNode{
426 NodeType: n.NodeType,
427 RawContent: n.RawContent,
428 }
429 if n.Wrap != nil {
430 newNode.Wrap = m.cloneNode(n.Wrap)
431 }
432 if len(n.Operands) > 0 {
433 newNode.Operands = make([]ast.Node, len(n.Operands))
434 for i, op := range n.Operands {
435 newNode.Operands[i] = m.cloneNode(op)
436 }
437 }
438 return newNode
439
Akronb7e1f352025-05-16 15:45:23 +0200440 default:
441 return nil
442 }
443}