Add method to compute measures for a specific node collocate pair
Change-Id: Idbc59ad850dff5e33ba3c56215855667d9b7cf6e
diff --git a/collocatordb.cc b/collocatordb.cc
index 6970501..9a23173 100644
--- a/collocatordb.cc
+++ b/collocatordb.cc
@@ -390,6 +390,8 @@
void dump(uint32_t w1, uint32_t w2, int8_t dist);
vector<Collocator> get_collocators(uint32_t w1);
vector<Collocator> get_collocators(uint32_t w1, uint32_t max_w2);
+ vector<Collocator> get_collocation_scores(uint32_t w1, uint32_t w2);
+ vector<Collocator> get_collocators(uint32_t w1, uint32_t min_w2, uint32_t max_w2);
void applyCAMeasures(const uint32_t w1, const uint32_t w2, uint64_t *sumWindow, const uint64_t sum, const int usedPositions, int true_window_size, rocksdb::Collocator *result);
void dumpSparseLlr(uint32_t w1, uint32_t min_cooccur);
@@ -416,7 +418,7 @@
rocksdb::CollocatorDB::CollocatorDB(const char *db_name, bool read_only = false) {
// merge_option_.sync = true;
if(read_only)
- db_ = OpenDbForRead(db_name);
+ db_ = OpenDbForRead(strdup(db_name));
else
db_ = OpenDb(db_name);
assert(db_);
@@ -538,7 +540,10 @@
EncodeFixed64(prefixc, encodeCollocation(w1, w2, dist));
Iterator *it = db_->NewIterator(options);
CollocatorIterator *cit = new CollocatorIterator(it);
- cit->Seek(std::string(prefixc,3));// it->Valid() && it->key().starts_with(std::string(prefixc,3)); it->Next()) {
+ if (w2 > 0)
+ cit->Seek(std::string(prefixc, 6));
+ else
+ cit->Seek(std::string(prefixc, 3));
cit->setPrefix(prefixc);
return cit;
}
@@ -610,7 +615,7 @@
}
- std::vector<Collocator> rocksdb::CollocatorDB::get_collocators(uint32_t w1, uint32_t max_w2) {
+ std::vector<Collocator> rocksdb::CollocatorDB::get_collocators(uint32_t w1, uint32_t min_w2, uint32_t max_w2) {
std::vector<Collocator> collocators;
uint64_t w2, last_w2 = 0xffffffffffffffff;
uint64_t maxv = 0, sum = 0;
@@ -628,7 +633,7 @@
#endif
// #pragma omp parallel num_threads(40)
// #pragma omp single
- for ( auto it = std::unique_ptr<CollocatorIterator>(SeekIterator(w1, 0, 0)); it->isValid(); it->Next()) {
+ for ( auto it = std::unique_ptr<CollocatorIterator>(SeekIterator(w1, min_w2, 0)); it->isValid(); it->Next()) {
uint64_t value = it->intValue(),
key = it->intKey();
if((w2 = W2(key)) > max_w2)
@@ -653,6 +658,8 @@
maxv = value;
sum = value;
true_window_size = 1;
+ if (min_w2 == max_w2 && w2 != min_w2)
+ break;
} else {
sum += value;
if(value > maxv)
@@ -687,8 +694,13 @@
return collocators;
}
+
+ std::vector<Collocator> rocksdb::CollocatorDB::get_collocation_scores(uint32_t w1, uint32_t w2) {
+ return get_collocators(w1, w2, w2);
+ }
+
std::vector<Collocator> rocksdb::CollocatorDB::get_collocators(uint32_t w1) {
- return get_collocators(w1, UINT32_MAX);
+ return get_collocators(w1, 0, UINT32_MAX);
}
void rocksdb::CollocatorDB::dumpSparseLlr(uint32_t w1, uint32_t min_cooccur) {
@@ -799,11 +811,19 @@
db->get_collocators(w1);
}
+ void get_collocation_scores(COLLOCATORS *db, uint32_t w1, uint32_t w2) {
+ db->get_collocation_scores(w1, w2);
+ }
+
const char *get_word(COLLOCATORS *db, uint32_t w) {
- return db->getWord(w).c_str();
+ return strdup(db->getWord(w).c_str());
}
const char *get_collocators_as_json(COLLOCATORS *db, uint32_t w1) {
return strdup(db->collocators2json(w1, db->get_collocators(w1)).c_str());
}
+
+ const char *get_collocation_scores_as_json(COLLOCATORS *db, uint32_t w1, uint32_t w2) {
+ return strdup(db->collocators2json(w1, db->get_collocation_scores(w1, w2)).c_str());
+ }
}