blob: 444a0301d672b0e79c18a0b788c6388e641f0edc [file] [log] [blame]
Marc Kupietz4b799e92018-01-02 11:04:56 +01001#define EXPORT __attribute__((visibility("visible")))
2#define IMPORT
Marc Kupietz12af0192021-03-13 18:05:14 +01003
Marc Kupietz39887082024-11-22 18:06:20 +01004#include "config.h"
5#include "export.h"
6#include "merge_operators.h"
Marc Kupietz28cc53e2017-12-23 17:24:55 +01007#include "rocksdb/db.h"
8#include "rocksdb/env.h"
Marc Kupietzc8ddf452018-01-07 21:33:12 +01009#include "rocksdb/table.h"
Marc Kupietze889cec2024-11-23 12:08:42 +010010#include "rocksdb/slice.h"
Marc Kupietz39887082024-11-22 18:06:20 +010011#include <algorithm>
12#include <cassert>
13#include <cmath>
14#include <cstdint>
15#include <iostream>
16#include <memory>
Marc Kupietz28cc53e2017-12-23 17:24:55 +010017#include <rocksdb/merge_operator.h>
Marc Kupietzc8ddf452018-01-07 21:33:12 +010018#include <rocksdb/slice_transform.h>
Marc Kupietz39887082024-11-22 18:06:20 +010019#include <sstream> // for ostringstream
20#include <string>
21#include <vector>
Marc Kupietz28cc53e2017-12-23 17:24:55 +010022
Marc Kupietz75af60f2019-01-22 22:34:29 +010023#define WINDOW_SIZE 5
Marc Kupietz98cbcdc2019-01-21 17:11:27 +010024#define FREQUENCY_THRESHOLD 5
Marc Kupietz28cc53e2017-12-23 17:24:55 +010025#define IS_BIG_ENDIAN (*(uint16_t *)"\0\xff" < 0x100)
Marc Kupietz39887082024-11-22 18:06:20 +010026#define encodeCollocation(w1, w2, dist) \
27 (((uint64_t)dist << 56) | ((uint64_t)w2 << 24) | w1)
Marc Kupietz18375e12017-12-24 10:11:18 +010028#define W1(key) (uint64_t)(key & 0xffffff)
29#define W2(key) (uint64_t)((key >> 24) & 0xffffff)
30#define DIST(key) (int8_t)((uint64_t)((key >> 56) & 0xff))
Marc Kupietzc8ddf452018-01-07 21:33:12 +010031
32typedef struct {
33 uint64_t freq;
34 char *word;
Marc Kupietz12af0192021-03-13 18:05:14 +010035} vocab_entry;
Marc Kupietzc8ddf452018-01-07 21:33:12 +010036
37// typedef struct Collocator {
38// uint64_t w2;
39// uint64_t sum;
40// };
41
Marc Kupietz28cc53e2017-12-23 17:24:55 +010042using namespace rocksdb;
Marc Kupietzc8ddf452018-01-07 21:33:12 +010043using namespace std;
Marc Kupietz28cc53e2017-12-23 17:24:55 +010044
Marc Kupietz4b799e92018-01-02 11:04:56 +010045namespace rocksdb {
Marc Kupietz39887082024-11-22 18:06:20 +010046class Collocator {
47public:
48 uint32_t w2;
49 uint64_t f2;
50 uint64_t raw;
51 double pmi;
52 double npmi;
53 double llr;
54 double lfmd;
55 double md;
Marc Kupietze889cec2024-11-23 12:08:42 +010056 double md_nws;
Marc Kupietz39887082024-11-22 18:06:20 +010057 uint64_t left_raw;
58 uint64_t right_raw;
59 double left_pmi;
60 double right_pmi;
61 double dice;
62 double logdice;
63 double ldaf;
64 int window;
65 int af_window;
Marc Kupietz28cc53e2017-12-23 17:24:55 +010066};
Marc Kupietz06c9a9f2018-01-02 16:56:43 +010067
Marc Kupietz39887082024-11-22 18:06:20 +010068size_t num_merge_operator_calls;
69
70void resetNumMergeOperatorCalls() { num_merge_operator_calls = 0; }
71
72size_t num_partial_merge_calls;
73
74void resetNumPartialMergeCalls() { num_partial_merge_calls = 0; }
75
76inline void EncodeFixed64(char *buf, uint64_t value) {
77 if (!IS_BIG_ENDIAN) {
78 memcpy(buf, &value, sizeof(value));
79 } else {
80 buf[0] = value & 0xff;
81 buf[1] = (value >> 8) & 0xff;
82 buf[2] = (value >> 16) & 0xff;
83 buf[3] = (value >> 24) & 0xff;
84 buf[4] = (value >> 32) & 0xff;
85 buf[5] = (value >> 40) & 0xff;
86 buf[6] = (value >> 48) & 0xff;
87 buf[7] = (value >> 56) & 0xff;
88 }
Marc Kupietz4a5e08a2018-06-05 11:07:11 +020089}
90
Marc Kupietz39887082024-11-22 18:06:20 +010091inline uint32_t DecodeFixed32(const char *ptr) {
92 if (!IS_BIG_ENDIAN) {
93 // Load the raw bytes
94 uint32_t result;
95 memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
96 return result;
97 } else {
98 return ((static_cast<uint32_t>(static_cast<unsigned char>(ptr[0]))) |
99 (static_cast<uint32_t>(static_cast<unsigned char>(ptr[1])) << 8) |
100 (static_cast<uint32_t>(static_cast<unsigned char>(ptr[2])) << 16) |
101 (static_cast<uint32_t>(static_cast<unsigned char>(ptr[3])) << 24));
102 }
103}
104
105inline uint64_t DecodeFixed64(const char *ptr) {
106 if (!IS_BIG_ENDIAN) {
107 // Load the raw bytes
108 uint64_t result;
109 memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
110 return result;
111 } else {
112 uint64_t lo = DecodeFixed32(ptr);
113 uint64_t hi = DecodeFixed32(ptr + 4);
114 return (hi << 32) | lo;
115 }
116}
117
118static inline double ca_pmi(uint64_t f1, uint64_t f2, uint64_t f12,
119 uint64_t total, double window_size) {
120 double r1 = f1 * window_size, c1 = f2, e = r1 * c1 / total, o = f12;
121 if (f12 < FREQUENCY_THRESHOLD)
122 return -1.0;
123 else
124 return log2(o / e);
125}
126
127// Bouma, Gerlof (2009): <a
128// href="https://svn.spraakdata.gu.se/repos/gerlof/pub/www/Docs/npmi-pfd.pdf">
129// Normalized (pointwise) mutual information in collocation extraction</a>. In
130// Proceedings of GSCL.
131static double ca_npmi(uint64_t f1, uint64_t f2, uint64_t f12,
132 uint64_t total, double window_size) {
133 double r1 = f1 * window_size, c1 = f2, e = r1 * c1 / total, o = f12;
134 if (f12 < FREQUENCY_THRESHOLD)
135 return -1.0;
136 else
137 return log2(o / e) / (-log2(o / total / window_size));
138}
139
140// Thanopoulos, A., Fakotakis, N., Kokkinakis, G.: Comparative evaluation of
141// collocation extraction metrics. In: International Conference on Language
142// Resources and Evaluation (LREC-2002). (2002) 620–625 double md =
143// log2(pow((double)max * window_size / total, 2) / (window_size *
144// ((double)_vocab[w1].freq/total) * ((double)_vocab[last_w2].freq/total)));
145static double ca_md(uint64_t f1, uint64_t f2, uint64_t f12,
146 uint64_t total, double window_size) {
147 const double r1 = f1 * window_size;
148 const double c1 = f2;
149 const double e = r1 * c1 / total;
150 const double o = f12;
151 return log2(o * o / e);
152}
153
154static double ca_lfmd(uint64_t f1, uint64_t f2, uint64_t f12,
155 uint64_t total, double window_size) {
156 double r1 = f1 * window_size, c1 = f2, e = r1 * c1 / total, o = f12;
157 if (f12 == 0)
158 return 0;
159 return log2(o * o * o / e);
160}
161
162// Evert, Stefan (2004): The Statistics of Word Cooccurrences: Word Pairs and
163// Collocations. PhD dissertation, IMS, University of Stuttgart. Published in
164// 2005, URN urn:nbn:de:bsz:93-opus-23714. Free PDF available from
165// http://purl.org/stefan.evert/PUB/Evert2004phd.pdf
166static double ca_ll(uint64_t w1, uint64_t w2, uint64_t w12, uint64_t n,
167 uint64_t window_size) {
168 double r1 = (double)w1 * window_size, r2 = (double)n - r1, c1 = w2,
169 c2 = n - c1, o11 = w12, o12 = r1 - o11, o21 = c1 - w12, o22 = r2 - o21,
170 e11 = r1 * c1 / n, e12 = r1 * c2 / n, e21 = r2 * c1 / n,
171 e22 = r2 * c2 / n;
172 return (2 * ((o11 > 0 ? o11 * log(o11 / e11) : 0) +
173 (o12 > 0 ? o12 * log(o12 / e12) : 0) +
174 (o21 > 0 ? o21 * log(o21 / e21) : 0) +
175 (o22 > 0 ? o22 * log(o22 / e22) : 0)));
176}
177
178static double ca_dice(uint64_t w1, uint64_t w2, uint64_t w12, uint64_t n,
179 uint64_t window_size) {
180 double r1 = (double)w1 * window_size, c1 = w2;
181 return 2 * w12 / (c1 + r1);
182}
183
184// Rychlý, Pavel (2008): <a
185// href="http://www.fi.muni.cz/usr/sojka/download/raslan2008/13.pdf">A
186// lexicographer-friendly association score.</a> In Proceedings of Recent
187// Advances in Slavonic Natural Language Processing, RASLAN, 6–9.
188static double ca_logdice(uint64_t w1, uint64_t w2, uint64_t w12,
189 uint64_t n, uint64_t window_size) {
190 double r1 = (double)w1 * window_size, c1 = w2;
191 return 14 + log2(2 * w12 / (c1 + r1));
192}
193
194class CountMergeOperator : public AssociativeMergeOperator {
195public:
196 CountMergeOperator() {
197 mergeOperator_ = MergeOperators::CreateUInt64AddOperator();
198 }
199
200 bool Merge(const Slice &key, const Slice *existing_value,
201 const Slice &value, std::string *new_value,
202 Logger *logger) const override {
203 assert(new_value->empty());
204 ++num_merge_operator_calls;
205 if (existing_value == nullptr) {
206 new_value->assign(value.data(), value.size());
207 return true;
208 }
209
210 return mergeOperator_->PartialMerge(key, *existing_value, value, new_value,
211 logger);
212 }
213
214 const char *Name() const override { return "UInt64AddOperator"; }
215
216private:
217 std::shared_ptr<MergeOperator> mergeOperator_;
218};
219
220class CollocatorIterator : public Iterator {
221 char prefixc[sizeof(uint64_t)]{};
222 Iterator *base_iterator_;
223
224public:
225 explicit CollocatorIterator(Iterator *base_iterator) : base_iterator_(base_iterator) {}
226
227 void setPrefix(char *prefix) { memcpy(prefixc, prefix, sizeof(uint64_t)); }
228
229 void SeekToFirst() override { base_iterator_->SeekToFirst(); }
230
231 void SeekToLast() override { base_iterator_->SeekToLast(); }
232
233 void Seek(const rocksdb::Slice &s) override { base_iterator_->Seek(s); }
234
235 void SeekForPrev(const rocksdb::Slice &s) override {
236 base_iterator_->SeekForPrev(s);
237 }
238
239 void Prev() override { base_iterator_->Prev(); }
240
241 void Next() override { base_iterator_->Next(); }
242
243 Slice key() const override;
244
245 Slice value() const override;
246
247 Status status() const override;
248
249 bool Valid() const override;
250
251 bool isValid();
252
253 uint64_t intValue();
254
255 uint64_t intKey();
256};
257
258// rocksdb::CollocatorIterator::CollocatorIterator(Iterator* base_iterator) {}
259
260bool CollocatorIterator::Valid() const {
261 return base_iterator_->Valid() && key().starts_with(std::string(prefixc, 3));
262}
263
264bool CollocatorIterator::isValid() {
265 return base_iterator_->Valid() && key().starts_with(std::string(prefixc, 3));
266 // return key().starts_with(std::string(prefixc,3));
267}
268
269uint64_t CollocatorIterator::intKey() {
270 return DecodeFixed64(base_iterator_->key().data());
271}
272
273uint64_t CollocatorIterator::intValue() {
274 return DecodeFixed64(base_iterator_->value().data());
275}
276
277class VocabEntry {
278public:
279 string word;
280 uint64_t freq;
281};
282
283class CollocatorDB {
284 WriteOptions merge_option_; // for merge
285 char _one[sizeof(uint64_t)]{};
286 Slice _one_slice;
287 vector<VocabEntry> _vocab;
288 uint64_t total = 0;
289 uint64_t sentences = 0;
290 float avg_window_size = 8.0;
291
292protected:
293 std::shared_ptr<DB> db_;
294
295 WriteOptions put_option_;
296 ReadOptions get_option_;
297 WriteOptions delete_option_;
298
299 uint64_t default_{};
300
301 std::shared_ptr<DB> OpenDb(const char *dbname);
302
303 std::shared_ptr<DB> OpenDbForRead(const char *dbname);
304
305public:
306 virtual ~CollocatorDB() = default;
307 void readVocab(const string& fname);
308 string getWord(uint32_t w1);
309
310 uint64_t getWordId(const char *word) const;
311
Marc Kupietzd26b1052024-12-10 16:56:39 +0100312 uint64_t getCorpusSize() const;
313
Marc Kupietz39887082024-11-22 18:06:20 +0100314 CollocatorDB(const char *db_name, bool read_only);
315
316 // public interface of CollocatorDB.
317 // All four functions return false
318 // if the underlying level db operation failed.
319
320 // mapped to a levedb Put
321 bool set(const std::string &key, uint64_t value) {
322 // just treat the internal rep of int64 as the string
323 char buf[sizeof(value)];
324 EncodeFixed64(buf, value);
325 Slice slice(buf, sizeof(value));
326 auto s = db_->Put(put_option_, key, slice);
327
328 if (s.ok()) {
329 return true;
330 } else {
331 std::cerr << s.ToString() << std::endl;
332 return false;
333 }
334 }
335
336 DB *getDb() { return db_.get(); }
337
338 // mapped to a rocksdb Delete
339 bool remove(const std::string &key) {
340 auto s = db_->Delete(delete_option_, key);
341
342 if (s.ok()) {
343 return true;
344 } else {
345 std::cerr << s.ToString() << std::endl;
346 return false;
347 }
348 }
349
350 // mapped to a rocksdb Get
351 bool get(const std::string &key, uint64_t *value) {
352 std::string str;
353 auto s = db_->Get(get_option_, key, &str);
354
355 if (s.IsNotFound()) {
356 // return default value if not found;
357 *value = default_;
358 return true;
359 } else if (s.ok()) {
360 // deserialization
361 if (str.size() != sizeof(uint64_t)) {
362 std::cerr << "value corruption\n";
363 return false;
364 }
365 *value = DecodeFixed64(&str[0]);
366 return true;
367 } else {
368 std::cerr << s.ToString() << std::endl;
369 return false;
370 }
371 }
372
373 uint64_t get(const uint32_t w1, const uint32_t w2, const int8_t dist) {
374 char encoded_key[sizeof(uint64_t)];
375 EncodeFixed64(encoded_key, encodeCollocation(w1, w2, dist));
376 uint64_t value = default_;
377 get(std::string(encoded_key, 8), &value);
378 return value;
379 }
380
381 virtual void inc(const std::string &key) {
382 db_->Merge(merge_option_, key, _one_slice);
383 }
384
385 void inc(const uint64_t key) {
386 char encoded_key[sizeof(uint64_t)];
387 EncodeFixed64(encoded_key, key);
388 db_->Merge(merge_option_, std::string(encoded_key, 8), _one_slice);
389 }
390
391 virtual void inc(uint32_t w1, uint32_t w2, uint8_t dist);
392
393 void dump(uint32_t w1, uint32_t w2, int8_t dist) const;
394
395 vector<Collocator> get_collocators(uint32_t w1);
396
397 vector<Collocator> get_collocators(uint32_t w1, uint32_t max_w2);
398
399 vector<Collocator> get_collocation_scores(uint32_t w1, uint32_t w2);
400
401 vector<Collocator> get_collocators(uint32_t w1, uint32_t min_w2,
402 uint32_t max_w2);
403
404 void applyCAMeasures(uint32_t w1, uint32_t w2,
405 uint64_t *sumWindow, uint64_t sum,
406 int usedPositions, int true_window_size,
407 Collocator *result) const;
408
409 void dumpSparseLlr(uint32_t w1, uint32_t min_cooccur);
410
411 string collocators2json(uint32_t w1, const vector<Collocator>& collocators);
412
413 // mapped to a rocksdb Merge operation
414 virtual bool add(const std::string &key, uint64_t value) {
415 char encoded[sizeof(uint64_t)];
416 EncodeFixed64(encoded, value);
417 Slice slice(encoded, sizeof(uint64_t));
418 auto s = db_->Merge(merge_option_, key, slice);
419
420 if (s.ok()) {
421 return true;
422 } else {
423 std::cerr << s.ToString() << std::endl;
424 return false;
425 }
426 }
427
428 CollocatorIterator *SeekIterator(uint64_t w1, uint64_t w2, int8_t dist) const;
429};
430
431CollocatorDB::CollocatorDB(const char *db_name,
432 bool read_only = false) {
433 // merge_option_.sync = true;
434 if (read_only)
435 db_ = OpenDbForRead(strdup(db_name));
436 else
437 db_ = OpenDb(db_name);
438 assert(db_);
439 uint64_t one = 1;
440 EncodeFixed64(_one, one);
441 _one_slice = Slice(_one, sizeof(uint64_t));
442}
443
444void CollocatorDB::inc(const uint32_t w1, const uint32_t w2,
445 const uint8_t dist) {
446 inc(encodeCollocation(w1, w2, dist));
447}
448
449void CollocatorDB::readVocab(const string& fname) {
450 char strbuf[2048];
451 uint64_t freq;
452 FILE *fin = fopen(fname.c_str(), "rb");
453 if (fin == nullptr) {
454 cout << "Vocabulary file " << fname << " not found\n";
455 exit(1);
456 }
457 uint64_t i = 0;
458 while (fscanf(fin, "%s %lu", strbuf, &freq) == 2) {
459 _vocab.push_back({strbuf, freq});
460 total += freq;
461 i++;
462 }
463 fclose(fin);
464
465 char size_fname[256];
466 strcpy(size_fname, fname.c_str());
467 char *pos = strstr(size_fname, ".vocab");
468 if (pos) {
469 *pos = 0;
470 strcat(size_fname, ".size");
471 FILE *fp = fopen(size_fname, "r");
472 if (fp != nullptr) {
473 fscanf(fp, "%lu", &sentences);
474 fscanf(fp, "%lu", &total);
475 float sl = (float)total / (float)sentences;
476 float w = WINDOW_SIZE;
477 avg_window_size =
478 ((sl > 2 * w ? (sl - 2 * w) * 2 * w : 0) + (double)w * (3 * w - 1)) /
479 sl;
480 fprintf(stdout,
481 "Size corrections found: corpus size: %lu tokens in %lu "
482 "sentences, avg. sentence size: %f, avg. window size: %f\n",
483 total, sentences, sl, avg_window_size);
484 fclose(fp);
485 } else {
486 // std::cout << "size file " << size_fname << " not found\n";
487 }
488 } else {
489 std::cout << "cannot determine size file " << size_fname << "\n";
490 }
491}
492
493std::shared_ptr<DB> CollocatorDB::OpenDbForRead(const char *name) {
494 DB *db;
495 Options options;
496 options.env->SetBackgroundThreads(4);
497 options.create_if_missing = true;
498 options.merge_operator = std::make_shared<CountMergeOperator>();
499 options.max_successive_merges = 0;
500 // options.prefix_extractor.reset(NewFixedPrefixTransform(8));
501 options.IncreaseParallelism();
502 options.OptimizeLevelStyleCompaction();
503 options.prefix_extractor.reset(NewFixedPrefixTransform(3));
504 ostringstream dbname, vocabname;
505 dbname << name << ".rocksdb";
506 auto s = DB::OpenForReadOnly(options, dbname.str(), &db);
507 if (!s.ok()) {
508 std::cerr << s.ToString() << std::endl;
509 assert(false);
510 }
511 vocabname << name << ".vocab";
512 readVocab(vocabname.str());
513 return std::shared_ptr<DB>(db);
514}
515
516std::shared_ptr<DB> CollocatorDB::OpenDb(const char *dbname) {
517 DB *db;
518 Options options;
519
520 options.env->SetBackgroundThreads(4);
521 options.create_if_missing = true;
522 options.merge_operator = std::make_shared<CountMergeOperator>();
523 options.max_successive_merges = 0;
524 // options.prefix_extractor.reset(NewFixedPrefixTransform(8));
525 options.IncreaseParallelism();
526 options.OptimizeLevelStyleCompaction();
527 // options.max_write_buffer_number = 48;
528 // options.max_background_jobs = 48;
529 // options.allow_concurrent_memtable_write=true;
530 // options.memtable_factory.reset(NewHashLinkListRepFactory(200000));
531 // options.enable_write_thread_adaptive_yield = 1;
532 // options.allow_concurrent_memtable_write = 1;
533 // options.memtable_factory.reset(new SkipListFactory);
534 // options.write_buffer_size = 1 << 22;
535 // options.allow_mmap_reads = true;
536 // options.allow_mmap_writes = true;
537 // options.max_background_compactions = 40;
538 // BlockBasedTableOptions table_options;
539 // table_options.filter_policy.reset(NewBloomFilterPolicy(24, false));
540 // options.bloom_locality = 1;
541 // std::shared_ptr<Cache> cache = NewLRUCache(512 * 1024 * 1024);
542 // table_options.block_cache = cache;
543 // options.table_factory.reset(NewBlockBasedTableFactory(table_options));
544 Status s;
545 // DestroyDB(dbname, Options());
546 s = DB::Open(options, dbname, &db);
547 if (!s.ok()) {
548 std::cerr << s.ToString() << std::endl;
549 assert(false);
550 }
551 total = 1000;
552 return std::shared_ptr<DB>(db);
553}
554
555CollocatorIterator *
556CollocatorDB::SeekIterator(uint64_t w1, uint64_t w2, int8_t dist) const {
557 ReadOptions options;
558 options.prefix_same_as_start = true;
559 char prefixc[sizeof(uint64_t)];
560 EncodeFixed64(prefixc, encodeCollocation(w1, w2, dist));
561 Iterator *it = db_->NewIterator(options);
562 auto *cit = new CollocatorIterator(it);
563 if (w2 > 0)
564 cit->Seek(std::string(prefixc, 6));
565 else
566 cit->Seek(std::string(prefixc, 3));
567 cit->setPrefix(prefixc);
568 return cit;
569}
570
571void CollocatorDB::dump(uint32_t w1, uint32_t w2, int8_t dist) const {
572 auto it = std::unique_ptr<CollocatorIterator>(SeekIterator(w1, w2, dist));
573 for (; it->isValid(); it->Next()) {
574 uint64_t value = it->intValue();
575 uint64_t key = it->intKey();
576 std::cout << "w1:" << W1(key) << ", w2:" << W2(key)
577 << ", dist:" << (int32_t)DIST(key) << " - count:" << value
578 << std::endl;
579 }
580 std::cout << "ready dumping\n";
581}
582
583bool sortByNpmi(const Collocator &lhs, const Collocator &rhs) {
584 return lhs.npmi > rhs.npmi;
585}
586
587bool sortByLfmd(const Collocator &lhs, const Collocator &rhs) {
588 return lhs.lfmd > rhs.lfmd;
589}
590
591bool sortByLlr(const Collocator &lhs, const Collocator &rhs) {
592 return lhs.llr > rhs.llr;
593}
594
595bool sortByLogDice(const Collocator &lhs, const Collocator &rhs) {
596 return lhs.logdice > rhs.logdice;
597}
598
599bool sortByLogDiceAF(const Collocator &lhs, const Collocator &rhs) {
600 return lhs.ldaf > rhs.ldaf;
601}
602
603void CollocatorDB::applyCAMeasures(
604 const uint32_t w1, const uint32_t w2, uint64_t *sumWindow,
605 const uint64_t sum, const int usedPositions, int true_window_size,
606 Collocator *result) const {
607 uint64_t f1 = _vocab[w1].freq, f2 = _vocab[w2].freq;
608 double o = sum, r1 = f1 * true_window_size, c1 = f2, e = r1 * c1 / total,
609 pmi = log2(o / e), md = log2(o * o / e), lfmd = log2(o * o * o / e),
Marc Kupietze889cec2024-11-23 12:08:42 +0100610 llr = ca_ll(f1, f2, sum, total, true_window_size),
611 md_nws = ca_md(f1, f2, sum, total, 2 * WINDOW_SIZE),
612 ld = ca_logdice(f1, f2, sum, total, true_window_size);
Marc Kupietz39887082024-11-22 18:06:20 +0100613
614 int bestWindow = usedPositions;
615 double bestAF = ld;
616 // if(f1<75000000)
617 // #pragma omp parallel for reduction(max:bestAF)
618 // #pragma omp target teams distribute parallel for reduction(max:bestAF)
619 // map(tofrom:bestAF,currentAF,bestWindow,usedPositions)
620 for (int bitmask = 1; bitmask < (1 << (2 * WINDOW_SIZE)); bitmask++) {
621 if ((bitmask & usedPositions) == 0 || (bitmask & ~usedPositions) > 0)
622 continue;
623 uint64_t currentWindowSum = 0;
624 // #pragma omp target teams distribute parallel for
625 // reduction(+:currentWindowSum) map(tofrom:bitmask,usedPositions)
626 for (int pos = 0; pos < 2 * WINDOW_SIZE; pos++) {
627 if (((1 << pos) & bitmask & usedPositions) != 0)
628 currentWindowSum += sumWindow[pos];
629 }
630 double currentAF = ca_logdice(f1, f2, currentWindowSum, total,
631 __builtin_popcount(bitmask));
632 if (currentAF > bestAF) {
633 bestAF = currentAF;
634 bestWindow = bitmask;
635 }
636 }
637
638 *result = {w2,
639 f2,
640 sum,
641 pmi,
642 pmi / (-log2(o / total / true_window_size)),
643 llr,
644 lfmd,
645 md,
Marc Kupietze889cec2024-11-23 12:08:42 +0100646 md_nws,
Marc Kupietz39887082024-11-22 18:06:20 +0100647 sumWindow[WINDOW_SIZE],
648 sumWindow[WINDOW_SIZE - 1],
649 ca_pmi(f1, f2, sumWindow[WINDOW_SIZE], total, 1),
650 ca_pmi(f1, f2, sumWindow[WINDOW_SIZE - 1], total, 1),
651 ca_dice(f1, f2, sum, total, true_window_size),
652 ld,
653 bestAF,
654 usedPositions,
655 bestWindow};
656}
657
658std::vector<Collocator>
659CollocatorDB::get_collocators(uint32_t w1, uint32_t min_w2,
660 uint32_t max_w2) {
661 std::vector<Collocator> collocators;
662 uint64_t w2, last_w2 = 0xffffffffffffffff;
663 uint64_t maxv = 0, sum = 0;
664 auto *sumWindow =
665 static_cast<uint64_t *>(malloc(sizeof(uint64_t) * 2 * WINDOW_SIZE));
666 memset(sumWindow, 0, sizeof(uint64_t) * 2 * WINDOW_SIZE);
667 int true_window_size = 1;
668 int usedPositions = 0;
669
670 if (w1 > _vocab.size()) {
671 std::cout << w1 << "> vocabulary size " << _vocab.size() << "\n";
672 w1 -= _vocab.size();
673 }
674#ifdef DEBUG
675 std::cout << "Searching for collocates of " << _vocab[w1].word << "\n";
676#endif
677 // #pragma omp parallel num_threads(40)
678 // #pragma omp single
679 for (auto it =
680 std::unique_ptr<CollocatorIterator>(SeekIterator(w1, min_w2, 0));
681 it->isValid(); it->Next()) {
682 uint64_t value = it->intValue(), key = it->intKey();
683 if ((w2 = W2(key)) > max_w2)
684 continue;
685 if (last_w2 == 0xffffffffffffffff)
686 last_w2 = w2;
687 if (w2 != last_w2) {
688 if (sum >= FREQUENCY_THRESHOLD) {
689 collocators.push_back({});
690 Collocator *result = &(collocators[collocators.size() - 1]);
691 // #pragma omp task firstprivate(last_w2, sumWindow, sum, usedPositions,
692 // true_window_size) shared(w1, result) if(sum > 1000000)
693 {
694 // uint64_t *nsw = (uint64_t *)malloc(sizeof(uint64_t) * 2
695 // *WINDOW_SIZE); memcpy(nsw, sumWindow, sizeof(uint64_t) * 2
696 // *WINDOW_SIZE);
697 applyCAMeasures(w1, last_w2, sumWindow, sum, usedPositions,
698 true_window_size, result);
699 // free(nsw);
700 }
701 }
702 memset(sumWindow, 0, 2 * WINDOW_SIZE * sizeof(uint64_t));
703 usedPositions = 1 << (-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0));
704 sumWindow[-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0)] = value;
705 last_w2 = w2;
706 maxv = value;
707 sum = value;
708 true_window_size = 1;
709 if (min_w2 == max_w2 && w2 != min_w2)
710 break;
711 } else {
712 sum += value;
713 if (value > maxv)
714 maxv = value;
715 usedPositions |=
716 1 << (-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0));
717 sumWindow[-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0)] = value;
718 true_window_size++;
719 }
720 }
721
722 // #pragma omp taskwait
723 sort(collocators.begin(), collocators.end(), sortByLogDiceAF);
724
725#ifdef DEBUG
726 int i = 0;
727 for (Collocator c : collocators) {
728 if (i++ > 10)
729 break;
730 std::cout << "w1:" << _vocab[w1].word << ", w2: *" << _vocab[c.w2].word
731 << "*"
732 << "\t f(w1):" << _vocab[w1].freq
733 << "\t f(w2):" << _vocab[c.w2].freq << "\t f(w1, w2):" << c.raw
734 << "\t pmi:" << c.pmi << "\t npmi:" << c.npmi
735 << "\t llr:" << c.llr << "\t md:" << c.md << "\t lfmd:" << c.lfmd
736 << "\t total:" << total << std::endl;
737 }
738#endif
739
740 return collocators;
741}
742
743std::vector<Collocator>
744CollocatorDB::get_collocation_scores(uint32_t w1, uint32_t w2) {
745 return get_collocators(w1, w2, w2);
746}
747
748std::vector<Collocator> CollocatorDB::get_collocators(uint32_t w1) {
749 return get_collocators(w1, 0, UINT32_MAX);
750}
751
752void CollocatorDB::dumpSparseLlr(uint32_t w1, uint32_t min_cooccur) {
753 std::vector<Collocator> collocators;
754 std::stringstream stream;
755 uint64_t w2, last_w2 = 0xffffffffffffffff;
756 uint64_t maxv = 0, total_w1 = 0;
757 bool first = true;
758 for (auto it = std::unique_ptr<CollocatorIterator>(SeekIterator(w1, 0, 0));
759 it->isValid(); it->Next()) {
760 uint64_t value = it->intValue(), key = it->intKey();
761 w2 = W2(key);
762 total_w1 += value;
763 if (last_w2 == 0xffffffffffffffff)
764 last_w2 = w2;
765 if (w2 != last_w2) {
766 if (maxv >= min_cooccur) {
767 double llr =
768 ca_ll(_vocab[w1].freq, _vocab[last_w2].freq, maxv, total, 1);
769 if (first)
770 first = false;
771 else
772 stream << " ";
773 stream << w2 << " " << llr;
774 }
775 last_w2 = w2;
776 maxv = value;
777 } else {
778 if (value > maxv)
779 maxv = value;
780 }
781 }
782 if (first)
783 stream << "1 0.0";
784 stream << "\n";
785 std::cout << stream.str();
786}
787
788Slice CollocatorIterator::key() const {
789 return base_iterator_->key();
790}
791
792Slice CollocatorIterator::value() const {
793 return base_iterator_->value();
794}
795
796Status CollocatorIterator::status() const {
797 return base_iterator_->status();
798}
799
800}; // namespace rocksdb
801
802string CollocatorDB::getWord(uint32_t w1) { return _vocab[w1].word; }
803
804uint64_t CollocatorDB::getWordId(const char *word) const {
Marc Kupietz979580e2024-11-21 18:05:07 +0100805 for (uint64_t i = 0; i < _vocab.size(); i++) {
806 if (strcmp(_vocab[i].word.c_str(), word) == 0)
807 return i;
808 }
809 return 0;
810}
811
Marc Kupietzd26b1052024-12-10 16:56:39 +0100812uint64_t CollocatorDB::getCorpusSize() const {
813 return total;
814}
815
Marc Kupietz39887082024-11-22 18:06:20 +0100816string CollocatorDB::collocators2json(uint32_t w1,
817 const vector<Collocator>& collocators) {
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100818 ostringstream s;
Marc Kupietz0dd86ef2018-01-11 22:23:17 +0100819 int i = 0;
Marc Kupietz39887082024-11-22 18:06:20 +0100820 s << " { \"f1\": " << _vocab[w1].freq << "," << R"("w1":")"
821 << string(_vocab[w1].word) << "\", " << "\"N\": " << total << ", "
822 << "\"collocates\": [";
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100823 bool first = true;
824 for (Collocator c : collocators) {
Marc Kupietz39887082024-11-22 18:06:20 +0100825 if (strncmp(_vocab[c.w2].word.c_str(), "quot", 4) == 0)
826 continue;
Marc Kupietz0dd86ef2018-01-11 22:23:17 +0100827 if (i++ > 200)
828 break;
Marc Kupietz12af0192021-03-13 18:05:14 +0100829 if (!first)
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100830 s << ",\n";
831 else
832 first = false;
833 s << "{"
Marc Kupietz39887082024-11-22 18:06:20 +0100834 "\"word\":\""
835 << (string(_vocab[c.w2].word) == "<num>"
836 ? string("###")
837 : string(_vocab[c.w2].word))
838 << "\"," << "\"f2\":" << c.f2 << "," << "\"f\":" << c.raw << ","
839 << "\"npmi\":" << c.npmi << "," << "\"pmi\":" << c.pmi << ","
840 << "\"llr\":" << c.llr << "," << "\"lfmd\":" << c.lfmd << ","
Marc Kupietze889cec2024-11-23 12:08:42 +0100841 << "\"md\":" << c.md << "," << "\"md_nws\":" << c.md_nws << "," << "\"dice\":" << c.dice << ","
Marc Kupietz39887082024-11-22 18:06:20 +0100842 << "\"ld\":" << c.logdice << "," << "\"ln_count\":" << c.left_raw << ","
843 << "\"rn_count\":" << c.right_raw << "," << "\"ln_pmi\":" << c.left_pmi
844 << "," << "\"rn_pmi\":" << c.right_pmi << "," << "\"ldaf\":" << c.ldaf
845 << "," << "\"win\":" << c.window << "," << "\"afwin\":" << c.af_window
846 << "}";
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100847 }
Marc Kupietze9627152019-02-04 12:32:12 +0100848 s << "]}\n";
Marc Kupietz0421d092021-03-13 18:05:14 +0100849 // std::cout << s.str();
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100850 return s.str();
851}
852
Marc Kupietz39887082024-11-22 18:06:20 +0100853typedef CollocatorDB COLLOCATORS;
Marc Kupietz06c9a9f2018-01-02 16:56:43 +0100854
855extern "C" {
Marc Kupietz12af0192021-03-13 18:05:14 +0100856#ifdef __clang__
857#pragma clang diagnostic push
858#pragma ide diagnostic ignored "OCUnusedGlobalDeclarationInspection"
859#endif
Marc Kupietz39887082024-11-22 18:06:20 +0100860DLL_EXPORT COLLOCATORS *open_collocatordb_for_write(char *dbname) {
861 return new CollocatorDB(dbname, false);
862}
Marc Kupietz12af0192021-03-13 18:05:14 +0100863
Marc Kupietz39887082024-11-22 18:06:20 +0100864DLL_EXPORT COLLOCATORS *open_collocatordb(char *dbname) {
865 return new CollocatorDB(dbname, true);
866}
Marc Kupietz06c9a9f2018-01-02 16:56:43 +0100867
Marc Kupietz39887082024-11-22 18:06:20 +0100868DLL_EXPORT void inc_collocator(COLLOCATORS *db, uint32_t w1, uint32_t w2,
869 int8_t dist) {
870 db->inc(w1, w2, dist);
871}
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100872
Marc Kupietz39887082024-11-22 18:06:20 +0100873DLL_EXPORT void dump_collocators(COLLOCATORS *db, uint32_t w1, uint32_t w2,
874 int8_t dist) {
875 db->dump(w1, w2, dist);
876}
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100877
Marc Kupietz39887082024-11-22 18:06:20 +0100878DLL_EXPORT COLLOCATORS *get_collocators(COLLOCATORS *db, uint32_t w1) {
879 std::vector<Collocator> c = db->get_collocators(w1);
880 if (c.empty())
881 return nullptr;
882 uint64_t size = c.size() + sizeof c[0];
883 auto *p = (COLLOCATORS *)malloc(size);
884 memcpy(p, c.data(), size);
885 return p;
886}
Marc Kupietz88d116b2021-03-13 18:05:14 +0100887
Marc Kupietz39887082024-11-22 18:06:20 +0100888DLL_EXPORT COLLOCATORS *get_collocation_scores(COLLOCATORS *db, uint32_t w1,
889 uint32_t w2) {
890 std::vector<Collocator> c = db->get_collocation_scores(w1, w2);
891 if (c.empty())
892 return nullptr;
893 uint64_t size = c.size() + sizeof c[0];
894 auto *p = (COLLOCATORS *)malloc(size);
895 memcpy(p, c.data(), size);
896 return p;
897}
Marc Kupietzca3a52e2018-06-05 14:16:23 +0200898
Marc Kupietz39887082024-11-22 18:06:20 +0100899DLL_EXPORT char *get_word(COLLOCATORS *db, uint32_t w) {
900 return strdup(db->getWord(w).c_str());
901}
Marc Kupietz979580e2024-11-21 18:05:07 +0100902
Marc Kupietz39887082024-11-22 18:06:20 +0100903DLL_EXPORT uint64_t get_word_id(COLLOCATORS *db, char *word) {
904 return db->getWordId(word);
905}
Marc Kupietzb4a683c2021-03-14 09:19:44 +0100906
Marc Kupietz39887082024-11-22 18:06:20 +0100907DLL_EXPORT void read_vocab(COLLOCATORS *db, char *fname) {
908 std::string fName(fname);
909 db->readVocab(fName);
910}
Marc Kupietz88d116b2021-03-13 18:05:14 +0100911
Marc Kupietz39887082024-11-22 18:06:20 +0100912DLL_EXPORT const char *get_collocators_as_json(COLLOCATORS *db, uint32_t w1) {
913 return strdup(db->collocators2json(w1, db->get_collocators(w1)).c_str());
914}
Marc Kupietzb4a683c2021-03-14 09:19:44 +0100915
Marc Kupietz39887082024-11-22 18:06:20 +0100916DLL_EXPORT const char *
917get_collocation_scores_as_json(COLLOCATORS *db, uint32_t w1, uint32_t w2) {
918 return strdup(
919 db->collocators2json(w1, db->get_collocation_scores(w1, w2)).c_str());
920}
921
922DLL_EXPORT const char *get_version() { return PROJECT_VERSION; }
Marc Kupietz6208fd72024-11-15 15:46:19 +0100923
Marc Kupietzd26b1052024-12-10 16:56:39 +0100924DLL_EXPORT uint64_t get_corpus_size(COLLOCATORS *db) { return db->getCorpusSize(); };
925
Marc Kupietz12af0192021-03-13 18:05:14 +0100926#ifdef __clang__
927#pragma clang diagnostic push
928#endif
Marc Kupietz06c9a9f2018-01-02 16:56:43 +0100929}