wang2vec: make sure nets are read correctly
diff --git a/word2vecExt.c b/word2vecExt.c
index 6aec264..1881e12 100644
--- a/word2vecExt.c
+++ b/word2vecExt.c
@@ -509,6 +509,8 @@
void InitNet() {
long long a, b;
unsigned long long next_random = 1;
+ long long read;
+
window_layer_size = layer1_size * window * 2;
a = posix_memalign((void **) &syn0, 128,
(long long) vocab_size * layer1_size * sizeof(real));
@@ -639,8 +641,23 @@
printf("Net parameter file not found\n");
exit(1);
}
- fread(syn0, sizeof(real), vocab_size * layer1_size, fnet);
- fread(syn1neg_window, sizeof(real), vocab_size * window_layer_size, fnet);
+ printf("vocab-size: %lld, layer1_size: %lld, window_layer_size %d\n", vocab_size, layer1_size, window_layer_size);
+ read = fread(syn0, sizeof(real), vocab_size * layer1_size, fnet);
+ if(read != vocab_size * layer1_size) {
+ fprintf(stderr, "read-net failed %lld\n", read);
+ exit(-1);
+ }
+ read = fread(syn1neg_window, sizeof(real), vocab_size * window_layer_size, fnet);
+ if(read != (long long) vocab_size * window_layer_size) {
+ fprintf(stderr, "read-net failed, read %lld, expected: %lld\n", read ,
+ (long long) sizeof(real) * vocab_size * window_layer_size);
+ exit(-1);
+ }
+ fgetc(fnet);
+ if(!feof(fnet)) {
+ fprintf(stderr, "Remaining bytes in net-file after read-net. File position: %ld\n", ftell(fnet));
+ exit(-1);
+ }
fclose(fnet);
} else {
fprintf(stderr, "read-net only supported for type 3 with negative sampling\n");