blob: 57a01d9d154a00c118f36f9c2942e5339fbec109 [file] [log] [blame]
Marc Kupietz4b799e92018-01-02 11:04:56 +01001#define EXPORT __attribute__((visibility("visible")))
2#define IMPORT
Marc Kupietz12af0192021-03-13 18:05:14 +01003
Marc Kupietz39887082024-11-22 18:06:20 +01004#include "config.h"
5#include "export.h"
6#include "merge_operators.h"
Marc Kupietz28cc53e2017-12-23 17:24:55 +01007#include "rocksdb/db.h"
8#include "rocksdb/env.h"
Marc Kupietzc8ddf452018-01-07 21:33:12 +01009#include "rocksdb/table.h"
Marc Kupietze889cec2024-11-23 12:08:42 +010010#include "rocksdb/slice.h"
Marc Kupietz39887082024-11-22 18:06:20 +010011#include <algorithm>
12#include <cassert>
13#include <cmath>
14#include <cstdint>
15#include <iostream>
16#include <memory>
Marc Kupietz28cc53e2017-12-23 17:24:55 +010017#include <rocksdb/merge_operator.h>
Marc Kupietzc8ddf452018-01-07 21:33:12 +010018#include <rocksdb/slice_transform.h>
Marc Kupietz39887082024-11-22 18:06:20 +010019#include <sstream> // for ostringstream
20#include <string>
21#include <vector>
Marc Kupietz28cc53e2017-12-23 17:24:55 +010022
Marc Kupietz75af60f2019-01-22 22:34:29 +010023#define WINDOW_SIZE 5
Marc Kupietz98cbcdc2019-01-21 17:11:27 +010024#define FREQUENCY_THRESHOLD 5
Marc Kupietz28cc53e2017-12-23 17:24:55 +010025#define IS_BIG_ENDIAN (*(uint16_t *)"\0\xff" < 0x100)
Marc Kupietz39887082024-11-22 18:06:20 +010026#define encodeCollocation(w1, w2, dist) \
27 (((uint64_t)dist << 56) | ((uint64_t)w2 << 24) | w1)
Marc Kupietz18375e12017-12-24 10:11:18 +010028#define W1(key) (uint64_t)(key & 0xffffff)
29#define W2(key) (uint64_t)((key >> 24) & 0xffffff)
30#define DIST(key) (int8_t)((uint64_t)((key >> 56) & 0xff))
Marc Kupietzc8ddf452018-01-07 21:33:12 +010031
32typedef struct {
33 uint64_t freq;
34 char *word;
Marc Kupietz12af0192021-03-13 18:05:14 +010035} vocab_entry;
Marc Kupietzc8ddf452018-01-07 21:33:12 +010036
37// typedef struct Collocator {
38// uint64_t w2;
39// uint64_t sum;
40// };
41
Marc Kupietz28cc53e2017-12-23 17:24:55 +010042using namespace rocksdb;
Marc Kupietzc8ddf452018-01-07 21:33:12 +010043using namespace std;
Marc Kupietz28cc53e2017-12-23 17:24:55 +010044
Marc Kupietz4b799e92018-01-02 11:04:56 +010045namespace rocksdb {
Marc Kupietz39887082024-11-22 18:06:20 +010046class Collocator {
47public:
48 uint32_t w2;
49 uint64_t f2;
50 uint64_t raw;
51 double pmi;
52 double npmi;
53 double llr;
54 double lfmd;
55 double md;
Marc Kupietze889cec2024-11-23 12:08:42 +010056 double md_nws;
Marc Kupietz39887082024-11-22 18:06:20 +010057 uint64_t left_raw;
58 uint64_t right_raw;
59 double left_pmi;
60 double right_pmi;
61 double dice;
62 double logdice;
63 double ldaf;
64 int window;
65 int af_window;
Marc Kupietz28cc53e2017-12-23 17:24:55 +010066};
Marc Kupietz06c9a9f2018-01-02 16:56:43 +010067
Marc Kupietz39887082024-11-22 18:06:20 +010068size_t num_merge_operator_calls;
69
70void resetNumMergeOperatorCalls() { num_merge_operator_calls = 0; }
71
72size_t num_partial_merge_calls;
73
74void resetNumPartialMergeCalls() { num_partial_merge_calls = 0; }
75
76inline void EncodeFixed64(char *buf, uint64_t value) {
77 if (!IS_BIG_ENDIAN) {
78 memcpy(buf, &value, sizeof(value));
79 } else {
80 buf[0] = value & 0xff;
81 buf[1] = (value >> 8) & 0xff;
82 buf[2] = (value >> 16) & 0xff;
83 buf[3] = (value >> 24) & 0xff;
84 buf[4] = (value >> 32) & 0xff;
85 buf[5] = (value >> 40) & 0xff;
86 buf[6] = (value >> 48) & 0xff;
87 buf[7] = (value >> 56) & 0xff;
88 }
Marc Kupietz4a5e08a2018-06-05 11:07:11 +020089}
90
Marc Kupietz39887082024-11-22 18:06:20 +010091inline uint32_t DecodeFixed32(const char *ptr) {
92 if (!IS_BIG_ENDIAN) {
93 // Load the raw bytes
94 uint32_t result;
95 memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
96 return result;
97 } else {
98 return ((static_cast<uint32_t>(static_cast<unsigned char>(ptr[0]))) |
99 (static_cast<uint32_t>(static_cast<unsigned char>(ptr[1])) << 8) |
100 (static_cast<uint32_t>(static_cast<unsigned char>(ptr[2])) << 16) |
101 (static_cast<uint32_t>(static_cast<unsigned char>(ptr[3])) << 24));
102 }
103}
104
105inline uint64_t DecodeFixed64(const char *ptr) {
106 if (!IS_BIG_ENDIAN) {
107 // Load the raw bytes
108 uint64_t result;
109 memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
110 return result;
111 } else {
112 uint64_t lo = DecodeFixed32(ptr);
113 uint64_t hi = DecodeFixed32(ptr + 4);
114 return (hi << 32) | lo;
115 }
116}
117
118static inline double ca_pmi(uint64_t f1, uint64_t f2, uint64_t f12,
119 uint64_t total, double window_size) {
120 double r1 = f1 * window_size, c1 = f2, e = r1 * c1 / total, o = f12;
121 if (f12 < FREQUENCY_THRESHOLD)
122 return -1.0;
123 else
124 return log2(o / e);
125}
126
127// Bouma, Gerlof (2009): <a
128// href="https://svn.spraakdata.gu.se/repos/gerlof/pub/www/Docs/npmi-pfd.pdf">
129// Normalized (pointwise) mutual information in collocation extraction</a>. In
130// Proceedings of GSCL.
131static double ca_npmi(uint64_t f1, uint64_t f2, uint64_t f12,
132 uint64_t total, double window_size) {
133 double r1 = f1 * window_size, c1 = f2, e = r1 * c1 / total, o = f12;
134 if (f12 < FREQUENCY_THRESHOLD)
135 return -1.0;
136 else
137 return log2(o / e) / (-log2(o / total / window_size));
138}
139
140// Thanopoulos, A., Fakotakis, N., Kokkinakis, G.: Comparative evaluation of
141// collocation extraction metrics. In: International Conference on Language
142// Resources and Evaluation (LREC-2002). (2002) 620–625 double md =
143// log2(pow((double)max * window_size / total, 2) / (window_size *
144// ((double)_vocab[w1].freq/total) * ((double)_vocab[last_w2].freq/total)));
145static double ca_md(uint64_t f1, uint64_t f2, uint64_t f12,
146 uint64_t total, double window_size) {
147 const double r1 = f1 * window_size;
148 const double c1 = f2;
149 const double e = r1 * c1 / total;
150 const double o = f12;
151 return log2(o * o / e);
152}
153
154static double ca_lfmd(uint64_t f1, uint64_t f2, uint64_t f12,
155 uint64_t total, double window_size) {
156 double r1 = f1 * window_size, c1 = f2, e = r1 * c1 / total, o = f12;
157 if (f12 == 0)
158 return 0;
159 return log2(o * o * o / e);
160}
161
162// Evert, Stefan (2004): The Statistics of Word Cooccurrences: Word Pairs and
163// Collocations. PhD dissertation, IMS, University of Stuttgart. Published in
164// 2005, URN urn:nbn:de:bsz:93-opus-23714. Free PDF available from
165// http://purl.org/stefan.evert/PUB/Evert2004phd.pdf
166static double ca_ll(uint64_t w1, uint64_t w2, uint64_t w12, uint64_t n,
167 uint64_t window_size) {
168 double r1 = (double)w1 * window_size, r2 = (double)n - r1, c1 = w2,
169 c2 = n - c1, o11 = w12, o12 = r1 - o11, o21 = c1 - w12, o22 = r2 - o21,
170 e11 = r1 * c1 / n, e12 = r1 * c2 / n, e21 = r2 * c1 / n,
171 e22 = r2 * c2 / n;
172 return (2 * ((o11 > 0 ? o11 * log(o11 / e11) : 0) +
173 (o12 > 0 ? o12 * log(o12 / e12) : 0) +
174 (o21 > 0 ? o21 * log(o21 / e21) : 0) +
175 (o22 > 0 ? o22 * log(o22 / e22) : 0)));
176}
177
178static double ca_dice(uint64_t w1, uint64_t w2, uint64_t w12, uint64_t n,
179 uint64_t window_size) {
180 double r1 = (double)w1 * window_size, c1 = w2;
181 return 2 * w12 / (c1 + r1);
182}
183
184// Rychlý, Pavel (2008): <a
185// href="http://www.fi.muni.cz/usr/sojka/download/raslan2008/13.pdf">A
186// lexicographer-friendly association score.</a> In Proceedings of Recent
187// Advances in Slavonic Natural Language Processing, RASLAN, 6–9.
188static double ca_logdice(uint64_t w1, uint64_t w2, uint64_t w12,
189 uint64_t n, uint64_t window_size) {
190 double r1 = (double)w1 * window_size, c1 = w2;
191 return 14 + log2(2 * w12 / (c1 + r1));
192}
193
194class CountMergeOperator : public AssociativeMergeOperator {
195public:
196 CountMergeOperator() {
197 mergeOperator_ = MergeOperators::CreateUInt64AddOperator();
198 }
199
200 bool Merge(const Slice &key, const Slice *existing_value,
201 const Slice &value, std::string *new_value,
202 Logger *logger) const override {
203 assert(new_value->empty());
204 ++num_merge_operator_calls;
205 if (existing_value == nullptr) {
206 new_value->assign(value.data(), value.size());
207 return true;
208 }
209
210 return mergeOperator_->PartialMerge(key, *existing_value, value, new_value,
211 logger);
212 }
213
214 const char *Name() const override { return "UInt64AddOperator"; }
215
216private:
217 std::shared_ptr<MergeOperator> mergeOperator_;
218};
219
220class CollocatorIterator : public Iterator {
221 char prefixc[sizeof(uint64_t)]{};
222 Iterator *base_iterator_;
223
224public:
225 explicit CollocatorIterator(Iterator *base_iterator) : base_iterator_(base_iterator) {}
226
227 void setPrefix(char *prefix) { memcpy(prefixc, prefix, sizeof(uint64_t)); }
228
229 void SeekToFirst() override { base_iterator_->SeekToFirst(); }
230
231 void SeekToLast() override { base_iterator_->SeekToLast(); }
232
233 void Seek(const rocksdb::Slice &s) override { base_iterator_->Seek(s); }
234
235 void SeekForPrev(const rocksdb::Slice &s) override {
236 base_iterator_->SeekForPrev(s);
237 }
238
239 void Prev() override { base_iterator_->Prev(); }
240
241 void Next() override { base_iterator_->Next(); }
242
243 Slice key() const override;
244
245 Slice value() const override;
246
247 Status status() const override;
248
249 bool Valid() const override;
250
251 bool isValid();
252
253 uint64_t intValue();
254
255 uint64_t intKey();
256};
257
258// rocksdb::CollocatorIterator::CollocatorIterator(Iterator* base_iterator) {}
259
260bool CollocatorIterator::Valid() const {
261 return base_iterator_->Valid() && key().starts_with(std::string(prefixc, 3));
262}
263
264bool CollocatorIterator::isValid() {
265 return base_iterator_->Valid() && key().starts_with(std::string(prefixc, 3));
266 // return key().starts_with(std::string(prefixc,3));
267}
268
269uint64_t CollocatorIterator::intKey() {
270 return DecodeFixed64(base_iterator_->key().data());
271}
272
273uint64_t CollocatorIterator::intValue() {
274 return DecodeFixed64(base_iterator_->value().data());
275}
276
277class VocabEntry {
278public:
279 string word;
280 uint64_t freq;
281};
282
283class CollocatorDB {
284 WriteOptions merge_option_; // for merge
285 char _one[sizeof(uint64_t)]{};
286 Slice _one_slice;
287 vector<VocabEntry> _vocab;
288 uint64_t total = 0;
289 uint64_t sentences = 0;
290 float avg_window_size = 8.0;
291
292protected:
293 std::shared_ptr<DB> db_;
294
295 WriteOptions put_option_;
296 ReadOptions get_option_;
297 WriteOptions delete_option_;
298
299 uint64_t default_{};
300
301 std::shared_ptr<DB> OpenDb(const char *dbname);
302
303 std::shared_ptr<DB> OpenDbForRead(const char *dbname);
304
305public:
306 virtual ~CollocatorDB() = default;
307 void readVocab(const string& fname);
308 string getWord(uint32_t w1);
309
310 uint64_t getWordId(const char *word) const;
311
Marc Kupietzd26b1052024-12-10 16:56:39 +0100312 uint64_t getCorpusSize() const;
313
Marc Kupietz21b964c2024-12-10 17:10:50 +0100314 uint64_t getWordFrequency(uint64_t w1);
315
Marc Kupietz39887082024-11-22 18:06:20 +0100316 CollocatorDB(const char *db_name, bool read_only);
317
318 // public interface of CollocatorDB.
319 // All four functions return false
320 // if the underlying level db operation failed.
321
322 // mapped to a levedb Put
323 bool set(const std::string &key, uint64_t value) {
324 // just treat the internal rep of int64 as the string
325 char buf[sizeof(value)];
326 EncodeFixed64(buf, value);
327 Slice slice(buf, sizeof(value));
328 auto s = db_->Put(put_option_, key, slice);
329
330 if (s.ok()) {
331 return true;
332 } else {
333 std::cerr << s.ToString() << std::endl;
334 return false;
335 }
336 }
337
338 DB *getDb() { return db_.get(); }
339
340 // mapped to a rocksdb Delete
341 bool remove(const std::string &key) {
342 auto s = db_->Delete(delete_option_, key);
343
344 if (s.ok()) {
345 return true;
346 } else {
347 std::cerr << s.ToString() << std::endl;
348 return false;
349 }
350 }
351
352 // mapped to a rocksdb Get
353 bool get(const std::string &key, uint64_t *value) {
354 std::string str;
355 auto s = db_->Get(get_option_, key, &str);
356
357 if (s.IsNotFound()) {
358 // return default value if not found;
359 *value = default_;
360 return true;
361 } else if (s.ok()) {
362 // deserialization
363 if (str.size() != sizeof(uint64_t)) {
364 std::cerr << "value corruption\n";
365 return false;
366 }
367 *value = DecodeFixed64(&str[0]);
368 return true;
369 } else {
370 std::cerr << s.ToString() << std::endl;
371 return false;
372 }
373 }
374
375 uint64_t get(const uint32_t w1, const uint32_t w2, const int8_t dist) {
376 char encoded_key[sizeof(uint64_t)];
377 EncodeFixed64(encoded_key, encodeCollocation(w1, w2, dist));
378 uint64_t value = default_;
379 get(std::string(encoded_key, 8), &value);
380 return value;
381 }
382
383 virtual void inc(const std::string &key) {
384 db_->Merge(merge_option_, key, _one_slice);
385 }
386
387 void inc(const uint64_t key) {
388 char encoded_key[sizeof(uint64_t)];
389 EncodeFixed64(encoded_key, key);
390 db_->Merge(merge_option_, std::string(encoded_key, 8), _one_slice);
391 }
392
393 virtual void inc(uint32_t w1, uint32_t w2, uint8_t dist);
394
395 void dump(uint32_t w1, uint32_t w2, int8_t dist) const;
396
397 vector<Collocator> get_collocators(uint32_t w1);
398
399 vector<Collocator> get_collocators(uint32_t w1, uint32_t max_w2);
400
401 vector<Collocator> get_collocation_scores(uint32_t w1, uint32_t w2);
402
403 vector<Collocator> get_collocators(uint32_t w1, uint32_t min_w2,
404 uint32_t max_w2);
405
406 void applyCAMeasures(uint32_t w1, uint32_t w2,
407 uint64_t *sumWindow, uint64_t sum,
408 int usedPositions, int true_window_size,
409 Collocator *result) const;
410
411 void dumpSparseLlr(uint32_t w1, uint32_t min_cooccur);
412
413 string collocators2json(uint32_t w1, const vector<Collocator>& collocators);
414
415 // mapped to a rocksdb Merge operation
416 virtual bool add(const std::string &key, uint64_t value) {
417 char encoded[sizeof(uint64_t)];
418 EncodeFixed64(encoded, value);
419 Slice slice(encoded, sizeof(uint64_t));
420 auto s = db_->Merge(merge_option_, key, slice);
421
422 if (s.ok()) {
423 return true;
424 } else {
425 std::cerr << s.ToString() << std::endl;
426 return false;
427 }
428 }
429
430 CollocatorIterator *SeekIterator(uint64_t w1, uint64_t w2, int8_t dist) const;
431};
432
433CollocatorDB::CollocatorDB(const char *db_name,
434 bool read_only = false) {
435 // merge_option_.sync = true;
436 if (read_only)
437 db_ = OpenDbForRead(strdup(db_name));
438 else
439 db_ = OpenDb(db_name);
440 assert(db_);
441 uint64_t one = 1;
442 EncodeFixed64(_one, one);
443 _one_slice = Slice(_one, sizeof(uint64_t));
444}
445
446void CollocatorDB::inc(const uint32_t w1, const uint32_t w2,
447 const uint8_t dist) {
448 inc(encodeCollocation(w1, w2, dist));
449}
450
451void CollocatorDB::readVocab(const string& fname) {
452 char strbuf[2048];
453 uint64_t freq;
454 FILE *fin = fopen(fname.c_str(), "rb");
455 if (fin == nullptr) {
456 cout << "Vocabulary file " << fname << " not found\n";
457 exit(1);
458 }
459 uint64_t i = 0;
460 while (fscanf(fin, "%s %lu", strbuf, &freq) == 2) {
461 _vocab.push_back({strbuf, freq});
462 total += freq;
463 i++;
464 }
465 fclose(fin);
466
467 char size_fname[256];
468 strcpy(size_fname, fname.c_str());
469 char *pos = strstr(size_fname, ".vocab");
470 if (pos) {
471 *pos = 0;
472 strcat(size_fname, ".size");
473 FILE *fp = fopen(size_fname, "r");
474 if (fp != nullptr) {
475 fscanf(fp, "%lu", &sentences);
476 fscanf(fp, "%lu", &total);
477 float sl = (float)total / (float)sentences;
478 float w = WINDOW_SIZE;
479 avg_window_size =
480 ((sl > 2 * w ? (sl - 2 * w) * 2 * w : 0) + (double)w * (3 * w - 1)) /
481 sl;
482 fprintf(stdout,
483 "Size corrections found: corpus size: %lu tokens in %lu "
484 "sentences, avg. sentence size: %f, avg. window size: %f\n",
485 total, sentences, sl, avg_window_size);
486 fclose(fp);
487 } else {
488 // std::cout << "size file " << size_fname << " not found\n";
489 }
490 } else {
491 std::cout << "cannot determine size file " << size_fname << "\n";
492 }
493}
494
495std::shared_ptr<DB> CollocatorDB::OpenDbForRead(const char *name) {
496 DB *db;
497 Options options;
498 options.env->SetBackgroundThreads(4);
499 options.create_if_missing = true;
500 options.merge_operator = std::make_shared<CountMergeOperator>();
501 options.max_successive_merges = 0;
502 // options.prefix_extractor.reset(NewFixedPrefixTransform(8));
503 options.IncreaseParallelism();
504 options.OptimizeLevelStyleCompaction();
505 options.prefix_extractor.reset(NewFixedPrefixTransform(3));
506 ostringstream dbname, vocabname;
507 dbname << name << ".rocksdb";
508 auto s = DB::OpenForReadOnly(options, dbname.str(), &db);
509 if (!s.ok()) {
510 std::cerr << s.ToString() << std::endl;
511 assert(false);
512 }
513 vocabname << name << ".vocab";
514 readVocab(vocabname.str());
515 return std::shared_ptr<DB>(db);
516}
517
518std::shared_ptr<DB> CollocatorDB::OpenDb(const char *dbname) {
519 DB *db;
520 Options options;
521
522 options.env->SetBackgroundThreads(4);
523 options.create_if_missing = true;
524 options.merge_operator = std::make_shared<CountMergeOperator>();
525 options.max_successive_merges = 0;
526 // options.prefix_extractor.reset(NewFixedPrefixTransform(8));
527 options.IncreaseParallelism();
528 options.OptimizeLevelStyleCompaction();
529 // options.max_write_buffer_number = 48;
530 // options.max_background_jobs = 48;
531 // options.allow_concurrent_memtable_write=true;
532 // options.memtable_factory.reset(NewHashLinkListRepFactory(200000));
533 // options.enable_write_thread_adaptive_yield = 1;
534 // options.allow_concurrent_memtable_write = 1;
535 // options.memtable_factory.reset(new SkipListFactory);
536 // options.write_buffer_size = 1 << 22;
537 // options.allow_mmap_reads = true;
538 // options.allow_mmap_writes = true;
539 // options.max_background_compactions = 40;
540 // BlockBasedTableOptions table_options;
541 // table_options.filter_policy.reset(NewBloomFilterPolicy(24, false));
542 // options.bloom_locality = 1;
543 // std::shared_ptr<Cache> cache = NewLRUCache(512 * 1024 * 1024);
544 // table_options.block_cache = cache;
545 // options.table_factory.reset(NewBlockBasedTableFactory(table_options));
546 Status s;
547 // DestroyDB(dbname, Options());
548 s = DB::Open(options, dbname, &db);
549 if (!s.ok()) {
550 std::cerr << s.ToString() << std::endl;
551 assert(false);
552 }
553 total = 1000;
554 return std::shared_ptr<DB>(db);
555}
556
557CollocatorIterator *
558CollocatorDB::SeekIterator(uint64_t w1, uint64_t w2, int8_t dist) const {
559 ReadOptions options;
560 options.prefix_same_as_start = true;
561 char prefixc[sizeof(uint64_t)];
562 EncodeFixed64(prefixc, encodeCollocation(w1, w2, dist));
563 Iterator *it = db_->NewIterator(options);
564 auto *cit = new CollocatorIterator(it);
565 if (w2 > 0)
566 cit->Seek(std::string(prefixc, 6));
567 else
568 cit->Seek(std::string(prefixc, 3));
569 cit->setPrefix(prefixc);
570 return cit;
571}
572
573void CollocatorDB::dump(uint32_t w1, uint32_t w2, int8_t dist) const {
574 auto it = std::unique_ptr<CollocatorIterator>(SeekIterator(w1, w2, dist));
575 for (; it->isValid(); it->Next()) {
576 uint64_t value = it->intValue();
577 uint64_t key = it->intKey();
578 std::cout << "w1:" << W1(key) << ", w2:" << W2(key)
579 << ", dist:" << (int32_t)DIST(key) << " - count:" << value
580 << std::endl;
581 }
582 std::cout << "ready dumping\n";
583}
584
585bool sortByNpmi(const Collocator &lhs, const Collocator &rhs) {
586 return lhs.npmi > rhs.npmi;
587}
588
589bool sortByLfmd(const Collocator &lhs, const Collocator &rhs) {
590 return lhs.lfmd > rhs.lfmd;
591}
592
593bool sortByLlr(const Collocator &lhs, const Collocator &rhs) {
594 return lhs.llr > rhs.llr;
595}
596
597bool sortByLogDice(const Collocator &lhs, const Collocator &rhs) {
598 return lhs.logdice > rhs.logdice;
599}
600
601bool sortByLogDiceAF(const Collocator &lhs, const Collocator &rhs) {
602 return lhs.ldaf > rhs.ldaf;
603}
604
605void CollocatorDB::applyCAMeasures(
606 const uint32_t w1, const uint32_t w2, uint64_t *sumWindow,
607 const uint64_t sum, const int usedPositions, int true_window_size,
608 Collocator *result) const {
609 uint64_t f1 = _vocab[w1].freq, f2 = _vocab[w2].freq;
610 double o = sum, r1 = f1 * true_window_size, c1 = f2, e = r1 * c1 / total,
611 pmi = log2(o / e), md = log2(o * o / e), lfmd = log2(o * o * o / e),
Marc Kupietze889cec2024-11-23 12:08:42 +0100612 llr = ca_ll(f1, f2, sum, total, true_window_size),
613 md_nws = ca_md(f1, f2, sum, total, 2 * WINDOW_SIZE),
614 ld = ca_logdice(f1, f2, sum, total, true_window_size);
Marc Kupietz39887082024-11-22 18:06:20 +0100615
616 int bestWindow = usedPositions;
617 double bestAF = ld;
618 // if(f1<75000000)
619 // #pragma omp parallel for reduction(max:bestAF)
620 // #pragma omp target teams distribute parallel for reduction(max:bestAF)
621 // map(tofrom:bestAF,currentAF,bestWindow,usedPositions)
622 for (int bitmask = 1; bitmask < (1 << (2 * WINDOW_SIZE)); bitmask++) {
623 if ((bitmask & usedPositions) == 0 || (bitmask & ~usedPositions) > 0)
624 continue;
625 uint64_t currentWindowSum = 0;
626 // #pragma omp target teams distribute parallel for
627 // reduction(+:currentWindowSum) map(tofrom:bitmask,usedPositions)
628 for (int pos = 0; pos < 2 * WINDOW_SIZE; pos++) {
629 if (((1 << pos) & bitmask & usedPositions) != 0)
630 currentWindowSum += sumWindow[pos];
631 }
632 double currentAF = ca_logdice(f1, f2, currentWindowSum, total,
633 __builtin_popcount(bitmask));
634 if (currentAF > bestAF) {
635 bestAF = currentAF;
636 bestWindow = bitmask;
637 }
638 }
639
640 *result = {w2,
641 f2,
642 sum,
643 pmi,
644 pmi / (-log2(o / total / true_window_size)),
645 llr,
646 lfmd,
647 md,
Marc Kupietze889cec2024-11-23 12:08:42 +0100648 md_nws,
Marc Kupietz39887082024-11-22 18:06:20 +0100649 sumWindow[WINDOW_SIZE],
650 sumWindow[WINDOW_SIZE - 1],
651 ca_pmi(f1, f2, sumWindow[WINDOW_SIZE], total, 1),
652 ca_pmi(f1, f2, sumWindow[WINDOW_SIZE - 1], total, 1),
653 ca_dice(f1, f2, sum, total, true_window_size),
654 ld,
655 bestAF,
656 usedPositions,
657 bestWindow};
658}
659
660std::vector<Collocator>
661CollocatorDB::get_collocators(uint32_t w1, uint32_t min_w2,
662 uint32_t max_w2) {
663 std::vector<Collocator> collocators;
664 uint64_t w2, last_w2 = 0xffffffffffffffff;
665 uint64_t maxv = 0, sum = 0;
666 auto *sumWindow =
667 static_cast<uint64_t *>(malloc(sizeof(uint64_t) * 2 * WINDOW_SIZE));
668 memset(sumWindow, 0, sizeof(uint64_t) * 2 * WINDOW_SIZE);
669 int true_window_size = 1;
670 int usedPositions = 0;
671
672 if (w1 > _vocab.size()) {
673 std::cout << w1 << "> vocabulary size " << _vocab.size() << "\n";
674 w1 -= _vocab.size();
675 }
676#ifdef DEBUG
677 std::cout << "Searching for collocates of " << _vocab[w1].word << "\n";
678#endif
679 // #pragma omp parallel num_threads(40)
680 // #pragma omp single
681 for (auto it =
682 std::unique_ptr<CollocatorIterator>(SeekIterator(w1, min_w2, 0));
683 it->isValid(); it->Next()) {
684 uint64_t value = it->intValue(), key = it->intKey();
685 if ((w2 = W2(key)) > max_w2)
686 continue;
687 if (last_w2 == 0xffffffffffffffff)
688 last_w2 = w2;
689 if (w2 != last_w2) {
690 if (sum >= FREQUENCY_THRESHOLD) {
691 collocators.push_back({});
692 Collocator *result = &(collocators[collocators.size() - 1]);
693 // #pragma omp task firstprivate(last_w2, sumWindow, sum, usedPositions,
694 // true_window_size) shared(w1, result) if(sum > 1000000)
695 {
696 // uint64_t *nsw = (uint64_t *)malloc(sizeof(uint64_t) * 2
697 // *WINDOW_SIZE); memcpy(nsw, sumWindow, sizeof(uint64_t) * 2
698 // *WINDOW_SIZE);
699 applyCAMeasures(w1, last_w2, sumWindow, sum, usedPositions,
700 true_window_size, result);
701 // free(nsw);
702 }
703 }
704 memset(sumWindow, 0, 2 * WINDOW_SIZE * sizeof(uint64_t));
705 usedPositions = 1 << (-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0));
706 sumWindow[-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0)] = value;
707 last_w2 = w2;
708 maxv = value;
709 sum = value;
710 true_window_size = 1;
711 if (min_w2 == max_w2 && w2 != min_w2)
712 break;
713 } else {
714 sum += value;
715 if (value > maxv)
716 maxv = value;
717 usedPositions |=
718 1 << (-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0));
719 sumWindow[-DIST(key) + WINDOW_SIZE - (DIST(key) < 0 ? 1 : 0)] = value;
720 true_window_size++;
721 }
722 }
723
724 // #pragma omp taskwait
725 sort(collocators.begin(), collocators.end(), sortByLogDiceAF);
726
727#ifdef DEBUG
728 int i = 0;
729 for (Collocator c : collocators) {
730 if (i++ > 10)
731 break;
732 std::cout << "w1:" << _vocab[w1].word << ", w2: *" << _vocab[c.w2].word
733 << "*"
734 << "\t f(w1):" << _vocab[w1].freq
735 << "\t f(w2):" << _vocab[c.w2].freq << "\t f(w1, w2):" << c.raw
736 << "\t pmi:" << c.pmi << "\t npmi:" << c.npmi
737 << "\t llr:" << c.llr << "\t md:" << c.md << "\t lfmd:" << c.lfmd
738 << "\t total:" << total << std::endl;
739 }
740#endif
741
742 return collocators;
743}
744
745std::vector<Collocator>
746CollocatorDB::get_collocation_scores(uint32_t w1, uint32_t w2) {
747 return get_collocators(w1, w2, w2);
748}
749
750std::vector<Collocator> CollocatorDB::get_collocators(uint32_t w1) {
751 return get_collocators(w1, 0, UINT32_MAX);
752}
753
754void CollocatorDB::dumpSparseLlr(uint32_t w1, uint32_t min_cooccur) {
755 std::vector<Collocator> collocators;
756 std::stringstream stream;
757 uint64_t w2, last_w2 = 0xffffffffffffffff;
758 uint64_t maxv = 0, total_w1 = 0;
759 bool first = true;
760 for (auto it = std::unique_ptr<CollocatorIterator>(SeekIterator(w1, 0, 0));
761 it->isValid(); it->Next()) {
762 uint64_t value = it->intValue(), key = it->intKey();
763 w2 = W2(key);
764 total_w1 += value;
765 if (last_w2 == 0xffffffffffffffff)
766 last_w2 = w2;
767 if (w2 != last_w2) {
768 if (maxv >= min_cooccur) {
769 double llr =
770 ca_ll(_vocab[w1].freq, _vocab[last_w2].freq, maxv, total, 1);
771 if (first)
772 first = false;
773 else
774 stream << " ";
775 stream << w2 << " " << llr;
776 }
777 last_w2 = w2;
778 maxv = value;
779 } else {
780 if (value > maxv)
781 maxv = value;
782 }
783 }
784 if (first)
785 stream << "1 0.0";
786 stream << "\n";
787 std::cout << stream.str();
788}
789
790Slice CollocatorIterator::key() const {
791 return base_iterator_->key();
792}
793
794Slice CollocatorIterator::value() const {
795 return base_iterator_->value();
796}
797
798Status CollocatorIterator::status() const {
799 return base_iterator_->status();
800}
801
802}; // namespace rocksdb
803
804string CollocatorDB::getWord(uint32_t w1) { return _vocab[w1].word; }
805
806uint64_t CollocatorDB::getWordId(const char *word) const {
Marc Kupietz979580e2024-11-21 18:05:07 +0100807 for (uint64_t i = 0; i < _vocab.size(); i++) {
808 if (strcmp(_vocab[i].word.c_str(), word) == 0)
809 return i;
810 }
811 return 0;
812}
813
Marc Kupietzd26b1052024-12-10 16:56:39 +0100814uint64_t CollocatorDB::getCorpusSize() const {
815 return total;
816}
817
Marc Kupietz21b964c2024-12-10 17:10:50 +0100818uint64_t CollocatorDB::getWordFrequency(uint64_t w1) {
819 return _vocab[w1].freq;
820}
821
Marc Kupietz39887082024-11-22 18:06:20 +0100822string CollocatorDB::collocators2json(uint32_t w1,
823 const vector<Collocator>& collocators) {
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100824 ostringstream s;
Marc Kupietz0dd86ef2018-01-11 22:23:17 +0100825 int i = 0;
Marc Kupietz39887082024-11-22 18:06:20 +0100826 s << " { \"f1\": " << _vocab[w1].freq << "," << R"("w1":")"
827 << string(_vocab[w1].word) << "\", " << "\"N\": " << total << ", "
828 << "\"collocates\": [";
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100829 bool first = true;
830 for (Collocator c : collocators) {
Marc Kupietz39887082024-11-22 18:06:20 +0100831 if (strncmp(_vocab[c.w2].word.c_str(), "quot", 4) == 0)
832 continue;
Marc Kupietz0dd86ef2018-01-11 22:23:17 +0100833 if (i++ > 200)
834 break;
Marc Kupietz12af0192021-03-13 18:05:14 +0100835 if (!first)
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100836 s << ",\n";
837 else
838 first = false;
839 s << "{"
Marc Kupietz39887082024-11-22 18:06:20 +0100840 "\"word\":\""
841 << (string(_vocab[c.w2].word) == "<num>"
842 ? string("###")
843 : string(_vocab[c.w2].word))
844 << "\"," << "\"f2\":" << c.f2 << "," << "\"f\":" << c.raw << ","
845 << "\"npmi\":" << c.npmi << "," << "\"pmi\":" << c.pmi << ","
846 << "\"llr\":" << c.llr << "," << "\"lfmd\":" << c.lfmd << ","
Marc Kupietze889cec2024-11-23 12:08:42 +0100847 << "\"md\":" << c.md << "," << "\"md_nws\":" << c.md_nws << "," << "\"dice\":" << c.dice << ","
Marc Kupietz39887082024-11-22 18:06:20 +0100848 << "\"ld\":" << c.logdice << "," << "\"ln_count\":" << c.left_raw << ","
849 << "\"rn_count\":" << c.right_raw << "," << "\"ln_pmi\":" << c.left_pmi
850 << "," << "\"rn_pmi\":" << c.right_pmi << "," << "\"ldaf\":" << c.ldaf
851 << "," << "\"win\":" << c.window << "," << "\"afwin\":" << c.af_window
852 << "}";
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100853 }
Marc Kupietze9627152019-02-04 12:32:12 +0100854 s << "]}\n";
Marc Kupietz0421d092021-03-13 18:05:14 +0100855 // std::cout << s.str();
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100856 return s.str();
857}
858
Marc Kupietz39887082024-11-22 18:06:20 +0100859typedef CollocatorDB COLLOCATORS;
Marc Kupietz06c9a9f2018-01-02 16:56:43 +0100860
861extern "C" {
Marc Kupietz12af0192021-03-13 18:05:14 +0100862#ifdef __clang__
863#pragma clang diagnostic push
864#pragma ide diagnostic ignored "OCUnusedGlobalDeclarationInspection"
865#endif
Marc Kupietz39887082024-11-22 18:06:20 +0100866DLL_EXPORT COLLOCATORS *open_collocatordb_for_write(char *dbname) {
867 return new CollocatorDB(dbname, false);
868}
Marc Kupietz12af0192021-03-13 18:05:14 +0100869
Marc Kupietz39887082024-11-22 18:06:20 +0100870DLL_EXPORT COLLOCATORS *open_collocatordb(char *dbname) {
871 return new CollocatorDB(dbname, true);
872}
Marc Kupietz06c9a9f2018-01-02 16:56:43 +0100873
Marc Kupietz39887082024-11-22 18:06:20 +0100874DLL_EXPORT void inc_collocator(COLLOCATORS *db, uint32_t w1, uint32_t w2,
875 int8_t dist) {
876 db->inc(w1, w2, dist);
877}
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100878
Marc Kupietz39887082024-11-22 18:06:20 +0100879DLL_EXPORT void dump_collocators(COLLOCATORS *db, uint32_t w1, uint32_t w2,
880 int8_t dist) {
881 db->dump(w1, w2, dist);
882}
Marc Kupietzc8ddf452018-01-07 21:33:12 +0100883
Marc Kupietz39887082024-11-22 18:06:20 +0100884DLL_EXPORT COLLOCATORS *get_collocators(COLLOCATORS *db, uint32_t w1) {
885 std::vector<Collocator> c = db->get_collocators(w1);
886 if (c.empty())
887 return nullptr;
888 uint64_t size = c.size() + sizeof c[0];
889 auto *p = (COLLOCATORS *)malloc(size);
890 memcpy(p, c.data(), size);
891 return p;
892}
Marc Kupietz88d116b2021-03-13 18:05:14 +0100893
Marc Kupietz39887082024-11-22 18:06:20 +0100894DLL_EXPORT COLLOCATORS *get_collocation_scores(COLLOCATORS *db, uint32_t w1,
895 uint32_t w2) {
896 std::vector<Collocator> c = db->get_collocation_scores(w1, w2);
897 if (c.empty())
898 return nullptr;
899 uint64_t size = c.size() + sizeof c[0];
900 auto *p = (COLLOCATORS *)malloc(size);
901 memcpy(p, c.data(), size);
902 return p;
903}
Marc Kupietzca3a52e2018-06-05 14:16:23 +0200904
Marc Kupietz39887082024-11-22 18:06:20 +0100905DLL_EXPORT char *get_word(COLLOCATORS *db, uint32_t w) {
906 return strdup(db->getWord(w).c_str());
907}
Marc Kupietz979580e2024-11-21 18:05:07 +0100908
Marc Kupietz39887082024-11-22 18:06:20 +0100909DLL_EXPORT uint64_t get_word_id(COLLOCATORS *db, char *word) {
910 return db->getWordId(word);
911}
Marc Kupietzb4a683c2021-03-14 09:19:44 +0100912
Marc Kupietz39887082024-11-22 18:06:20 +0100913DLL_EXPORT void read_vocab(COLLOCATORS *db, char *fname) {
914 std::string fName(fname);
915 db->readVocab(fName);
916}
Marc Kupietz88d116b2021-03-13 18:05:14 +0100917
Marc Kupietz39887082024-11-22 18:06:20 +0100918DLL_EXPORT const char *get_collocators_as_json(COLLOCATORS *db, uint32_t w1) {
919 return strdup(db->collocators2json(w1, db->get_collocators(w1)).c_str());
920}
Marc Kupietzb4a683c2021-03-14 09:19:44 +0100921
Marc Kupietz39887082024-11-22 18:06:20 +0100922DLL_EXPORT const char *
923get_collocation_scores_as_json(COLLOCATORS *db, uint32_t w1, uint32_t w2) {
924 return strdup(
925 db->collocators2json(w1, db->get_collocation_scores(w1, w2)).c_str());
926}
927
928DLL_EXPORT const char *get_version() { return PROJECT_VERSION; }
Marc Kupietz6208fd72024-11-15 15:46:19 +0100929
Marc Kupietzd26b1052024-12-10 16:56:39 +0100930DLL_EXPORT uint64_t get_corpus_size(COLLOCATORS *db) { return db->getCorpusSize(); };
931
Marc Kupietz21b964c2024-12-10 17:10:50 +0100932DLL_EXPORT uint64_t get_word_frequency(COLLOCATORS *db, uint64_t w1) {
933 return db->getWordFrequency(w1);
934}
935
Marc Kupietz12af0192021-03-13 18:05:14 +0100936#ifdef __clang__
937#pragma clang diagnostic push
938#endif
Marc Kupietz06c9a9f2018-01-02 16:56:43 +0100939}