Add delta columns for CA with multiple VCs

Change-Id: Ie7d7bbfec226277f07e0f089e8ed6e847ca048f2
diff --git a/R/collocationAnalysis.R b/R/collocationAnalysis.R
index 0bcd691..bbd017d 100644
--- a/R/collocationAnalysis.R
+++ b/R/collocationAnalysis.R
@@ -42,13 +42,15 @@
 #' @param threshold              minimum value of `thresholdScore` function call to apply collocation analysis recursively
 #' @param localStopwords         vector of stopwords that will not be considered as collocates in the current function call, but that will not be passed to recursive calls
 #' @param collocateFilterRegex   allow only collocates matching the regular expression
+#' @param multiVcMissingScoreFactor factor that is multiplied with the minimum observed score when imputing missing scores for delta computations between virtual corpora
 #' @param ...                    more arguments will be passed to [collocationScoreQuery()]
 #' @inheritParams collocationScoreQuery,KorAPConnection-method
 #' @return Tibble with top collocates, association scores, corresponding URLs for web user interface queries, etc.
 #'
-#' @importFrom dplyr arrange desc slice_head bind_rows
+#' @importFrom dplyr arrange desc slice_head bind_rows group_by mutate ungroup left_join select row_number all_of first
 #' @importFrom purrr pmap
-#' @importFrom tidyr expand_grid
+#' @importFrom tidyr expand_grid pivot_wider
+#' @importFrom rlang sym
 #'
 #' @examples
 #' \dontrun{
@@ -96,6 +98,7 @@
            threshold = 2.0,
            localStopwords = c(),
            collocateFilterRegex = "^[:alnum:]+-?[:alnum:]*$",
+           multiVcMissingScoreFactor = 0.9,
            ...) {
     # https://stackoverflow.com/questions/8096313/no-visible-binding-for-global-variable-note-in-r-cmd-check
     word <- frequency <- O <- NULL
@@ -129,11 +132,13 @@
           localStopwords = localStopwords,
           seed = seed,
           expand = expand,
+          multiVcMissingScoreFactor = multiVcMissingScoreFactor,
           ...
         )
       }) |>
         bind_rows() |>
-        mutate(label = queryStringToLabel(vc))
+        mutate(label = queryStringToLabel(vc)) |>
+        add_multi_vc_comparisons(thresholdScore = thresholdScore, missingScoreFactor = multiVcMissingScoreFactor)
     } else {
       set.seed(seed)
       candidates <- collocatesQuery(
@@ -194,7 +199,8 @@
         exactFrequencies = exactFrequencies,
         searchHitsSampleLimit = searchHitsSampleLimit,
         topCollocatesLimit = topCollocatesLimit,
-        addExamples = FALSE
+        addExamples = FALSE,
+        multiVcMissingScoreFactor = multiVcMissingScoreFactor
       ) |>
         bind_rows(result) |>
         filter(logDice >= 2) |>
@@ -231,6 +237,89 @@
   return(res)
 }
 
+add_multi_vc_comparisons <- function(result, thresholdScore, missingScoreFactor) {
+  label <- node <- collocate <- rankWithinLabel <- NULL
+
+  if (!"label" %in% names(result) || dplyr::n_distinct(result$label) < 2) {
+    return(result)
+  }
+
+  numeric_cols <- names(result)[vapply(result, is.numeric, logical(1))]
+  non_score_cols <- c("N", "O", "O1", "O2", "E", "w", "leftContextSize", "rightContextSize", "frequency")
+  score_cols <- setdiff(numeric_cols, non_score_cols)
+
+  if (length(score_cols) == 0) {
+    return(result)
+  }
+
+  ranking_col <- thresholdScore
+  if (is.null(ranking_col) || is.na(ranking_col) || !ranking_col %in% score_cols) {
+    ranking_col <- if ("logDice" %in% score_cols) "logDice" else score_cols[1]
+  }
+
+  ranking_sym <- rlang::sym(ranking_col)
+
+  result <- result |>
+    dplyr::group_by(label) |>
+    dplyr::mutate(rankWithinLabel = dplyr::row_number(dplyr::desc(!!ranking_sym))) |>
+    dplyr::ungroup()
+
+  comparison <- result |>
+    dplyr::select(node, collocate, label, rankWithinLabel, dplyr::all_of(score_cols)) |>
+    pivot_wider(
+      names_from = label,
+      values_from = c(rankWithinLabel, dplyr::all_of(score_cols)),
+      names_glue = "{.value}_{make.names(label)}",
+      values_fn = dplyr::first
+    )
+
+  labels <- make.names(unique(result$label))
+
+  if (length(labels) == 2) {
+    fill_scores <- function(x, y) {
+      min_val <- suppressWarnings(min(c(x, y), na.rm = TRUE))
+      if (!is.finite(min_val)) {
+        min_val <- 0
+      }
+      x[is.na(x)] <- missingScoreFactor * min_val
+      y[is.na(y)] <- missingScoreFactor * min_val
+      list(x = x, y = y)
+    }
+
+    fill_ranks <- function(x, y) {
+      max_val <- suppressWarnings(max(c(x, y), na.rm = TRUE))
+      if (!is.finite(max_val)) {
+        max_val <- 0
+      }
+      x[is.na(x)] <- max_val + 1
+      y[is.na(y)] <- max_val + 1
+      list(x = x, y = y)
+    }
+
+    left_label <- labels[1]
+    right_label <- labels[2]
+
+    for (col in score_cols) {
+      left_col <- paste0(col, "_", left_label)
+      right_col <- paste0(col, "_", right_label)
+      if (!all(c(left_col, right_col) %in% names(comparison))) {
+        next
+      }
+      filled <- fill_scores(comparison[[left_col]], comparison[[right_col]])
+      comparison[[paste0("delta_", col)]] <- filled$x - filled$y
+    }
+
+    left_rank <- paste0("rankWithinLabel_", left_label)
+    right_rank <- paste0("rankWithinLabel_", right_label)
+    if (all(c(left_rank, right_rank) %in% names(comparison))) {
+      filled_rank <- fill_ranks(comparison[[left_rank]], comparison[[right_rank]])
+      comparison[["delta_rankWithinLabel"]] <- filled_rank$x - filled_rank$y
+    }
+  }
+
+  dplyr::left_join(result, comparison, by = c("node", "collocate"))
+}
+
 #' @importFrom magrittr debug_pipe
 #' @importFrom stringr str_detect
 #' @importFrom dplyr as_tibble tibble rename filter anti_join tibble bind_rows case_when