blob: bc78ba7f7b894dcebfbe158e94d73bcd6ee1401b [file] [log] [blame]
Marc Kupietzd6f9c712016-03-16 11:50:56 +01001// Copyright 2013 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <stdio.h>
16#include <string.h>
17#include <math.h>
18#include <stdlib.h>
19
20const long long max_size = 2000; // max length of strings
21const long long N = 40; // number of closest words that will be shown
22const long long max_w = 50; // max length of vocabulary entries
23
24int main(int argc, char **argv) {
25 FILE *f;
26 char st1[max_size];
27 char bestw[N][max_size];
28 char file_name[max_size], st[100][max_size];
29 float dist, len, bestd[N], vec[max_size];
30 long long words, size, a, b, c, d, cn, bi[100];
31 float *M;
32 char *vocab;
33 if (argc < 2) {
34 printf("Usage: ./word-analogy <FILE>\nwhere FILE contains word projections in the BINARY FORMAT\n");
35 return 0;
36 }
37 strcpy(file_name, argv[1]);
38 f = fopen(file_name, "rb");
39 if (f == NULL) {
40 printf("Input file not found\n");
41 return -1;
42 }
43 fscanf(f, "%lld", &words);
44 fscanf(f, "%lld", &size);
45 vocab = (char *)malloc((long long)words * max_w * sizeof(char));
46 M = (float *)malloc((long long)words * (long long)size * sizeof(float));
47 if (M == NULL) {
48 printf("Cannot allocate memory: %lld MB %lld %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size);
49 return -1;
50 }
51 for (b = 0; b < words; b++) {
52 a = 0;
53 while (1) {
54 vocab[b * max_w + a] = fgetc(f);
55 if (feof(f) || (vocab[b * max_w + a] == ' ')) break;
56 if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++;
57 }
58 vocab[b * max_w + a] = 0;
59 for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f);
60 len = 0;
61 for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size];
62 len = sqrt(len);
63 for (a = 0; a < size; a++) M[a + b * size] /= len;
64 }
65 fclose(f);
66 while (1) {
67 for (a = 0; a < N; a++) bestd[a] = 0;
68 for (a = 0; a < N; a++) bestw[a][0] = 0;
69 printf("Enter three words (EXIT to break): ");
70 a = 0;
71 while (1) {
72 st1[a] = fgetc(stdin);
73 if ((st1[a] == '\n') || (a >= max_size - 1)) {
74 st1[a] = 0;
75 break;
76 }
77 a++;
78 }
79 if (!strcmp(st1, "EXIT")) break;
80 cn = 0;
81 b = 0;
82 c = 0;
83 while (1) {
84 st[cn][b] = st1[c];
85 b++;
86 c++;
87 st[cn][b] = 0;
88 if (st1[c] == 0) break;
89 if (st1[c] == ' ') {
90 cn++;
91 b = 0;
92 c++;
93 }
94 }
95 cn++;
96 if (cn < 3) {
97 printf("Only %lld words were entered.. three words are needed at the input to perform the calculation\n", cn);
98 continue;
99 }
100 for (a = 0; a < cn; a++) {
101 for (b = 0; b < words; b++) if (!strcmp(&vocab[b * max_w], st[a])) break;
102 if (b == words) b = 0;
103 bi[a] = b;
104 printf("\nWord: %s Position in vocabulary: %lld\n", st[a], bi[a]);
105 if (b == 0) {
106 printf("Out of dictionary word!\n");
107 break;
108 }
109 }
110 if (b == 0) continue;
111 printf("\n Word Distance\n------------------------------------------------------------------------\n");
112 for (a = 0; a < size; a++) vec[a] = M[a + bi[1] * size] - M[a + bi[0] * size] + M[a + bi[2] * size];
113 len = 0;
114 for (a = 0; a < size; a++) len += vec[a] * vec[a];
115 len = sqrt(len);
116 for (a = 0; a < size; a++) vec[a] /= len;
117 for (a = 0; a < N; a++) bestd[a] = 0;
118 for (a = 0; a < N; a++) bestw[a][0] = 0;
119 for (c = 0; c < words; c++) {
120 if (c == bi[0]) continue;
121 if (c == bi[1]) continue;
122 if (c == bi[2]) continue;
123 a = 0;
124 for (b = 0; b < cn; b++) if (bi[b] == c) a = 1;
125 if (a == 1) continue;
126 dist = 0;
127 for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size];
128 for (a = 0; a < N; a++) {
129 if (dist > bestd[a]) {
130 for (d = N - 1; d > a; d--) {
131 bestd[d] = bestd[d - 1];
132 strcpy(bestw[d], bestw[d - 1]);
133 }
134 bestd[a] = dist;
135 strcpy(bestw[a], &vocab[c * max_w]);
136 break;
137 }
138 }
139 }
140 for (a = 0; a < N; a++) printf("%50s\t\t%f\n", bestw[a], bestd[a]);
141 }
142 return 0;
143}