blob: 1630b1a823b5e4e6cb0c16d7bb6f0ce178e10061 [file] [log] [blame]
Akronb7e1f352025-05-16 15:45:23 +02001package matcher
2
3import (
Akrond5850f82025-05-23 16:44:44 +02004 "fmt"
5
Akronfa55bb22025-05-26 15:10:42 +02006 "github.com/KorAP/KoralPipe-TermMapper/ast"
Akronb7e1f352025-05-16 15:45:23 +02007)
8
9// Matcher handles pattern matching and replacement in the AST
10type Matcher struct {
11 pattern ast.Pattern
12 replacement ast.Replacement
13}
14
Akrond5850f82025-05-23 16:44:44 +020015// validateNode checks if a node is valid for pattern/replacement ASTs
16func validateNode(node ast.Node) error {
17 if node == nil {
18 return fmt.Errorf("nil node")
19 }
20
21 switch n := node.(type) {
22 case *ast.Token:
23 if n.Wrap != nil {
24 return validateNode(n.Wrap)
25 }
26 return nil
27 case *ast.Term:
28 return nil
29 case *ast.TermGroup:
30 if len(n.Operands) == 0 {
31 return fmt.Errorf("empty term group")
32 }
33 for _, op := range n.Operands {
34 if err := validateNode(op); err != nil {
35 return fmt.Errorf("invalid operand: %v", err)
36 }
37 }
38 return nil
39 case *ast.CatchallNode:
40 return fmt.Errorf("catchall nodes are not allowed in pattern/replacement ASTs")
41 default:
42 return fmt.Errorf("unknown node type: %T", node)
43 }
44}
45
Akronb7e1f352025-05-16 15:45:23 +020046// NewMatcher creates a new Matcher with the given pattern and replacement
Akrond5850f82025-05-23 16:44:44 +020047func NewMatcher(pattern ast.Pattern, replacement ast.Replacement) (*Matcher, error) {
48 if err := validateNode(pattern.Root); err != nil {
49 return nil, fmt.Errorf("invalid pattern: %v", err)
50 }
51 if err := validateNode(replacement.Root); err != nil {
52 return nil, fmt.Errorf("invalid replacement: %v", err)
53 }
Akronb7e1f352025-05-16 15:45:23 +020054 return &Matcher{
55 pattern: pattern,
56 replacement: replacement,
Akrond5850f82025-05-23 16:44:44 +020057 }, nil
Akronb7e1f352025-05-16 15:45:23 +020058}
59
60// Match checks if the given node matches the pattern
61func (m *Matcher) Match(node ast.Node) bool {
62 return m.matchNode(node, m.pattern.Root)
63}
64
65// Replace replaces all occurrences of the pattern in the given node with the replacement
66func (m *Matcher) Replace(node ast.Node) ast.Node {
Akrond5850f82025-05-23 16:44:44 +020067 // First step: Create complete structure with replacements
68 replaced := m.replaceNode(node)
69 // Second step: Simplify the structure
70 simplified := m.simplifyNode(replaced)
71 // If the input was a Token, ensure the output is also a Token
72 if _, isToken := node.(*ast.Token); isToken {
73 if _, isToken := simplified.(*ast.Token); !isToken {
74 return &ast.Token{Wrap: simplified}
75 }
76 }
77 return simplified
78}
79
80// replaceNode creates a complete structure with replacements
81func (m *Matcher) replaceNode(node ast.Node) ast.Node {
82 if node == nil {
83 return nil
84 }
85
86 // First handle Token nodes specially to preserve their structure
87 if token, ok := node.(*ast.Token); ok {
88 if token.Wrap == nil {
89 return token
90 }
91 // Process the wrapped node
92 wrap := m.replaceNode(token.Wrap)
93 return &ast.Token{Wrap: wrap}
94 }
95
Akron6f455152025-05-27 09:03:00 +020096 // Handle TermGroup nodes
97 if tg, ok := node.(*ast.TermGroup); ok {
98 // Check if any operand matches the pattern
99 hasMatch := false
100 newOperands := make([]ast.Node, 0, len(tg.Operands))
101 for _, op := range tg.Operands {
102 if !hasMatch && m.matchNode(op, m.pattern.Root) {
103 newOperands = append(newOperands, m.cloneNode(m.replacement.Root))
104 hasMatch = true
105 } else {
106 newOperands = append(newOperands, m.replaceNode(op))
Akrond5850f82025-05-23 16:44:44 +0200107 }
108 }
Akron6f455152025-05-27 09:03:00 +0200109 // If we found a match, return the modified TermGroup
110 if hasMatch {
Akrond5850f82025-05-23 16:44:44 +0200111 return &ast.TermGroup{
112 Operands: newOperands,
Akron6f455152025-05-27 09:03:00 +0200113 Relation: tg.Relation,
Akrond5850f82025-05-23 16:44:44 +0200114 }
115 }
Akron6f455152025-05-27 09:03:00 +0200116 // If this TermGroup matches the pattern exactly, replace it
117 if m.matchNode(node, m.pattern.Root) {
118 return m.cloneNode(m.replacement.Root)
Akronb7e1f352025-05-16 15:45:23 +0200119 }
Akron6f455152025-05-27 09:03:00 +0200120 // Otherwise, return the modified TermGroup
Akronbf5149c2025-05-20 15:53:41 +0200121 return &ast.TermGroup{
122 Operands: newOperands,
Akron6f455152025-05-27 09:03:00 +0200123 Relation: tg.Relation,
Akronbf5149c2025-05-20 15:53:41 +0200124 }
Akron6f455152025-05-27 09:03:00 +0200125 }
Akronb7e1f352025-05-16 15:45:23 +0200126
Akron6f455152025-05-27 09:03:00 +0200127 // Handle CatchallNode nodes
128 if c, ok := node.(*ast.CatchallNode); ok {
Akron32958422025-05-16 16:33:05 +0200129 newNode := &ast.CatchallNode{
Akron6f455152025-05-27 09:03:00 +0200130 NodeType: c.NodeType,
131 RawContent: c.RawContent,
Akron32958422025-05-16 16:33:05 +0200132 }
Akron6f455152025-05-27 09:03:00 +0200133 if c.Wrap != nil {
134 newNode.Wrap = m.replaceNode(c.Wrap)
Akron32958422025-05-16 16:33:05 +0200135 }
Akron6f455152025-05-27 09:03:00 +0200136 if len(c.Operands) > 0 {
137 newNode.Operands = make([]ast.Node, len(c.Operands))
138 for i, op := range c.Operands {
Akrond5850f82025-05-23 16:44:44 +0200139 newNode.Operands[i] = m.replaceNode(op)
140 }
141 }
142 return newNode
Akrond5850f82025-05-23 16:44:44 +0200143 }
Akron6f455152025-05-27 09:03:00 +0200144
145 // If this node matches the pattern exactly, replace it
146 if m.matchNode(node, m.pattern.Root) {
147 return m.cloneNode(m.replacement.Root)
148 }
149
150 return node
Akrond5850f82025-05-23 16:44:44 +0200151}
152
153// simplifyNode removes unnecessary wrappers and empty nodes
154func (m *Matcher) simplifyNode(node ast.Node) ast.Node {
155 if node == nil {
156 return nil
157 }
158
159 switch n := node.(type) {
160 case *ast.Token:
161 if n.Wrap == nil {
162 return nil
163 }
164 simplified := m.simplifyNode(n.Wrap)
165 if simplified == nil {
166 return nil
167 }
168 return &ast.Token{Wrap: simplified}
169
170 case *ast.TermGroup:
171 // First simplify all operands
172 simplified := make([]ast.Node, 0, len(n.Operands))
173 for _, op := range n.Operands {
174 if s := m.simplifyNode(op); s != nil {
175 simplified = append(simplified, s)
176 }
177 }
178
179 // Handle special cases
180 if len(simplified) == 0 {
181 return nil
182 }
183 if len(simplified) == 1 {
184 // If we have a single operand, return it directly
185 // But only if we're not inside a Token
186 if _, isToken := node.(*ast.Token); !isToken {
187 return simplified[0]
188 }
189 }
190
191 return &ast.TermGroup{
192 Operands: simplified,
193 Relation: n.Relation,
194 }
195
196 case *ast.CatchallNode:
197 newNode := &ast.CatchallNode{
198 NodeType: n.NodeType,
199 RawContent: n.RawContent,
200 }
201 if n.Wrap != nil {
202 newNode.Wrap = m.simplifyNode(n.Wrap)
203 }
204 if len(n.Operands) > 0 {
205 simplified := make([]ast.Node, 0, len(n.Operands))
206 for _, op := range n.Operands {
207 if s := m.simplifyNode(op); s != nil {
208 simplified = append(simplified, s)
209 }
210 }
211 if len(simplified) > 0 {
212 newNode.Operands = simplified
Akron32958422025-05-16 16:33:05 +0200213 }
214 }
215 return newNode
216
Akronb7e1f352025-05-16 15:45:23 +0200217 default:
218 return node
219 }
220}
221
222// matchNode recursively checks if two nodes match
223func (m *Matcher) matchNode(node, pattern ast.Node) bool {
224 if pattern == nil {
225 return true
226 }
227 if node == nil {
228 return false
229 }
230
Akron6f455152025-05-27 09:03:00 +0200231 // Handle wrapped nodes (Token and CatchallNode)
232 if m.tryMatchWrapped(node, pattern) {
233 return true
Akronbf5149c2025-05-20 15:53:41 +0200234 }
Akronb7e1f352025-05-16 15:45:23 +0200235
Akron6f455152025-05-27 09:03:00 +0200236 switch p := pattern.(type) {
237 case *ast.Token:
238 if n, ok := node.(*ast.Token); ok {
239 return m.matchNode(n.Wrap, p.Wrap)
Akronbf5149c2025-05-20 15:53:41 +0200240 }
Akronbf5149c2025-05-20 15:53:41 +0200241
Akron6f455152025-05-27 09:03:00 +0200242 case *ast.Term:
243 return m.matchTerm(node, p)
244
245 case *ast.TermGroup:
246 if p.Relation == ast.OrRelation {
247 // For OR relations, check if any operand matches
248 for _, pOp := range p.Operands {
Akronbf5149c2025-05-20 15:53:41 +0200249 if m.matchNode(node, pOp) {
Akronb7e1f352025-05-16 15:45:23 +0200250 return true
251 }
252 }
Akron6f455152025-05-27 09:03:00 +0200253 } else if tg, ok := node.(*ast.TermGroup); ok && tg.Relation == p.Relation {
254 // For AND relations, all pattern operands must match in any order
255 return m.matchAndTermGroup(tg, p)
Akronb7e1f352025-05-16 15:45:23 +0200256 }
Akronb7e1f352025-05-16 15:45:23 +0200257 }
258
259 return false
260}
261
Akron6f455152025-05-27 09:03:00 +0200262// tryMatchWrapped attempts to match a node that might wrap other nodes
263func (m *Matcher) tryMatchWrapped(node, pattern ast.Node) bool {
264 switch n := node.(type) {
265 case *ast.Token:
266 if n.Wrap != nil {
267 return m.matchNode(n.Wrap, pattern)
268 }
269 case *ast.CatchallNode:
270 if n.Wrap != nil && m.matchNode(n.Wrap, pattern) {
271 return true
272 }
273 for _, op := range n.Operands {
274 if m.matchNode(op, pattern) {
275 return true
276 }
277 }
278 case *ast.TermGroup:
279 for _, op := range n.Operands {
280 if m.matchNode(op, pattern) {
281 return true
282 }
283 }
284 }
285 return false
286}
287
288// matchTerm checks if a node matches a term pattern
289func (m *Matcher) matchTerm(node ast.Node, pattern *ast.Term) bool {
290 if t, ok := node.(*ast.Term); ok {
291 return t.Foundry == pattern.Foundry &&
292 t.Key == pattern.Key &&
293 t.Layer == pattern.Layer &&
294 t.Match == pattern.Match &&
295 (pattern.Value == "" || t.Value == pattern.Value)
296 }
297 return m.tryMatchWrapped(node, pattern)
298}
299
300// matchAndTermGroup checks if a TermGroup matches an AND pattern
301func (m *Matcher) matchAndTermGroup(node *ast.TermGroup, pattern *ast.TermGroup) bool {
302 if len(node.Operands) < len(pattern.Operands) {
303 return false
304 }
305 matched := make([]bool, len(node.Operands))
306 for _, pOp := range pattern.Operands {
307 found := false
308 for j, tOp := range node.Operands {
309 if !matched[j] && m.matchNode(tOp, pOp) {
310 matched[j] = true
311 found = true
312 break
313 }
314 }
315 if !found {
316 return false
317 }
318 }
319 return true
320}
321
Akronb7e1f352025-05-16 15:45:23 +0200322// cloneNode creates a deep copy of a node
323func (m *Matcher) cloneNode(node ast.Node) ast.Node {
324 if node == nil {
325 return nil
326 }
327
328 switch n := node.(type) {
329 case *ast.Token:
330 return &ast.Token{
331 Wrap: m.cloneNode(n.Wrap),
332 }
333
334 case *ast.TermGroup:
335 operands := make([]ast.Node, len(n.Operands))
336 for i, op := range n.Operands {
337 operands[i] = m.cloneNode(op)
338 }
339 return &ast.TermGroup{
340 Operands: operands,
341 Relation: n.Relation,
342 }
343
344 case *ast.Term:
345 return &ast.Term{
346 Foundry: n.Foundry,
347 Key: n.Key,
348 Layer: n.Layer,
349 Match: n.Match,
350 Value: n.Value,
351 }
352
Akron32958422025-05-16 16:33:05 +0200353 case *ast.CatchallNode:
354 newNode := &ast.CatchallNode{
355 NodeType: n.NodeType,
356 RawContent: n.RawContent,
357 }
358 if n.Wrap != nil {
359 newNode.Wrap = m.cloneNode(n.Wrap)
360 }
361 if len(n.Operands) > 0 {
362 newNode.Operands = make([]ast.Node, len(n.Operands))
363 for i, op := range n.Operands {
364 newNode.Operands[i] = m.cloneNode(op)
365 }
366 }
367 return newNode
368
Akronb7e1f352025-05-16 15:45:23 +0200369 default:
370 return nil
371 }
372}