#include <collocatordb.h>
#include <malloc.h>
#include <math.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>

#define max_size 2000
#define max_w 50
#define MAX_NEIGHBOURS 1000
#define MAX_WORDS -1
#define MAX_THREADS 100
#define MAX_CC 50
#define EXP_TABLE_SIZE 1000
#define MAX_EXP 6
#define MIN_RESP 0.50

//the thread function
void *connection_handler(void *);

typedef struct {
  long long wordi;
  long position;
  float activation;
  float average;
  float cprobability;  // column wise probability
  float cprobability_sum;
  float probability;
  float activation_sum;
  float max_activation;
  float heat[16];
} collocator;

typedef struct {
  collocator *best;
  int length;
} knn;

typedef struct {
  long long wordi[MAX_NEIGHBOURS];
  char sep[MAX_NEIGHBOURS];
  int length;
} wordlist;

typedef struct {
  long cutoff;
  wordlist *wl;
  char *token;
  int N;
  long from;
  unsigned long upto;
  collocator *best;
  float *target_sums;
  float *window_sums;
  float threshold;
} knnpars;

typedef struct {
  uint32_t index;
  float value;
} sparse_t;

typedef struct {
  uint32_t len;
  sparse_t nbr[100];
} profile_t;

float *M, *M2 = 0L, *syn1neg_window, *expTable;
float *window_sums;
char *vocab;
char *garbage = NULL;
COLLOCATORDB *cdb = NULL;
profile_t *sprofiles = NULL;
size_t sprofiles_qty = 0;

long long words, size, merged_end;
long long merge_words = 0;
int num_threads = 20;
int latin_enc = 0;
int window;

/* load collocation profiles if file exists */
int load_sprofiles(char *vecsname) {
  char *basename = strdup(vecsname);
  char *pos = strstr(basename, ".vecs");
  if (pos)
    *pos = 0;

  char binsprofiles_fname[256];
  strcpy(binsprofiles_fname, basename);
  strcat(binsprofiles_fname, ".sprofiles.bin");
  FILE *fp = fopen(binsprofiles_fname, "rb");
  if (fp == NULL) {
    printf("Collocation profiles %s not found. No problem.\n", binsprofiles_fname);
    return 0;
  }
  fseek(fp, 0L, SEEK_END);
  size_t sz = ftell(fp);
  fclose(fp);

  int fd = open(binsprofiles_fname, O_RDONLY);
  sprofiles = mmap(0, sz, PROT_READ, MAP_SHARED, fd, 0);
  if (sprofiles == MAP_FAILED) {
    close(fd);
    fprintf(stderr, "Cannot mmap %s\n", binsprofiles_fname);
    sprofiles = NULL;
    return 0;
  } else {
    sprofiles_qty = sz / sizeof(profile_t);
    fprintf(stderr, "Successfully mmaped %s containing similar profiles for %ld word forms.\n", binsprofiles_fname, sprofiles_qty);
  }
  return 1;
}

char *removeExtension(char* myStr) {
    char *retStr;
    char *lastExt;
    if (myStr == NULL) return NULL;
    if ((retStr = malloc (strlen (myStr) + 1)) == NULL) return NULL;
    strcpy (retStr, myStr);
    lastExt = strrchr (retStr, '.');
    if (lastExt != NULL)
        *lastExt = '\0';
    return retStr;
}

int init_net(char *file_name, char *net_name, int latin, int do_open_cdb) {
  FILE *f, *binvecs, *binwords;
  int binwords_fd, binvecs_fd, net_fd, i;
  long long a, b;
  float len;
  double val;

  char binvecs_fname[1024], binwords_fname[1024];

  if (strstr(file_name, ".txt")) {
    strcpy(binwords_fname, removeExtension(file_name));
  } else {
    strcpy(binwords_fname, file_name);
  }
  strcat(binwords_fname, ".words");
  strcpy(binvecs_fname, file_name);
  strcat(binvecs_fname, ".vecs");

  latin_enc = latin;
  f = fopen(file_name, "rb");
  if (f == NULL) {
    printf("Input file %s not found\n", file_name);
    return -1;
  }
  fscanf(f, "%lld", &words);
  if (MAX_WORDS > 0 && words > MAX_WORDS) words = MAX_WORDS;
  fscanf(f, "%lld", &size);
  if ((binvecs_fd = open(binvecs_fname, O_RDONLY)) < 0 || (binwords_fd = open(binwords_fname, O_RDONLY)) < 0) {
    printf("Converting %s to memory mappable structures\n", file_name);
    vocab = (char *)malloc((long long)words * max_w * sizeof(char));
    M = (float *)malloc((long long)words * (long long)size * sizeof(float));
    if (M == NULL) {
      printf("Cannot allocate memory: %lld MB    %lld  %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size);
      return -1;
    }
    if (strstr(file_name, ".txt")) {
      printf("%lld words in ascii vector file with vector size %lld\n", words, size);
      for (b = 0; b < words; b++) {
        a = 0;
        while (1) {
          vocab[b * max_w + a] = fgetc(f);
          if (feof(f) || (vocab[b * max_w + a] == ' ')) break;
          if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++;
        }
        vocab[b * max_w + a] = 0;
        len = 0;
        for (a = 0; a < size; a++) {
          fscanf(f, "%lf", &val);
          M[a + b * size] = val;
          len += val * val;
        }
        len = sqrt(len);
        for (a = 0; a < size; a++) M[a + b * size] /= len;
      }
    } else {
      for (b = 0; b < words; b++) {
        a = 0;
        while (1) {
          vocab[b * max_w + a] = fgetc(f);
          if (feof(f) || (vocab[b * max_w + a] == ' ')) break;
          if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++;
        }
        vocab[b * max_w + a] = 0;
        fread(&M[b * size], sizeof(float), size, f);
        len = 0;
        for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size];
        len = sqrt(len);
        for (a = 0; a < size; a++) M[a + b * size] /= len;
      }
    }
    if ((binvecs = fopen(binvecs_fname, "wb")) != NULL && (binwords = fopen(binwords_fname, "wb")) != NULL) {
      fwrite(M, sizeof(float), (long long)words * (long long)size, binvecs);
      fclose(binvecs);
      fwrite(vocab, sizeof(char), (long long)words * max_w, binwords);
      fclose(binwords);
    }
  }
  if ((binvecs_fd = open(binvecs_fname, O_RDONLY)) >= 0 && (binwords_fd = open(binwords_fname, O_RDONLY)) >= 0) {
    M = mmap(0, sizeof(float) * (long long)words * (long long)size, PROT_READ, MAP_SHARED, binvecs_fd, 0);
    vocab = mmap(0, sizeof(char) * (long long)words * max_w, PROT_READ, MAP_SHARED, binwords_fd, 0);
    if (M == MAP_FAILED || vocab == MAP_FAILED) {
      close(binvecs_fd);
      close(binwords_fd);
      fprintf(stderr, "Cannot mmap %s or %s\n", binwords_fname, binvecs_fname);
      exit(-1);
    }
  } else {
    fprintf(stderr, "Cannot open %s or %s\n", binwords_fname, binvecs_fname);
    exit(-1);
  }
  fclose(f);

  if (net_name && strlen(net_name) > 0) {
    if ((net_fd = open(net_name, O_RDONLY)) >= 0) {
      window = (lseek(net_fd, 0, SEEK_END) - sizeof(float) * words * size) / words / size / sizeof(float) / 2;
      //      lseek(net_fd, sizeof(float) * words * size, SEEK_SET);
      // munmap(M,  sizeof(float) * words * size);
      M2 = mmap(0, sizeof(float) * words * size + sizeof(float) * 2 * window * size * words, PROT_READ, MAP_SHARED, net_fd, 0);
      if (M2 == MAP_FAILED) {
        close(net_fd);
        fprintf(stderr, "Cannot mmap %s\n", net_name);
        exit(-1);
      }
      syn1neg_window = M2 + words * size;
    } else {
      fprintf(stderr, "Cannot open %s\n", net_name);
      exit(-1);
    }
    fprintf(stderr, "Successfully memmaped %s. Determined window size: %d\n", net_name, window);

    if (do_open_cdb) {
      char collocatordb_name[2048];
      strcpy(collocatordb_name, net_name);
      char *ext = rindex(collocatordb_name, '.');
      if (ext) {
        strcpy(ext, ".rocksdb");
        if (access(collocatordb_name, R_OK) == 0) {
          *ext = 0;
          fprintf(stderr, "Opening collocator DB	%s\n", collocatordb_name);
          cdb = open_collocatordb(collocatordb_name);
        } else {
           fprintf(stderr, "Cannot open collocator DB	%s\n", collocatordb_name);
        }
      }
    }
  }

  expTable = (float *)malloc((EXP_TABLE_SIZE + 1) * sizeof(float));
  for (i = 0; i < EXP_TABLE_SIZE; i++) {
    expTable[i] = exp((i / (float)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);  // Precompute the exp() table
    expTable[i] = expTable[i] / (expTable[i] + 1);                     // Precompute f(x) = x / (x + 1)
  }
  window_sums = malloc(sizeof(float) * (window + 1) * 2);

  return 0;
}

long mergeVectors(char *file_name) {
  FILE *f;
  int binwords_fd, binvecs_fd;
  float *merge_vecs;
  char *merge_vocab;
  /*  long long merge_words, merge_size; */
  long long merge_size;

  char binvecs_fname[256], binwords_fname[256];


  strcpy(binwords_fname, file_name);
  strcat(binwords_fname, ".words");
  strcpy(binvecs_fname, file_name);
  strcat(binvecs_fname, ".vecs");

  f = fopen(file_name, "rb");
  if (f == NULL) {
    printf("Input file %s not found\n", file_name);
    exit(-1);
  }
  fscanf(f, "%lld", &merge_words);
  fscanf(f, "%lld", &merge_size);
  if (merge_size != size) {
    fprintf(stderr, "vectors must have the same length\n");
    exit(-1);
  }
  if ((binvecs_fd = open(binvecs_fname, O_RDONLY)) >= 0 && (binwords_fd = open(binwords_fname, O_RDONLY)) >= 0) {
    merge_vecs = malloc(sizeof(float) * (words + merge_words) * size);
    merge_vocab = malloc(sizeof(char) * (words + merge_words) * max_w);
    if (merge_vecs == NULL || merge_vocab == NULL) {
      close(binvecs_fd);
      close(binwords_fd);
      fprintf(stderr, "Cannot reserve memory for %s or %s\n", binwords_fname, binvecs_fname);
      exit(-1);
    }
    read(binvecs_fd, merge_vecs, merge_words * size * sizeof(float));
    read(binwords_fd, merge_vocab, merge_words * max_w);
  } else {
    fprintf(stderr, "Cannot open %s or %s\n", binwords_fname, binvecs_fname);
    exit(-1);
  }
  printf("Successfully reallocated memory\nMerging...\n");
  fflush(stdout);
  memcpy(merge_vecs + merge_words * size, M, words * size * sizeof(float));
  memcpy(merge_vocab + merge_words * max_w, vocab, words * max_w);
  munmap(M, words * size * sizeof(float));
  munmap(vocab, words * max_w);
  M = merge_vecs;
  vocab = merge_vocab;
  merged_end = merge_words;
  words += merge_words;
  fclose(f);
  printf("merged_end: %lld, words: %lld\n", merged_end, words);
  //printBiggestMergedDifferences();
  return ((long)merged_end);
}

void filter_garbage() {
  long i;
  unsigned char *w, previous, c;
  garbage = malloc(words);
  memset(garbage, 0, words);
  for (i = 0; i < words; i++) {
    w = (unsigned char *) vocab + i * max_w;
    previous = 0;
    if (strncmp("quot", (const char *)w, 4) == 0) {
      garbage[i] = 1;
      //      printf("Gargabe: %s\n", vocab + i * max_w);
    } else {
      while ((c = *w++) && !garbage[i]) {
        if (((c <= 90 && c >= 65) && (previous >= 97 && previous <= 122)) ||
            (previous == '-' && (c & 32)) ||
            (previous == 0xc2 && (c == 0xa4 || c == 0xb6)) ||
            (previous == 'q' && c == 'u' && *(w) == 'o' && *(w + 1) == 't') || /* quot */
            c == '<') {
          garbage[i] = 1;
          continue;
        }
        previous = c;
      }
    }
  }
  return;
}

knn *simpleGetCollocators(int word, int number, long cutoff, int *result) {
  knnpars *pars = calloc(sizeof(knnpars), 1);
  float *target_sums = NULL;
  float *my_window_sums = malloc(sizeof(float) * (window + 1) * 2);
  pars->cutoff = (cutoff ? cutoff : 300000);
  long a;
  for (a = 0; a < cutoff; a++)
    target_sums[a] = 0;
  pars->target_sums = target_sums;
  pars->window_sums = my_window_sums;
  pars->N = (number ? number : 20);
  pars->from = 0;
  pars->upto = window * 2 - 1;
  knn *syn_nbs = NULL;  // = (knn*) getCollocators(pars);
  free(pars);
  free(my_window_sums);
  free(target_sums);
  return syn_nbs;
}

void *getCollocators(void *args) {
  knnpars *pars = args;
  int N = pars->N;

  int cc = pars->wl->wordi[0];
  knn *nbs = NULL;
  long window_layer_size = size * window * 2;
  long a, b, c, d, window_offset, target, max_target = 0, maxmax_target;
  float f, max_f, maxmax_f;
  float *target_sums = NULL, worstbest, wpos_sum;
  collocator *best;

  if (M2 == NULL || cc == -1)
    return NULL;

  a = posix_memalign((void **)&target_sums, 128, pars->cutoff * sizeof(float));
  memset(target_sums, 0, pars->cutoff * sizeof(float));
  best = malloc((N > 200 ? N : 200) * sizeof(collocator));
  memset(best, 0, (N > 200 ? N : 200) * sizeof(collocator));
  worstbest = pars->threshold;

  for (b = 0; b < pars->cutoff; b++)
    target_sums[b] = 0;
  for (b = 0; b < N; b++) {
    best[b].wordi = -1;
    best[b].probability = 1;
    best[b].activation = worstbest;
  }

  d = cc;
  maxmax_f = -1;
  maxmax_target = 0;

  for (a = pars->from; a < pars->upto; a++) {
    if (a >= window)
      a++;
    wpos_sum = 0;
    printf("window pos: %ld\n", a);
    if (a != window) {
      max_f = -1;
      window_offset = a * size;
      if (a > window)
        window_offset -= size;
      for (target = 0; target < pars->cutoff; target++) {
        if (garbage && garbage[target]) continue;
        if (target == d)
          continue;
        f = 0;
        for (c = 0; c < size; c++)
          f += M2[d * size + c] * syn1neg_window[target * window_layer_size + window_offset + c];
        if (f < -MAX_EXP)
          continue;
        else if (f > MAX_EXP)
          continue;
        else
          f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
        wpos_sum += f;

        target_sums[target] += f;
        if (f > worstbest) {
          for (b = 0; b < N; b++) {
            if (f > best[b].activation) {
              memmove(best + b + 1, best + b, (N - b - 1) * sizeof(collocator));
              best[b].activation = f;
              best[b].wordi = target;
              best[b].position = window - a;
              break;
            }
          }
          if (b == N - 1)
            worstbest = best[N - 1].activation;
        }
      }
      printf("%ld %.2f\n", max_target, max_f);
      printf("%s (%.2f) ", &vocab[max_target * max_w], max_f);
      if (max_f > maxmax_f) {
        maxmax_f = max_f;
        maxmax_target = max_target;
      }
      for (b = 0; b < N; b++)
        if (best[b].position == window - a)
          best[b].cprobability = best[b].activation / wpos_sum;
    } else {
      printf("\x1b[1m%s\x1b[0m ", &vocab[d * max_w]);
    }
    pars->window_sums[a] = wpos_sum;
  }
  for (b = 0; b < pars->cutoff; b++)
    pars->target_sums[b] += target_sums[b];  //(target_sums[b] / wpos_sum ) / (window * 2);

  free(target_sums);
  for (b = 0; b < N && best[b].wordi >= 0; b++)
    ;
     // THIS LOOP IS NEEDED (b...)
     //		printf("%d: best syn: %s %.2f %.5f\n", b, &vocab[best[b].wordi*max_w], best[b].activation, best[b].probability);
     //	printf("\n");
  nbs = malloc(sizeof(knn));
  nbs->best = best;
  nbs->length = b - 1;
  pthread_exit(nbs);
}

float getOutputWeight(int hidden, long target, int window_position) {
  const long window_layer_size = size * window * 2;
  int a;

  if (window_position == 0 || window_position > window || window_position < -window) {
    fprintf(stderr, "window_position: %d - assert: -%d <= window_position <= %d && window_position != 0 failed.\n", window_position, window, window);
    exit(-1);
  }

  if (hidden >= size) {
    fprintf(stderr, "hidden: %d - assert: hidden < %lld failed.\n", hidden, size);
    exit(-1);
  }

  if (target >= words) {
    fprintf(stderr, "target: %ld - assert: target < %lld failed.\n", target, words);
    exit(-1);
  }

  a = window_position + window;
  if (a > window) {
    --a;
  }
  long window_offset = a * size;
  return syn1neg_window[target * window_layer_size + window_offset + hidden];
}

AV *getVecs(AV *array) {
  int i, b;
  AV *result = newAV();
  for (i = 0; i <= av_len(array); i++) {
    SV **elem = av_fetch(array, i, 0);
    if (elem != NULL) {
      long j = (long)SvNV(*elem);
      AV *vector = newAV();
      for (b = 0; b < size; b++) {
        av_push(vector, newSVnv(M[b + j * size]));
      }
      av_push(result, newRV_noinc(vector));
    }
  }
  return result;
}

char *getSimilarProfiles(long node) {
  int i;
  char buffer[120000];
  char pair_buffer[2048];
  buffer[0] = '[';
  buffer[1] = 0;
  if (node >= sprofiles_qty) {
    printf("Not available in precomputed profile\n");
    return (strdup("[{\"w\":\"not available\", \"v\":0}]\n"));
  }

  printf("******* %s ******\n", &vocab[max_w * node]);

  for (i = 0; i < 100 && i < sprofiles[node].len; i++) {
    sprintf(pair_buffer, "{\"w\":\"%s\", \"v\":%f},", &vocab[max_w * (sprofiles[node].nbr[i].index)], sprofiles[node].nbr[i].value);
    strcat(buffer, pair_buffer);
  }
  buffer[strlen(buffer) - 1] = ']';
  strcat(buffer, "\n");
  printf("%s", buffer);
  return (strdup(buffer));
}

char *getCollocationScores(long node, long collocate) {
    char *res = (cdb ? strdup(get_collocation_scores_as_json(cdb, node, collocate)) : "[]");
    return res;
}

char *getClassicCollocators(long node) {
  char *res = (cdb ? strdup(get_collocators_as_json(cdb, node)) : "[]");
  return res;
}

wordlist *getTargetWords(char *st1, int search_backw) {
  wordlist *wl = malloc(sizeof(wordlist));
  char st[100][max_size];
  long a, b = 0, c = 0, cn = 0;

  while (1) {
    st[cn][b] = st1[c];
    b++;
    c++;
    st[cn][b] = 0;
    if (st1[c] == 0) break;
    if (st1[c] == ' ' /*|| st1[c] == '-'*/) {
      b = 0;
      c++;
    }
  }
  cn++;
  for (a = 0; a < cn; a++) {
    if (search_backw) {
      for (b = words - 1; b >= (merge_words ? merge_words : 0) && strcmp(&vocab[b * max_w], st[a]) != 0; b--)
        ;
    } else {
      for (b = 0; b < (merge_words ? merge_words : words) && strcmp(&vocab[b * max_w], st[a]) != 0; b++)
        ;
    }
    if (b == words) b = -1;
    wl->wordi[a] = b;
    if (b == -1) {
      fprintf(stderr, "Out of dictionary word!\n");
      cn--;
    } else {
      fprintf(stderr, "Word: \"%s\"  Position in vocabulary: %lld\n", &vocab[wl->wordi[a] * max_w], wl->wordi[a]);
    }
  }
  wl->length = cn;
  return (wl);
}

long getWordNumber(char *word) {
  wordlist *wl = getTargetWords(word, 0);
  if(wl->length > 0)
    return(wl->wordi[0]);
  return(0);
}

float get_distance(long b, long c) {
  long a;
  float dist = 0;
  for (a = 0; a < size; a++) dist += M[a + c * size] * M[a + b * size];
  return dist;
}

char *getBiggestMergedDifferences() {
  static char *result = NULL;
  float dist;
  long long a, c;
  int N = 1000;

  if (merged_end == 0)
    result = "[]";

  if (result != NULL)
    return result;

  printf("Looking for biggest distances between main and merged vectors ...\n");
  collocator *best;
  best = malloc(N * sizeof(collocator));
  memset(best, 0, N * sizeof(collocator));

  float worstbest = 1000000;

  for (a = 0; a < N; a++) best[a].activation = worstbest;

  for (c = 0; c < 500000; c++) {
    if (garbage && garbage[c]) continue;
    dist = 0;
    for (a = 0; a < size; a++) dist += M[a + c * size] * M[a + (c + merged_end) * size];
    if (dist < worstbest) {
      for (a = 0; a < N; a++) {
        if (dist < best[a].activation) {
          memmove(best + a + 1, best + a, (N - a - 1) * sizeof(collocator));
          best[a].activation = dist;
          best[a].wordi = c;
          break;
        }
      }
      worstbest = best[N - 1].activation;
    }
  }

  result = malloc(N * max_w);
  char *p = result;
  *p++ = '[';
  *p = 0;
  for (a = 0; a < N; a++) {
    p += sprintf(p, "{\"rank\":%lld,\"word\":\"%s\",\"dist\":%.3f},", a, &vocab[best[a].wordi * max_w], 1 - best[a].activation);
  }
  *--p = ']';
  return (result);
}

float cos_similarity(long b, long c) {
  float dist = 0;
  long a;
  for (a = 0; a < size; a++) dist += M[b * size + a] * M[c * size + a];
  return dist;
}

char *cos_similarity_as_json(char *w1, char *w2) {
  wordlist *a, *b;
  float res;
  a = getTargetWords(w1, 0);
  b = getTargetWords(w2, 0);
  if (a == NULL || b == NULL || a->length != 1 || b->length != 1)
    res = -1;
  else
    res = cos_similarity(a->wordi[0], b->wordi[0]);
  fprintf(stderr, "a: %lld b: %lld res:%f\n", a->wordi[0], b->wordi[0], res);
  char *json = malloc(16);
  sprintf(json, "%.5f", res);
  return json;
}

void *_get_neighbours(void *arg) {
  knnpars *pars = arg;
  int N = pars->N;
  long from = pars->from;
  unsigned long upto = pars->upto;
  char *sep;
  float dist, len, vec[max_size];
  long long a, b, c, cn, *bi;
  knn *nbs = NULL;
  wordlist *wl = pars->wl;

  collocator *best = pars->best;

  float worstbest = -1;

  for (a = 0; a < N; a++) best[a].activation = 0;
  a = 0;
  bi = wl->wordi;
  cn = wl->length;
  sep = wl->sep;
  b = bi[0];
  if (b == -1) {
    goto end;
  }
  for (a = 0; a < size; a++) vec[a] = 0;
  for (b = 0; b < cn; b++) {
    if (bi[b] == -1) continue;
    if (b > 0 && sep[b - 1] == '-')
      for (a = 0; a < size; a++) vec[a] -= M[a + bi[b] * size];
    else
      for (a = 0; a < size; a++) vec[a] += M[a + bi[b] * size];
  }
  len = 0;
  for (a = 0; a < size; a++) len += vec[a] * vec[a];
  len = sqrt(len);
  for (a = 0; a < size; a++) vec[a] /= len;
  for (a = 0; a < N; a++) best[a].activation = -1;
  for (c = from; c < upto; c++) {
    if (garbage && garbage[c]) continue;
    a = 0;
    // do not skip taget word
    //		for (b = 0; b < cn; b++) if (bi[b] == c) a = 1;
    //		if (a == 1) continue;
    dist = 0;
    for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size];
    if (dist > worstbest) {
      for (a = 0; a < N; a++) {
        if (dist > best[a].activation) {
          memmove(best + a + 1, best + a, (N - a - 1) * sizeof(collocator));
          best[a].activation = dist;
          best[a].wordi = c;
          break;
        }
      }
      worstbest = best[N - 1].activation;
    }
  }

end:
  pthread_exit(nbs);
}

int cmp_activation(const void *a, const void *b) {
  float fb = ((collocator *)a)->activation;
  float fa = ((collocator *)b)->activation;
  return (fa > fb) - (fa < fb);
}

int cmp_probability(const void *a, const void *b) {
  float fb = ((collocator *)a)->probability;
  float fa = ((collocator *)b)->probability;
  return (fa > fb) - (fa < fb);
}

char *getPosWiseW2VCollocatorsAsTsv(char *word, long maxPerPos, long cutoff, float threshold) {
  HV *result = newHV();
  float *target_sums = NULL;
  long a, b;
  knn *para_nbs[MAX_THREADS];
  knn *syn_nbs[MAX_THREADS];
  knnpars pars[MAX_THREADS];
  pthread_t *pt = (pthread_t *)malloc((num_threads + 1) * sizeof(pthread_t));
  wordlist *wl;
  int syn_threads = (M2 ? window * 2 : 0);
  int search_backw = 0;
  collocator *best = NULL;
  posix_memalign((void **)&best, 128, 10 * (maxPerPos >= 200 ? maxPerPos : 200) * sizeof(collocator));
  memset(best, 0, (maxPerPos >= 200 ? maxPerPos : 200) * sizeof(collocator));

  if (cutoff < 1 || cutoff > words)
    cutoff = words;

  wl = getTargetWords(word, search_backw);
  if (wl == NULL || wl->length < 1)
    return "";

  a = posix_memalign((void **)&target_sums, 128, cutoff * sizeof(float));
  memset(target_sums, 0, cutoff * sizeof(float));

  printf("Starting %d threads\n", syn_threads);
  fflush(stdout);
  for (a = 0; a < syn_threads; a++) {
    pars[a].cutoff = cutoff;
    pars[a].target_sums = target_sums;
    pars[a].window_sums = window_sums;
    pars[a].wl = wl;
    pars[a].N = maxPerPos;
    pars[a].threshold = threshold;
    pars[a].from = a;
    pars[a].upto = a + 1;
    pthread_create(&pt[a], NULL, getCollocators, (void *)&pars[a]);
  }
  printf("Waiting for syn threads to join\n");
  fflush(stdout);
  for (a = 0; a < syn_threads; a++) pthread_join(pt[a], (void *)&syn_nbs[a]);
  printf("Syn threads joint\n");
  fflush(stdout);
  result = malloc(maxPerPos * 80 * syn_threads);
  char *p = result;
  *p = 0;
  for (a = syn_threads - 1; a >= 0; a--) {
    for (b = 0; b < syn_nbs[a]->length; b++) {
      p += sprintf(p, "%ld\t%s\t%f\n", syn_nbs[a]->best[b].position, &vocab[syn_nbs[a]->best[b].wordi * max_w], syn_nbs[a]->best[b].activation);
    }
  }
  return (result);
}

SV *get_neighbours(char *st1, int N, int sort_by, int search_backw, long cutoff, int dedupe, int no_similar_profiles) {
  HV *result = newHV();
  float *target_sums = NULL;
  long a, b, c, d, slice;
  knn *para_nbs[MAX_THREADS];
  knn *syn_nbs[MAX_THREADS];
  knnpars pars[MAX_THREADS];
  pthread_t *pt = (pthread_t *)malloc((num_threads + 1) * sizeof(pthread_t));
  wordlist *wl;
  int syn_threads = (M2 ? window * 2 : 0);
  int para_threads = (no_similar_profiles ? 0 : num_threads - syn_threads);

  collocator *best = NULL;
  posix_memalign((void **)&best, 128, 10 * (N >= 200 ? N : 200) * sizeof(collocator));
  memset(best, 0, (N >= 200 ? N : 200) * sizeof(collocator));

  if (N > MAX_NEIGHBOURS) N = MAX_NEIGHBOURS;

  if (cutoff < 1 || cutoff > words)
    cutoff = words;

  wl = getTargetWords(st1, search_backw);
  if (wl == NULL || wl->length < 1)
    goto end;

  slice = cutoff / para_threads;

  a = posix_memalign((void **)&target_sums, 128, cutoff * sizeof(float));
  memset(target_sums, 0, cutoff * sizeof(float));

  printf("Starting %d threads for paradigmatic search\n", para_threads);
  fflush(stdout);
  for (a = 0; a < para_threads; a++) {
    pars[a].cutoff = cutoff;
    pars[a].token = st1;
    pars[a].wl = wl;
    pars[a].N = N;
    pars[a].best = &best[N * a];
    if (merge_words == 0 || search_backw == 0) {
      pars[a].from = a * slice;
      pars[a].upto = ((a + 1) * slice > cutoff ? cutoff : (a + 1) * slice);
    } else {
      pars[a].from = merge_words + a * slice;
      pars[a].upto = merge_words + ((a + 1) * slice > cutoff ? cutoff : (a + 1) * slice);
    }
    printf("From: %ld, Upto: %ld\n", pars[a].from, pars[a].upto);
    pthread_create(&pt[a], NULL, _get_neighbours, (void *)&pars[a]);
  }
  if (M2) {
    for (a = 0; a < syn_threads; a++) {
      pars[a + para_threads].cutoff = cutoff;
      pars[a + para_threads].target_sums = target_sums;
      pars[a + para_threads].window_sums = window_sums;
      pars[a + para_threads].wl = wl;
      pars[a + para_threads].N = N;
      pars[a + para_threads].threshold = MIN_RESP;
      pars[a + para_threads].from = a;
      pars[a + para_threads].upto = a + 1;
      pthread_create(&pt[a + para_threads], NULL, getCollocators, (void *)&pars[a + para_threads]);
    }
  }
  printf("Waiting for para threads to join\n");
  fflush(stdout);
  for (a = 0; a < para_threads; a++) pthread_join(pt[a], (void *)&para_nbs[a]);
  printf("Para threads joint\n");
  fflush(stdout);

  /* if(!syn_nbs[0]) */
  /* 	goto end; */

  qsort(best, N * para_threads, sizeof(collocator), cmp_activation);

  long long chosen[MAX_NEIGHBOURS];
  printf("N: %d\n", N);

  AV *array = newAV();
  int i, j;
  int l1_words = 0, l2_words = 0;

  for (a = 0, i = 0; i < N && a < N * para_threads; a++) {
    int filtered = 0;
    long long c = best[a].wordi;
    if ((merge_words && dedupe && i > 1) || (!merge_words && dedupe && i > 0)) {
      for (j = 0; j < i && !filtered; j++)
        if (strcasestr(&vocab[c * max_w], &vocab[chosen[j] * max_w]) ||
            strcasestr(&vocab[chosen[j] * max_w], &vocab[c * max_w])) {
          printf("filtering %s %s\n", &vocab[chosen[j] * max_w], &vocab[c * max_w]);
          filtered = 1;
        }
      if (filtered)
        continue;
    }

    if (0 && merge_words > 0) {
      if (c >= merge_words) {
        if (l1_words > N / 2)
          continue;
        else
          l1_words++;
      } else {
        if (l2_words > N / 2)
          continue;
        else
          l2_words++;
      }
    }

    //    printf("%s l1:%d l2:%d i:%d a:%ld\n", &vocab[c * max_w], l1_words, l2_words, i, a);
    //    fflush(stdout);
    HV *hash = newHV();
    SV *word = newSVpvf(&vocab[c * max_w], 0);
    chosen[i] = c;
    if (latin_enc == 0) SvUTF8_on(word);
    fflush(stdout);
    hv_store(hash, "word", strlen("word"), word, 0);
    hv_store(hash, "dist", strlen("dist"), newSVnv(best[a].activation), 0);
    hv_store(hash, "rank", strlen("rank"), newSVuv(best[a].wordi), 0);
    AV *vector = newAV();
    for (b = 0; b < size; b++) {
      av_push(vector, newSVnv(M[b + best[a].wordi * size]));
    }
    hv_store(hash, "vector", strlen("vector"), newRV_noinc((SV *)vector), 0);
    av_push(array, newRV_noinc((SV *)hash));
    i++;
  }
  hv_store(result, "paradigmatic", strlen("paradigmatic"), newRV_noinc((SV *)array), 0);

  for (b = 0; b < MAX_NEIGHBOURS; b++) {
    best[b].wordi = -1L;
    best[b].activation = 0;
    best[b].probability = 0;
    best[b].position = 0;
    best[b].activation_sum = 0;
    memset(best[b].heat, 0, sizeof(float) * 16);
  }

  float total_activation = 0;

  if (M2) {
    printf("Waiting for syn threads to join\n");
    fflush(stdout);
    for (a = 0; a < syn_threads; a++) pthread_join(pt[a + para_threads], (void *)&syn_nbs[a]);
    for (a = 0; a <= syn_threads; a++) {
      if (a == window) continue;
      total_activation += window_sums[a];
      printf("window pos: %ld, sum: %f\n", a, window_sums[a]);
    }
    printf("syn threads joint\n");
    fflush(stdout);

    for (b = 0; b < syn_nbs[0]->length; b++) {
      memcpy(best + b, &syn_nbs[0]->best[b], sizeof(collocator));
      best[b].position = -1;  //  syn_nbs[0]->pos[b];
      best[b].activation_sum = target_sums[syn_nbs[0]->best[b].wordi];
      best[b].max_activation = 0.0;
      best[b].average = 0.0;
      best[b].probability = 0.0;
      best[b].cprobability = syn_nbs[0]->best[b].cprobability;
      memset(best[b].heat, 0, sizeof(float) * 16);
    }

    float best_window_sum[MAX_NEIGHBOURS];
    int found_index = 0, i = 0, w;
    for (a = 0; a < syn_threads; a++) {
      for (b = 0; b < syn_nbs[a]->length; b++) {
        for (i = 0; i < found_index; i++)
          if (best[i].wordi == syn_nbs[a]->best[b].wordi)
            break;
        if (i >= found_index) {
          best[found_index].max_activation = 0.0;
          best[found_index].average = 0.0;
          best[found_index].probability = 0.0;
          memset(best[found_index].heat, 0, sizeof(float) * 16);
          best[found_index].cprobability = syn_nbs[a]->best[b].cprobability;
          best[found_index].activation_sum = target_sums[syn_nbs[a]->best[b].wordi];  // syn_nbs[a]->best[b].activation_sum;
          best[found_index++].wordi = syn_nbs[a]->best[b].wordi;
          //						printf("found: %s\n", &vocab[syn_nbs[a]->index[b] * max_w]);
        }
      }
    }
    sort_by = 0;                         // ALWAYS AUTO-FOCUS
    if (sort_by != 1 && sort_by != 2) {  // sort by auto focus mean
      printf("window: %d  -  syn_threads: %d, %d\n", window, syn_threads, (1 << syn_threads) - 1);
      int wpos;
      int bits_set = 0;
      for (i = 0; i < found_index; i++) {
        best[i].activation = best[i].probability = best[i].average = best[i].cprobability_sum = 0;
        for (w = 1; w < (1 << syn_threads); w++) {  // loop through all possible windows
          float word_window_sum = 0, word_window_average = 0, word_cprobability_sum = 0, word_activation_sum = 0, total_window_sum = 0;
          bits_set = 0;
          for (a = 0; a < syn_threads; a++) {
            if ((1 << a) & w) {
              wpos = (a >= window ? a + 1 : a);
              total_window_sum += window_sums[wpos];
            }
          }
          //					printf("%d window-sum %f\n", w, total_window_sum);
          for (a = 0; a < syn_threads; a++) {
            if ((1 << a) & w) {
              wpos = (a >= window ? a + 1 : a);
              bits_set++;
              for (b = 0; b < syn_nbs[a]->length; b++)
                if (best[i].wordi == syn_nbs[a]->best[b].wordi) {
                  //									float acti = syn_nbs[a]->best[b].activation / total_window_sum;
                  //                  word_window_sum += syn_nbs[a]->dist[b] *  syn_nbs[a]->norm[b]; // / window_sums[wpos];  // syn_nbs[a]->norm[b];
                  //                    word_window_sum += syn_nbs[a]->norm[b]; // / window_sums[wpos];  // syn_nbs[a]->norm[b];
                  //                  word_window_sum = (word_window_sum + syn_nbs[a]->norm[b]) - (word_window_sum * syn_nbs[a]->norm[b]);  // syn_nbs[a]->norm[b];

                  word_window_sum += syn_nbs[a]->best[b].activation;  // / window_sums[wpos];  // syn_nbs[a]->norm[b];
                                                                      //                  word_window_sum += acti - (word_window_sum * acti); syn_nbs[a]->best[b].activation; // / window_sums[wpos];  // syn_nbs[a]->norm[b];

                  word_window_average += syn_nbs[a]->best[b].activation;                                                                 // - word_window_average * syn_nbs[a]->best[b].activation;  // conormalied activation sum
                  word_cprobability_sum += syn_nbs[a]->best[b].cprobability - word_cprobability_sum * syn_nbs[a]->best[b].cprobability;  // conormalied column probability sum
                  word_activation_sum += syn_nbs[a]->best[b].activation;
                  if (syn_nbs[a]->best[b].activation > best[i].max_activation)
                    best[i].max_activation = syn_nbs[a]->best[b].activation;
                  if (syn_nbs[a]->best[b].activation > best[i].heat[wpos])
                    best[i].heat[wpos] = syn_nbs[a]->best[b].activation;
                }
            }
          }
          if (bits_set) {
            word_window_average /= bits_set;
            //						word_activation_sum /= bits_set;
            //						word_window_sum /= bits_set;
          }

          word_window_sum /= total_window_sum;

          if (word_window_sum > best[i].probability) {
            //						best[i].position = w;
            best[i].probability = word_window_sum;
          }

          if (word_cprobability_sum > best[i].cprobability_sum) {
            best[i].position = w;
            best[i].cprobability_sum = word_cprobability_sum;
          }

          best[i].average = word_window_average;
          //						best[i].activation = word_activation_sum;
        }
      }
      qsort(best, found_index, sizeof(collocator), cmp_probability);
      //      for(i=0; i < found_index; i++) {
      //				printf("found: %s - sum: %f - window: %d\n", &vocab[best[i].wordi * max_w], best[i].activation, best[i].position);
      //			}

    } else if (sort_by == 1) {  // responsiveness any window position
      int wpos;
      for (i = 0; i < found_index; i++) {
        float word_window_sum = 0, word_activation_sum = 0, total_window_sum = 0;
        for (a = 0; a < syn_threads; a++) {
          wpos = (a >= window ? a + 1 : a);
          for (b = 0; b < syn_nbs[a]->length; b++)
            if (best[i].wordi == syn_nbs[a]->best[b].wordi) {
              best[i].probability += syn_nbs[a]->best[b].probability;
              if (syn_nbs[a]->best[b].activation > 0.25)
                best[i].position |= 1 << wpos;
              if (syn_nbs[a]->best[b].activation > best[i].activation) {
                best[i].activation = syn_nbs[a]->best[b].activation;
              }
            }
        }
      }
      qsort(best, found_index, sizeof(collocator), cmp_activation);
    } else if (sort_by == 2) {  // single window position
      for (a = 1; a < syn_threads; a++) {
        for (b = 0; b < syn_nbs[a]->length; b++) {
          for (c = 0; c < MAX_NEIGHBOURS; c++) {
            if (syn_nbs[a]->best[b].activation > best[c].activation) {
              for (d = MAX_NEIGHBOURS - 1; d > c; d--) {
                memmove(best + d, best + d - 1, sizeof(collocator));
              }
              memcpy(best + c, &syn_nbs[a]->best[b], sizeof(collocator));
              best[c].position = 1 << (-syn_nbs[a]->best[b].position + window - (syn_nbs[a]->best[b].position < 0 ? 1 : 0));
              break;
            }
          }
        }
      }
    } else {  // sort by mean p
      for (a = 1; a < syn_threads; a++) {
        for (b = 0; b < syn_nbs[a]->length; b++) {
          for (c = 0; c < MAX_NEIGHBOURS; c++) {
            if (target_sums[syn_nbs[a]->best[b].wordi] > best[c].activation_sum) {
              for (d = MAX_NEIGHBOURS - 1; d > c; d--) {
                memmove(best + d, best + d - 1, sizeof(collocator));
              }
              memcpy(best + c, &syn_nbs[a]->best[b], sizeof(collocator));
              best[c].position = (1 << 2 * window) - 1;  // syn_nbs[a]->pos[b];
              best[c].activation_sum = target_sums[syn_nbs[a]->best[b].wordi];
              break;
            }
          }
        }
      }
    }
    array = newAV();
    for (a = 0, i = 0; a < MAX_NEIGHBOURS && best[a].wordi >= 0; a++) {
      long long c = best[a].wordi;
      /*
      if (dedupe) {
	  		int filtered=0;
        for (j=0; j<i; j++)
          if (strcasestr(&vocab[c * max_w], chosen[j]) ||
              strcasestr(chosen[j], &vocab[c * max_w])) {
						printf("filtering %s %s\n", chosen[j], &vocab[c * max_w]);
						filtered = 1;
					}
				if(filtered)
					continue;
			}
*/
      chosen[i++] = c;
      HV *hash = newHV();
      SV *word = newSVpvf(&vocab[best[a].wordi * max_w], 0);
      AV *heat = newAV();
      if (latin_enc == 0) SvUTF8_on(word);
      hv_store(hash, "word", strlen("word"), word, 0);
      hv_store(hash, "rank", strlen("rank"), newSVuv(best[a].wordi), 0);
      hv_store(hash, "average", strlen("average"), newSVnv(best[a].average), 0);
      hv_store(hash, "prob", strlen("prob"), newSVnv(best[a].probability), 0);
      hv_store(hash, "cprob", strlen("cprob"), newSVnv(best[a].cprobability_sum), 0);
      hv_store(hash, "max", strlen("max"), newSVnv(best[a].max_activation), 0);                             // newSVnv(target_sums[best[a].wordi]), 0);
      hv_store(hash, "overall", strlen("overall"), newSVnv(best[a].activation_sum / total_activation), 0);  // newSVnv(target_sums[best[a].wordi]), 0);
      hv_store(hash, "pos", strlen("pos"), newSVnv(best[a].position), 0);
      best[a].heat[5] = 0;
      for (i = 10; i >= 0; i--) av_push(heat, newSVnv(best[a].heat[i]));
      hv_store(hash, "heat", strlen("heat"), newRV_noinc((SV *)heat), 0);
      av_push(array, newRV_noinc((SV *)hash));
    }
    hv_store(result, "syntagmatic", strlen("syntagmatic"), newRV_noinc((SV *)array), 0);
  }
end:
  free(best);
  return newRV_noinc((SV *)result);
}

int dump_vecs(char *fname) {
  long i, j;
  FILE *f;
  /*  if(words>100000)
	 	words=100000;
*/
  if ((f = fopen(fname, "w")) == NULL) {
    fprintf(stderr, "cannot open %s for writing\n", fname);
    return (-1);
  }
  fprintf(f, "%lld %lld\n", words, size);
  for (i = 0; i < words; i++) {
    fprintf(f, "%s ", &vocab[i * max_w]);
    for (j = 0; j < size - 1; j++)
      fprintf(f, "%f ", M[i * size + j]);
    fprintf(f, "%f\n", M[i * size + j]);
  }
  fclose(f);
  return (0);
}

int dump_for_numpy(char *fname) {
  long i, j;
  FILE *f;
  int max = words; // 300000;

  if ((f = fopen(fname, "w")) == NULL) {
    fprintf(stderr, "cannot open %s for writing\n", fname);
    return (-1);
  }
  for (i = 0; i < max; i++) {
    for (j = 0; j < size - 1; j++)
      fprintf(f, "%f\t", M[i * size + j]);
    fprintf(f, "%f\n", M[i * size + j]);
    printf("%s\r\n", &vocab[i * max_w]);
  }
  if (merged_end > 0) {
    for (i = 0; i < max; i++) {
      for (j = 0; j < size - 1; j++)
        fprintf(f, "%f\t", M[(merged_end + i) * size + j]);
      fprintf(f, "%f\n", M[(merged_end + i) * size + j]);
      printf("_%s\r\n", &vocab[i * max_w]);
    }
  }
  fclose(f);
  return (0);
}
