blob: e01ba9b8d7995cc65a333a4633e51b77a4b46fe0 [file] [log] [blame]
Marc Kupietzd6f9c712016-03-16 11:50:56 +01001// Copyright 2013 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <stdio.h>
16#include <string.h>
17#include <math.h>
18#include <stdlib.h>
19
20const long long max_size = 2000; // max length of strings
21const long long N = 40; // number of closest words that will be shown
22const long long max_w = 50; // max length of vocabulary entries
23
24#define MAX_STRING 100
25void ReadWord(char *word, FILE *fin) {
26 int a = 0, ch;
27 while (!feof(fin)) {
28 ch = fgetc(fin);
29 if (ch == 13) continue;
30 if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
31 if (a > 0) {
32 if (ch == '\n') ungetc(ch, fin);
33 break;
34 }
35 if (ch == '\n') {
36 strcpy(word, (char *)"</s>");
37 return;
38 } else continue;
39 }
40 word[a] = ch;
41 a++;
42 if (a >= MAX_STRING - 1) a--; // Truncate too long words
43 }
44 word[a] = 0;
45}
46
47int main(int argc, char **argv) {
48 FILE *f;
49 char st1[max_size];
50 char *bestw[N];
51 char file_name[max_size], st[100][max_size];
52 float dist, len, bestd[N], vec[max_size];
53 long long words, size, a, b, c, d, cn, bi[100];
54 float *M;
55 char *vocab;
56 if (argc < 2) {
57 printf("Usage: ./distance <FILE>\nwhere FILE contains word projections in the BINARY FORMAT\n");
58 return 0;
59 }
60 strcpy(file_name, argv[1]);
61 f = fopen(file_name, "rb");
62 if (f == NULL) {
63 printf("Input file not found\n");
64 return -1;
65 }
66 fscanf(f, "%lld", &words);
67 fscanf(f, "%lld", &size);
68 vocab = (char *)malloc((long long)words * max_w * sizeof(char));
69 for (a = 0; a < N; a++) bestw[a] = (char *)malloc(max_size * sizeof(char));
70 M = (float *)malloc((long long)words * (long long)size * sizeof(float));
71 if (M == NULL) {
72 printf("Cannot allocate memory: %lld MB %lld %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size);
73 return -1;
74 }
75 for (b = 0; b < words; b++) {
76 a = 0;
77 while (1) {
78 vocab[b * max_w + a] = fgetc(f);
79 if (feof(f) || (vocab[b * max_w + a] == ' ')) break;
80 if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++;
81 }
82 vocab[b * max_w + a] = 0;
83 for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f);
84 len = 0;
85 for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size];
86 len = sqrt(len);
87 for (a = 0; a < size; a++) M[a + b * size] /= len;
88 }
89 fclose(f);
90 while (1) {
91 for (a = 0; a < N; a++) bestd[a] = 0;
92 for (a = 0; a < N; a++) bestw[a][0] = 0;
93 printf("Enter word or sentence (EXIT to break): ");
94 a = 0;
95 while (1) {
96 st1[a] = fgetc(stdin);
97 if ((st1[a] == '\n') || (a >= max_size - 1)) {
98 st1[a] = 0;
99 break;
100 }
101 a++;
102 }
103 if (!strcmp(st1, "EXIT")) break;
104 cn = 0;
105 b = 0;
106 c = 0;
107 while (1) {
108 st[cn][b] = st1[c];
109 b++;
110 c++;
111 st[cn][b] = 0;
112 if (st1[c] == 0) break;
113 if (st1[c] == ' ') {
114 cn++;
115 b = 0;
116 c++;
117 }
118 }
119 cn++;
120 for (a = 0; a < cn; a++) {
121 for (b = 0; b < words; b++) if (!strcmp(&vocab[b * max_w], st[a])) break;
122 if (b == words) b = -1;
123 bi[a] = b;
124 printf("\nWord: %s Position in vocabulary: %lld\n", st[a], bi[a]);
125 if (b == -1) {
126 printf("Out of dictionary word!\n");
127 break;
128 }
129 }
130 if (b == -1) continue;
131 printf("\n Word Cosine distance\n------------------------------------------------------------------------\n");
132 for (a = 0; a < size; a++) vec[a] = 0;
133 for (b = 0; b < cn; b++) {
134 if (bi[b] == -1) continue;
135 for (a = 0; a < size; a++) vec[a] += M[a + bi[b] * size];
136 }
137 len = 0;
138 for (a = 0; a < size; a++) len += vec[a] * vec[a];
139 len = sqrt(len);
140 for (a = 0; a < size; a++) vec[a] /= len;
141 for (a = 0; a < N; a++) bestd[a] = -1;
142 for (a = 0; a < N; a++) bestw[a][0] = 0;
143 for (c = 0; c < words; c++) {
144 a = 0;
145 for (b = 0; b < cn; b++) if (bi[b] == c) a = 1;
146 if (a == 1) continue;
147 dist = 0;
148 for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size];
149 for (a = 0; a < N; a++) {
150 if (dist > bestd[a]) {
151 for (d = N - 1; d > a; d--) {
152 bestd[d] = bestd[d - 1];
153 strcpy(bestw[d], bestw[d - 1]);
154 }
155 bestd[a] = dist;
156 strcpy(bestw[a], &vocab[c * max_w]);
157 break;
158 }
159 }
160 }
161 for (a = 0; a < N; a++) printf("%50s\t\t%f\n", bestw[a], bestd[a]);
162 }
163 return 0;
164}