w2vserver: properly search for closest merged vectors
diff --git a/w2v-server.pl b/w2v-server.pl
index 6cb8a7e..d0084fb 100755
--- a/w2v-server.pl
+++ b/w2v-server.pl
@@ -879,9 +879,7 @@
goto end;
old_words = cutoff;
- if(merge_words > 0)
- cutoff = merge_words * 1.25; /* HACK */
- slice = (para_threads? cutoff / para_threads : 0);
+ slice = (para_threads? cutoff / para_threads * (merge_words > 0? 2 : 1) : 0);
a = posix_memalign((void **) &target_sums, 128, cutoff * sizeof(float));
for(a = 0; a < cutoff; a++)
@@ -895,8 +893,14 @@
pars[a].wl = wl;
pars[a].N = N;
pars[a].best = &best[N*a];
- pars[a].from = a*slice;
- pars[a].upto = ((a+1)*slice > cutoff? cutoff:(a+1)*slice);
+ if(merge_words == 0 || a < para_threads / 2) {
+ pars[a].from = a*slice;
+ pars[a].upto = ((a+1)*slice > cutoff? cutoff : (a+1) * slice);
+ } else {
+ pars[a].from = merge_words + (a - para_threads / 2) * slice;
+ pars[a].upto = merge_words + ((a - para_threads / 2 + 1)*slice > cutoff? cutoff : (a - para_threads / 2 + 1) *slice);
+ }
+ printf("From: %ld, Upto: %ld\n", pars[a].from, pars[a].upto);
pthread_create(&pt[a], NULL, _get_neighbours, (void *) &pars[a]);
}
if(M2) {
@@ -943,6 +947,7 @@
if(filtered)
continue;
}
+/*
if(merge_words > 0) {
if(c >= merge_words) {
if(l1_words > N / 2)
@@ -956,9 +961,9 @@
l2_words++;
}
}
+*/
printf("%s l1:%d l2:%d i:%d a:%ld\n", &vocab[c * max_w], l1_words, l2_words, i, a);
fflush(stdout);
-
HV* hash = newHV();
SV* word = newSVpvf(&vocab[c * max_w], 0);
chosen[i] = c;