Allow passing max number of collocates
Change-Id: Ida073dde970549974adff899891ae1f4b968b83a
diff --git a/Project.toml b/Project.toml
index cf2f458..bf88fcd 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "DerekoVecs"
uuid = "d522d5f0-9ae6-4a1c-b144-42ce9f15cae4"
authors = ["Marc Kupietz"]
-version = "0.5.1"
+version = "0.5.2"
[deps]
ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
diff --git a/src/DerekoVecs.jl b/src/DerekoVecs.jl
index cef4cec..ae728c2 100644
--- a/src/DerekoVecs.jl
+++ b/src/DerekoVecs.jl
@@ -191,32 +191,33 @@
@ccall libcdb.open_collocatordb(path::Cstring)::Ptr{Cvoid}
end
-function get_collocates(cdb::Ptr{Nothing}, node::Int64, max_vocab_index::Int64)::Vector{collocate}
+function get_collocates(cdb::Ptr{Nothing}, node::Int64, max_vocab_index::Int64, max::Int64)::Vector{collocate}
res = @ccall libcdb.get_collocators(cdb::Ptr{collocate}, node::Cuint)::Ptr{collocate}
if res == Ptr{collocate}(C_NULL)
return Vector{collocate}()
end
i = 0
- for c in unsafe_wrap(Vector{collocate}, res, 1000, own=false)
+ for c in unsafe_wrap(Vector{collocate}, res, max, own=false)
+ i += 1
if (c.w2 <= 0 || c.w2 > max_vocab_index || c.f2 <= 1 || c.pmi < 0.01)
+ i -= 1
break
end
- i += 1
end
- unsafe_wrap(Vector{collocate}, res, i - 1, own=false)
+ unsafe_wrap(Vector{collocate}, res, i, own = false)
end
-function get_collocates(dv::d2vmodel, node::Int)::DataFrame
- collocates = get_collocates(dv.cdb, node - 1, length(dv.vocab))
+function get_collocates(dv::d2vmodel, node::Int64, max = 200)::DataFrame
+ collocates = get_collocates(dv.cdb, node - 1, length(dv.vocab), max)
df = DataFrame(collocates)
df.w2 = map(x -> x+1, df.w2)
df.collocate = map(x -> dv.vocab[x], df.w2)
df
end
-function get_collocates(dv::d2vmodel, node::String)::DataFrame
- get_collocates(dv, dv.vocabdict[node])
+function get_collocates(dv::d2vmodel, node::String, max = 200)::DataFrame
+ get_collocates(dv, dv.vocabdict[node], max)
end
end
diff --git a/test/Project.toml b/test/Project.toml
index 2a9125d..e18ca85 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,4 +1,6 @@
[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
diff --git a/test/runtests.jl b/test/runtests.jl
index f1e0388..66e3609 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,6 @@
using Artifacts
using DerekoVecs
+using DataFrames
using Test
@@ -48,6 +49,8 @@
@test df.collocate[1] == "kann"
@test df.ldaf[1] > 10
@test df.ldaf[1] > df.ldaf[3]
+ @test nrow(get_collocates(wpd19, 3, 1)) == 1
+ @test nrow(get_collocates(wpd19, 3, 2)) == 2
end
end