blob: 24238bc5bd016191ef5ff5405d54bb7d80ede13b [file] [log] [blame]
Marc Kupietzd6f9c712016-03-16 11:50:56 +01001// Copyright 2013 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <stdio.h>
16#include <stdlib.h>
17#include <string.h>
18#include <math.h>
19#include <pthread.h>
20
21#define MAX_STRING 60
22
23const int vocab_hash_size = 500000000; // Maximum 500M entries in the vocabulary
24
25typedef float real; // Precision of float numbers
26
27struct vocab_word {
28 long long cn;
29 char *word;
30};
31
32char train_file[MAX_STRING], output_file[MAX_STRING];
33struct vocab_word *vocab;
34int debug_mode = 2, min_count = 5, *vocab_hash, min_reduce = 1;
35long long vocab_max_size = 10000, vocab_size = 0;
36long long train_words = 0;
37real threshold = 100;
38
39unsigned long long next_random = 1;
40
41// Reads a single word from a file, assuming space + tab + EOL to be word boundaries
42void ReadWord(char *word, FILE *fin) {
43 int a = 0, ch;
44 while (!feof(fin)) {
45 ch = fgetc(fin);
46 if (ch == 13) continue;
47 if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
48 if (a > 0) {
49 if (ch == '\n') ungetc(ch, fin);
50 break;
51 }
52 if (ch == '\n') {
53 strcpy(word, (char *)"</s>");
54 return;
55 } else continue;
56 }
57 word[a] = ch;
58 a++;
59 if (a >= MAX_STRING - 1) a--; // Truncate too long words
60 }
61 word[a] = 0;
62}
63
64// Returns hash value of a word
65int GetWordHash(char *word) {
66 unsigned long long a, hash = 1;
67 for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a];
68 hash = hash % vocab_hash_size;
69 return hash;
70}
71
72// Returns position of a word in the vocabulary; if the word is not found, returns -1
73int SearchVocab(char *word) {
74 unsigned int hash = GetWordHash(word);
75 while (1) {
76 if (vocab_hash[hash] == -1) return -1;
77 if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash];
78 hash = (hash + 1) % vocab_hash_size;
79 }
80 return -1;
81}
82
83// Reads a word and returns its index in the vocabulary
84int ReadWordIndex(FILE *fin) {
85 char word[MAX_STRING];
86 ReadWord(word, fin);
87 if (feof(fin)) return -1;
88 return SearchVocab(word);
89}
90
91// Adds a word to the vocabulary
92int AddWordToVocab(char *word) {
93 unsigned int hash, length = strlen(word) + 1;
94 if (length > MAX_STRING) length = MAX_STRING;
95 vocab[vocab_size].word = (char *)calloc(length, sizeof(char));
96 strcpy(vocab[vocab_size].word, word);
97 vocab[vocab_size].cn = 0;
98 vocab_size++;
99 // Reallocate memory if needed
100 if (vocab_size + 2 >= vocab_max_size) {
101 vocab_max_size += 10000;
102 vocab=(struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word));
103 }
104 hash = GetWordHash(word);
105 while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
106 vocab_hash[hash]=vocab_size - 1;
107 return vocab_size - 1;
108}
109
110// Used later for sorting by word counts
111int VocabCompare(const void *a, const void *b) {
112 return ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn;
113}
114
115// Sorts the vocabulary by frequency using word counts
116void SortVocab() {
117 int a;
118 unsigned int hash;
119 // Sort the vocabulary and keep </s> at the first position
120 qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare);
121 for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
122 for (a = 0; a < vocab_size; a++) {
123 // Words occuring less than min_count times will be discarded from the vocab
124 if (vocab[a].cn < min_count) {
125 vocab_size--;
126 free(vocab[vocab_size].word);
127 } else {
128 // Hash will be re-computed, as after the sorting it is not actual
129 hash = GetWordHash(vocab[a].word);
130 while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
131 vocab_hash[hash] = a;
132 }
133 }
134 vocab = (struct vocab_word *)realloc(vocab, vocab_size * sizeof(struct vocab_word));
135}
136
137// Reduces the vocabulary by removing infrequent tokens
138void ReduceVocab() {
139 int a, b = 0;
140 unsigned int hash;
141 for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) {
142 vocab[b].cn = vocab[a].cn;
143 vocab[b].word = vocab[a].word;
144 b++;
145 } else free(vocab[a].word);
146 vocab_size = b;
147 for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
148 for (a = 0; a < vocab_size; a++) {
149 // Hash will be re-computed, as it is not actual
150 hash = GetWordHash(vocab[a].word);
151 while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
152 vocab_hash[hash] = a;
153 }
154 fflush(stdout);
155 min_reduce++;
156}
157
158void LearnVocabFromTrainFile() {
159 char word[MAX_STRING], last_word[MAX_STRING], bigram_word[MAX_STRING * 2];
160 FILE *fin;
161 long long a, i, start = 1;
162 for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
163 fin = fopen(train_file, "rb");
164 if (fin == NULL) {
165 printf("ERROR: training data file not found!\n");
166 exit(1);
167 }
168 vocab_size = 0;
169 AddWordToVocab((char *)"</s>");
170 while (1) {
171 ReadWord(word, fin);
172 if (feof(fin)) break;
173 if (!strcmp(word, "</s>")) {
174 start = 1;
175 continue;
176 } else start = 0;
177 train_words++;
178 if ((debug_mode > 1) && (train_words % 100000 == 0)) {
179 printf("Words processed: %lldK Vocab size: %lldK %c", train_words / 1000, vocab_size / 1000, 13);
180 fflush(stdout);
181 }
182 i = SearchVocab(word);
183 if (i == -1) {
184 a = AddWordToVocab(word);
185 vocab[a].cn = 1;
186 } else vocab[i].cn++;
187 if (start) continue;
188 sprintf(bigram_word, "%s_%s", last_word, word);
189 bigram_word[MAX_STRING - 1] = 0;
190 strcpy(last_word, word);
191 i = SearchVocab(bigram_word);
192 if (i == -1) {
193 a = AddWordToVocab(bigram_word);
194 vocab[a].cn = 1;
195 } else vocab[i].cn++;
196 if (vocab_size > vocab_hash_size * 0.7) ReduceVocab();
197 }
198 SortVocab();
199 if (debug_mode > 0) {
200 printf("\nVocab size (unigrams + bigrams): %lld\n", vocab_size);
201 printf("Words in train file: %lld\n", train_words);
202 }
203 fclose(fin);
204}
205
206void TrainModel() {
207 long long pa = 0, pb = 0, pab = 0, oov, i, li = -1, cn = 0;
208 char word[MAX_STRING], last_word[MAX_STRING], bigram_word[MAX_STRING * 2];
209 real score;
210 FILE *fo, *fin;
211 printf("Starting training using file %s\n", train_file);
212 LearnVocabFromTrainFile();
213 fin = fopen(train_file, "rb");
214 fo = fopen(output_file, "wb");
215 word[0] = 0;
216 while (1) {
217 strcpy(last_word, word);
218 ReadWord(word, fin);
219 if (feof(fin)) break;
220 if (!strcmp(word, "</s>")) {
221 fprintf(fo, "\n");
222 continue;
223 }
224 cn++;
225 if ((debug_mode > 1) && (cn % 100000 == 0)) {
226 printf("Words written: %lldK%c", cn / 1000, 13);
227 fflush(stdout);
228 }
229 oov = 0;
230 i = SearchVocab(word);
231 if (i == -1) oov = 1; else pb = vocab[i].cn;
232 if (li == -1) oov = 1;
233 li = i;
234 sprintf(bigram_word, "%s_%s", last_word, word);
235 bigram_word[MAX_STRING - 1] = 0;
236 i = SearchVocab(bigram_word);
237 if (i == -1) oov = 1; else pab = vocab[i].cn;
238 if (pa < min_count) oov = 1;
239 if (pb < min_count) oov = 1;
240 if (oov) score = 0; else score = (pab - min_count) / (real)pa / (real)pb * (real)train_words;
241 if (score > threshold) {
242 fprintf(fo, "_%s", word);
243 pb = 0;
244 } else fprintf(fo, " %s", word);
245 pa = pb;
246 }
247 fclose(fo);
248 fclose(fin);
249}
250
251int ArgPos(char *str, int argc, char **argv) {
252 int a;
253 for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
254 if (a == argc - 1) {
255 printf("Argument missing for %s\n", str);
256 exit(1);
257 }
258 return a;
259 }
260 return -1;
261}
262
263int main(int argc, char **argv) {
264 int i;
265 if (argc == 1) {
266 printf("WORD2PHRASE tool v0.1a\n\n");
267 printf("Options:\n");
268 printf("Parameters for training:\n");
269 printf("\t-train <file>\n");
270 printf("\t\tUse text data from <file> to train the model\n");
271 printf("\t-output <file>\n");
272 printf("\t\tUse <file> to save the resulting word vectors / word clusters / phrases\n");
273 printf("\t-min-count <int>\n");
274 printf("\t\tThis will discard words that appear less than <int> times; default is 5\n");
275 printf("\t-threshold <float>\n");
276 printf("\t\t The <float> value represents threshold for forming the phrases (higher means less phrases); default 100\n");
277 printf("\t-debug <int>\n");
278 printf("\t\tSet the debug mode (default = 2 = more info during training)\n");
279 printf("\nExamples:\n");
280 printf("./word2phrase -train text.txt -output phrases.txt -threshold 100 -debug 2\n\n");
281 return 0;
282 }
283 if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]);
284 if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]);
285 if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]);
286 if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]);
287 if ((i = ArgPos((char *)"-threshold", argc, argv)) > 0) threshold = atof(argv[i + 1]);
288 vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word));
289 vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int));
290 TrainModel();
291 return 0;
292}