| Marc Kupietz | d6f9c71 | 2016-03-16 11:50:56 +0100 | [diff] [blame] | 1 | //  Copyright 2013 Google Inc. All Rights Reserved. | 
|  | 2 | // | 
|  | 3 | //  Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | //  you may not use this file except in compliance with the License. | 
|  | 5 | //  You may obtain a copy of the License at | 
|  | 6 | // | 
|  | 7 | //      http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | // | 
|  | 9 | //  Unless required by applicable law or agreed to in writing, software | 
|  | 10 | //  distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | //  See the License for the specific language governing permissions and | 
|  | 13 | //  limitations under the License. | 
|  | 14 |  | 
|  | 15 | #include <stdio.h> | 
|  | 16 | #include <string.h> | 
|  | 17 | #include <math.h> | 
|  | 18 | #include <stdlib.h> | 
|  | 19 |  | 
|  | 20 | const long long max_size = 2000;         // max length of strings | 
|  | 21 | const long long N = 40;                  // number of closest words that will be shown | 
|  | 22 | const long long max_w = 50;              // max length of vocabulary entries | 
|  | 23 |  | 
|  | 24 | int main(int argc, char **argv) { | 
|  | 25 | FILE *f; | 
|  | 26 | char st1[max_size]; | 
|  | 27 | char bestw[N][max_size]; | 
|  | 28 | char file_name[max_size], st[100][max_size]; | 
|  | 29 | float dist, len, bestd[N], vec[max_size]; | 
|  | 30 | long long words, size, a, b, c, d, cn, bi[100]; | 
|  | 31 | float *M; | 
|  | 32 | char *vocab; | 
|  | 33 | if (argc < 2) { | 
|  | 34 | printf("Usage: ./word-analogy <FILE>\nwhere FILE contains word projections in the BINARY FORMAT\n"); | 
|  | 35 | return 0; | 
|  | 36 | } | 
|  | 37 | strcpy(file_name, argv[1]); | 
|  | 38 | f = fopen(file_name, "rb"); | 
|  | 39 | if (f == NULL) { | 
|  | 40 | printf("Input file not found\n"); | 
|  | 41 | return -1; | 
|  | 42 | } | 
|  | 43 | fscanf(f, "%lld", &words); | 
|  | 44 | fscanf(f, "%lld", &size); | 
|  | 45 | vocab = (char *)malloc((long long)words * max_w * sizeof(char)); | 
|  | 46 | M = (float *)malloc((long long)words * (long long)size * sizeof(float)); | 
|  | 47 | if (M == NULL) { | 
|  | 48 | printf("Cannot allocate memory: %lld MB    %lld  %lld\n", (long long)words * size * sizeof(float) / 1048576, words, size); | 
|  | 49 | return -1; | 
|  | 50 | } | 
|  | 51 | for (b = 0; b < words; b++) { | 
|  | 52 | a = 0; | 
|  | 53 | while (1) { | 
|  | 54 | vocab[b * max_w + a] = fgetc(f); | 
|  | 55 | if (feof(f) || (vocab[b * max_w + a] == ' ')) break; | 
|  | 56 | if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++; | 
|  | 57 | } | 
|  | 58 | vocab[b * max_w + a] = 0; | 
|  | 59 | for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f); | 
|  | 60 | len = 0; | 
|  | 61 | for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size]; | 
|  | 62 | len = sqrt(len); | 
|  | 63 | for (a = 0; a < size; a++) M[a + b * size] /= len; | 
|  | 64 | } | 
|  | 65 | fclose(f); | 
|  | 66 | while (1) { | 
|  | 67 | for (a = 0; a < N; a++) bestd[a] = 0; | 
|  | 68 | for (a = 0; a < N; a++) bestw[a][0] = 0; | 
|  | 69 | printf("Enter three words (EXIT to break): "); | 
|  | 70 | a = 0; | 
|  | 71 | while (1) { | 
|  | 72 | st1[a] = fgetc(stdin); | 
|  | 73 | if ((st1[a] == '\n') || (a >= max_size - 1)) { | 
|  | 74 | st1[a] = 0; | 
|  | 75 | break; | 
|  | 76 | } | 
|  | 77 | a++; | 
|  | 78 | } | 
|  | 79 | if (!strcmp(st1, "EXIT")) break; | 
|  | 80 | cn = 0; | 
|  | 81 | b = 0; | 
|  | 82 | c = 0; | 
|  | 83 | while (1) { | 
|  | 84 | st[cn][b] = st1[c]; | 
|  | 85 | b++; | 
|  | 86 | c++; | 
|  | 87 | st[cn][b] = 0; | 
|  | 88 | if (st1[c] == 0) break; | 
|  | 89 | if (st1[c] == ' ') { | 
|  | 90 | cn++; | 
|  | 91 | b = 0; | 
|  | 92 | c++; | 
|  | 93 | } | 
|  | 94 | } | 
|  | 95 | cn++; | 
|  | 96 | if (cn < 3) { | 
|  | 97 | printf("Only %lld words were entered.. three words are needed at the input to perform the calculation\n", cn); | 
|  | 98 | continue; | 
|  | 99 | } | 
|  | 100 | for (a = 0; a < cn; a++) { | 
|  | 101 | for (b = 0; b < words; b++) if (!strcmp(&vocab[b * max_w], st[a])) break; | 
|  | 102 | if (b == words) b = 0; | 
|  | 103 | bi[a] = b; | 
|  | 104 | printf("\nWord: %s  Position in vocabulary: %lld\n", st[a], bi[a]); | 
|  | 105 | if (b == 0) { | 
|  | 106 | printf("Out of dictionary word!\n"); | 
|  | 107 | break; | 
|  | 108 | } | 
|  | 109 | } | 
|  | 110 | if (b == 0) continue; | 
|  | 111 | printf("\n                                              Word              Distance\n------------------------------------------------------------------------\n"); | 
|  | 112 | for (a = 0; a < size; a++) vec[a] = M[a + bi[1] * size] - M[a + bi[0] * size] + M[a + bi[2] * size]; | 
|  | 113 | len = 0; | 
|  | 114 | for (a = 0; a < size; a++) len += vec[a] * vec[a]; | 
|  | 115 | len = sqrt(len); | 
|  | 116 | for (a = 0; a < size; a++) vec[a] /= len; | 
|  | 117 | for (a = 0; a < N; a++) bestd[a] = 0; | 
|  | 118 | for (a = 0; a < N; a++) bestw[a][0] = 0; | 
|  | 119 | for (c = 0; c < words; c++) { | 
|  | 120 | if (c == bi[0]) continue; | 
|  | 121 | if (c == bi[1]) continue; | 
|  | 122 | if (c == bi[2]) continue; | 
|  | 123 | a = 0; | 
|  | 124 | for (b = 0; b < cn; b++) if (bi[b] == c) a = 1; | 
|  | 125 | if (a == 1) continue; | 
|  | 126 | dist = 0; | 
|  | 127 | for (a = 0; a < size; a++) dist += vec[a] * M[a + c * size]; | 
|  | 128 | for (a = 0; a < N; a++) { | 
|  | 129 | if (dist > bestd[a]) { | 
|  | 130 | for (d = N - 1; d > a; d--) { | 
|  | 131 | bestd[d] = bestd[d - 1]; | 
|  | 132 | strcpy(bestw[d], bestw[d - 1]); | 
|  | 133 | } | 
|  | 134 | bestd[a] = dist; | 
|  | 135 | strcpy(bestw[a], &vocab[c * max_w]); | 
|  | 136 | break; | 
|  | 137 | } | 
|  | 138 | } | 
|  | 139 | } | 
|  | 140 | for (a = 0; a < N; a++) printf("%50s\t\t%f\n", bestw[a], bestd[a]); | 
|  | 141 | } | 
|  | 142 | return 0; | 
|  | 143 | } |