Fix bugs in remove_column
diff --git a/R/collapse_rows.R b/R/collapse_rows.R
index 5491529..96769be 100644
--- a/R/collapse_rows.R
+++ b/R/collapse_rows.R
@@ -82,7 +82,6 @@
     kable_dt <- kable_dt[-(1:kable_attrs$header_above),]
     names(kable_dt) <- kable_dt_col_names
   }
-  # kable_dt$row_id <- seq(nrow(kable_dt))
   collapse_matrix <- collapse_row_matrix(kable_dt, columns, target = target)
 
   for (i in 1:nrow(collapse_matrix)) {
@@ -107,6 +106,7 @@
   }
 
   out <- as_kable_xml(kable_xml)
+  kable_attrs$collapse_matrix <- collapse_matrix
   attributes(out) <- kable_attrs
   if (!"kableExtra" %in% class(out)) class(out) <- c("kableExtra", class(out))
   return(out)
diff --git a/R/remove_column.R b/R/remove_column.R
index 828edae..aac4b9b 100644
--- a/R/remove_column.R
+++ b/R/remove_column.R
@@ -1,23 +1,24 @@
 #' Remove columns
 #'
 #' @param kable_input Output of [knitr::kable()] with format specified
-#' @param columns A numeric value or vector indicating in which column(s) rows need to be removed
+#' @param columns A numeric value or vector indicating in which column(s) rows
+#' need to be removed
 #'
 #' @export
 #'
 #' @examples
-#' mtcars %>% 
-#' kable() %>% 
-#'     remove_column(2:3)
+#' remove_column(kable(mtcars), 1)
 remove_column <- function (kable_input, columns) {
-    if(is.null(columns)) return(kable_input)
+    if (is.null(columns)) return(kable_input)
     kable_format <- attr(kable_input, "format")
     if (!kable_format %in% c("html", "latex")) {
-        warning("Please specify format in kable. kableExtra can customize either ", 
-                "HTML or LaTeX outputs. See https://haozhu233.github.io/kableExtra/ ", 
-                "for details.")
+        warning("Please specify format in kable. kableExtra can customize",
+                " either HTML or LaTeX outputs. See ",
+                "https://haozhu233.github.io/kableExtra/ for details.")
         return(kable_input)
     }
+
+    columns <- sort(unique(columns))
     if (kable_format == "html") {
         return(remove_column_html(kable_input, columns))
     } else if (kable_format == "latex") {
@@ -30,48 +31,99 @@
     kable_xml <- kable_as_xml(kable_input)
     kable_tbody <- xml_tpart(kable_xml, "tbody")
     kable_thead <- xml_tpart(kable_xml, "thead")
-    
-    cell_topleft <- xml2::xml_child(kable_thead, 1) %>% 
-        xml2::xml_child(1) %>% 
-        xml2::xml_text() %>% 
-        stringr::str_trim()
-    has_rownames <- cell_topleft==""
-    
-    head_row <- xml2::xml_child(kable_thead, 1)
-    ncols <- xml2::xml_length(head_row)
-    body_nrows <- xml2::xml_length(kable_tbody)
-    
-    rowspan = matrix(1, nrow = body_nrows, ncol=ncols)
-    for(i in 1:body_nrows){
-        target_row <- xml2::xml_child(kable_tbody, i)
-        target_ncols <- xml2::xml_length(target_row)
-        for(j in 1:target_ncols){
-            target_cell <- xml2::xml_child(target_row, j)
-            span = as.numeric(xml2::xml_attr(target_cell, "rowspan")) %>% replace_na(0)
-            if(span>0){
-                rowspan[i,j]=1
-                rowspan[i+seq(from=1, to=span-1),j]=0
+
+    group_header_rows <- attr(kable_input, "group_header_rows")
+    all_contents_rows <- seq(1, length(xml_children(kable_tbody)))
+
+    if (!is.null(group_header_rows)) {
+        warning("It's recommended to use remove_column after add_header_above.",
+                "Right now some column span numbers might not be correct. ")
+        all_contents_rows <- all_contents_rows[!all_contents_rows %in%
+                                                   group_header_rows]
+    }
+
+    collapse_matrix <- attr(kable_input, "collapse_matrix")
+    collapse_columns <- NULL
+    if (!is.null(collapse_matrix)) {
+        collapse_columns <- sort(as.numeric(sub("x", "",
+                                                names(collapse_matrix))))
+        collapse_columns_origin <- collapse_columns
+    }
+
+    while (length(columns) > 0) {
+        xml2::xml_remove(xml2::xml_child(
+            xml2::xml_child(kable_thead, xml2::xml_length(kable_thead)),
+            columns[1]))
+        if (length(collapse_columns) != 0 && collapse_columns[1] <= columns[1]){
+            if (columns[1] %in% collapse_columns) {
+                column_span <- collapse_matrix[[paste0('x', columns[1])]]
+                non_skip_rows <- column_span != 0
+                collapse_columns <- collapse_columns[
+                    collapse_columns != columns[1]
+                    ] - 1
+            } else {
+                non_skip_rows <- rep(TRUE, length(all_contents_rows))
+            }
+            prior_col <- which(collapse_columns_origin < columns[1])
+            for (i in all_contents_rows[non_skip_rows]) {
+                if (length(prior_col) == 0) {
+                    pos_adj <- 0
+                } else {
+                    pos_adj <- sum(collapse_matrix[i, prior_col] == 0)
+                }
+                target_cell <- xml2::xml_child(
+                    xml2::xml_child(kable_tbody, i),
+                    columns[1] - pos_adj)
+                xml2::xml_remove(target_cell)
+            }
+        } else {
+            for (i in all_contents_rows) {
+                target_cell <- xml2::xml_child(
+                    xml2::xml_child(kable_tbody, i),
+                    columns[1])
+                xml2::xml_remove(target_cell)
             }
         }
+        # not very efficient but for finite task it's probably okay
+        columns <- (columns - 1)[-1]
     }
-    
-    for(i in 1:body_nrows){
-        target_row <- xml2::xml_child(kable_tbody, i)
-        for(j in rev(columns)){
-            target_cell <- xml2::xml_child(target_row, j)
-            if(rowspan[i,j]==1)
-            xml2::xml_remove(target_cell)
-        }
-    }
-    
-    for(j in columns){
-        target_cell_head <- xml2::xml_child(head_row, j)
-        xml2::xml_remove(target_cell_head)
-    }
+
+    # head_row <- xml2::xml_child(kable_thead, xml2::xml_length(kable_thead))
+    # ncols <- xml2::xml_length(head_row)
+    # body_nrows <- xml2::xml_length(kable_tbody)
+    #
+    # rowspan = matrix(1, nrow = body_nrows, ncol=ncols)
+    # for(i in 1:body_nrows){
+    #     target_row <- xml2::xml_child(kable_tbody, i)
+    #     target_ncols <- xml2::xml_length(target_row)
+    #     for(j in 1:target_ncols){
+    #         target_cell <- xml2::xml_child(target_row, j)
+    #         span <- as.numeric(xml2::xml_attr(target_cell, "rowspan"))
+    #         span[is.na(span)] <- 0
+    #         if(span>0){
+    #             rowspan[i,j]=1
+    #             rowspan[i+seq(from=1, to=span-1),j]=0
+    #         }
+    #     }
+    # }
+    #
+    # for(i in 1:body_nrows){
+    #     target_row <- xml2::xml_child(kable_tbody, i)
+    #     for(j in rev(columns)){
+    #         target_cell <- xml2::xml_child(target_row, j)
+    #         if(rowspan[i,j]==1)
+    #         xml2::xml_remove(target_cell)
+    #     }
+    # }
+    #
+    # for(j in columns){
+    #     target_cell_head <- xml2::xml_child(head_row, j)
+    #     xml2::xml_remove(target_cell_head)
+    # }
     out <- as_kable_xml(kable_xml)
     attributes(out) <- kable_attrs
-    if (!"kableExtra" %in% class(out)) 
+    if (!"kableExtra" %in% class(out))
         class(out) <- c("kableExtra", class(out))
-    
+
     return(out)
 }