Add RestrictObligatory filter for response rewriting
Change-Id: I30a386ac48fa8dcbd0635b77fa6449c755f0cd59
diff --git a/ast/ast.go b/ast/ast.go
index e47a78e..2da55b8 100644
--- a/ast/ast.go
+++ b/ast/ast.go
@@ -308,3 +308,92 @@
}
}
}
+
+// RestrictToObligatory takes a replacement node from a mapping rule and reduces the boolean structure
+// to only obligatory operations by removing optional OR-relations and keeping required AND-relations.
+// It also applies foundry and layer overrides like ApplyFoundryAndLayerOverrides().
+// Note: This function is designed for mapping rule replacement nodes and does not handle CatchallNodes.
+// For efficiency, restriction is performed first, then foundry/layer overrides are applied to the smaller result.
+//
+// Examples:
+// - (a & b & c) -> (a & b & c) (kept as is)
+// - (a & b & (c | d) & e) -> (a & b & e) (OR-relation removed)
+// - (a | b) -> nil (completely optional)
+func RestrictToObligatory(node Node, foundry, layer string) Node {
+ if node == nil {
+ return nil
+ }
+
+ // First, clone and restrict to obligatory operations
+ cloned := node.Clone()
+ restricted := restrictToObligatoryRecursive(cloned)
+
+ // Then apply foundry and layer overrides to the smaller, restricted tree
+ if restricted != nil {
+ ApplyFoundryAndLayerOverrides(restricted, foundry, layer)
+ }
+
+ return restricted
+}
+
+// restrictToObligatoryRecursive performs the actual restriction logic
+func restrictToObligatoryRecursive(node Node) Node {
+ if node == nil {
+ return nil
+ }
+
+ switch n := node.(type) {
+ case *Term:
+ // Terms are always obligatory
+ return n
+
+ case *Token:
+ // Process the wrapped node
+ if n.Wrap != nil {
+ restricted := restrictToObligatoryRecursive(n.Wrap)
+ if restricted == nil {
+ return nil
+ }
+ return &Token{
+ Wrap: restricted,
+ Rewrites: n.Rewrites,
+ }
+ }
+ return n
+
+ case *TermGroup:
+ if n.Relation == OrRelation {
+ // OR-relations are optional, so remove them
+ return nil
+ } else if n.Relation == AndRelation {
+ // AND-relations are obligatory, but we need to process operands
+ var obligatoryOperands []Node
+ for _, operand := range n.Operands {
+ restricted := restrictToObligatoryRecursive(operand)
+ if restricted != nil {
+ obligatoryOperands = append(obligatoryOperands, restricted)
+ }
+ }
+
+ // If no operands remain, return nil
+ if len(obligatoryOperands) == 0 {
+ return nil
+ }
+
+ // If only one operand remains, return it directly
+ if len(obligatoryOperands) == 1 {
+ return obligatoryOperands[0]
+ }
+
+ // Return the group with obligatory operands
+ return &TermGroup{
+ Operands: obligatoryOperands,
+ Relation: AndRelation,
+ Rewrites: n.Rewrites,
+ }
+ }
+ }
+
+ // For unknown node types, return as is
+ return node
+}