Return collocates as DataFrame
Change-Id: I6cde1b08782d9e7bf3bc6ad4d170cdb4d25b404f
diff --git a/Project.toml b/Project.toml
index 9a7e1f5..db84e29 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "DerekoVecs"
uuid = "d522d5f0-9ae6-4a1c-b144-42ce9f15cae4"
authors = ["Marc Kupietz"]
-version = "0.4.2"
+version = "0.5.0"
[deps]
ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
diff --git a/src/DerekoVecs.jl b/src/DerekoVecs.jl
index 66100a8..fa22a53 100644
--- a/src/DerekoVecs.jl
+++ b/src/DerekoVecs.jl
@@ -191,7 +191,7 @@
@ccall libcdb.open_collocatordb(path::Cstring)::Ptr{Cvoid}
end
-function get_collocates(cdb::Ptr{Nothing}, node::Int64)::Vector{collocate}
+function get_collocates(cdb::Ptr{Nothing}, node::Int64, max_vocab_index::Int64)::Vector{collocate}
res = @ccall libcdb.get_collocators(cdb::Ptr{collocate}, node::Cuint)::Ptr{collocate}
if res == Ptr{collocate}(C_NULL)
return Vector{collocate}()
@@ -199,7 +199,7 @@
i = 0
for c in unsafe_wrap(Vector{collocate}, res, 1000, own=false)
- if (c.w2 == 0)
+ if (c.w2 <= 0 || c.w2 > max_vocab_index)
break
end
i += 1
@@ -207,11 +207,15 @@
unsafe_wrap(Vector{collocate}, res, i - 1, own=false)
end
-function get_collocates(dv::d2vmodel, node::Int)::Vector{collocate}
- get_collocates(dv.cdb, node)
+function get_collocates(dv::d2vmodel, node::Int)::DataFrame
+ collocates = get_collocates(dv.cdb, node - 1, length(dv.vocab))
+ df = DataFrame(collocates)
+ df.w2 = map(x -> x+1, df.w2)
+ df.collocate = map(x -> dv.vocab[x], df.w2)
+ df
end
-function get_collocates(dv::d2vmodel, node::String)::Vector{collocate}
+function get_collocates(dv::d2vmodel, node::String)::DataFrame
get_collocates(dv, dv.vocabdict[node])
end
diff --git a/test/runtests.jl b/test/runtests.jl
index bc29554..f1e0388 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -44,10 +44,10 @@
@testset "DerekoVecs.jl: collocation analysis" begin
if (!isnothing(wpd19.cdb))
- println(wpd19.vocab[30])
- coll = get_collocates(wpd19, "werden")
- @test coll[1].ldaf > 10
- @test coll[1].ldaf > coll[3].ldaf
+ df = get_collocates(wpd19, "werden")
+ @test df.collocate[1] == "kann"
+ @test df.ldaf[1] > 10
+ @test df.ldaf[1] > df.ldaf[3]
end
end