Allow passing max number of collocates

Change-Id: Ida073dde970549974adff899891ae1f4b968b83a
diff --git a/Project.toml b/Project.toml
index cf2f458..bf88fcd 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "DerekoVecs"
 uuid = "d522d5f0-9ae6-4a1c-b144-42ce9f15cae4"
 authors = ["Marc Kupietz"]
-version = "0.5.1"
+version = "0.5.2"
 
 [deps]
 ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
diff --git a/src/DerekoVecs.jl b/src/DerekoVecs.jl
index cef4cec..ae728c2 100644
--- a/src/DerekoVecs.jl
+++ b/src/DerekoVecs.jl
@@ -191,32 +191,33 @@
     @ccall libcdb.open_collocatordb(path::Cstring)::Ptr{Cvoid}
 end
 
-function get_collocates(cdb::Ptr{Nothing}, node::Int64, max_vocab_index::Int64)::Vector{collocate}
+function get_collocates(cdb::Ptr{Nothing}, node::Int64, max_vocab_index::Int64, max::Int64)::Vector{collocate}
     res = @ccall libcdb.get_collocators(cdb::Ptr{collocate}, node::Cuint)::Ptr{collocate}
     if res == Ptr{collocate}(C_NULL)
         return Vector{collocate}()
     end
 
     i = 0
-    for c in unsafe_wrap(Vector{collocate}, res, 1000, own=false)
+    for c in unsafe_wrap(Vector{collocate}, res, max, own=false)
+        i += 1
         if (c.w2 <= 0 || c.w2 > max_vocab_index || c.f2 <= 1 || c.pmi < 0.01)
+            i -= 1
             break
         end
-        i += 1
     end
-    unsafe_wrap(Vector{collocate}, res, i - 1, own=false)
+    unsafe_wrap(Vector{collocate}, res, i, own = false)
 end
 
-function get_collocates(dv::d2vmodel, node::Int)::DataFrame
-    collocates = get_collocates(dv.cdb, node - 1, length(dv.vocab))
+function get_collocates(dv::d2vmodel, node::Int64, max = 200)::DataFrame
+    collocates = get_collocates(dv.cdb, node - 1, length(dv.vocab), max)
     df = DataFrame(collocates)
     df.w2 = map(x -> x+1, df.w2)
     df.collocate = map(x -> dv.vocab[x], df.w2)
     df
 end
 
-function get_collocates(dv::d2vmodel, node::String)::DataFrame
-    get_collocates(dv, dv.vocabdict[node])
+function get_collocates(dv::d2vmodel, node::String, max = 200)::DataFrame
+    get_collocates(dv, dv.vocabdict[node], max)
 end
 
 end
diff --git a/test/Project.toml b/test/Project.toml
index 2a9125d..e18ca85 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,4 +1,6 @@
 [deps]
 Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
diff --git a/test/runtests.jl b/test/runtests.jl
index f1e0388..66e3609 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,6 @@
 using Artifacts
 using DerekoVecs
+using DataFrames
 using Test
 
 
@@ -48,6 +49,8 @@
             @test df.collocate[1] == "kann"
             @test df.ldaf[1] > 10
             @test df.ldaf[1] > df.ldaf[3]
+            @test nrow(get_collocates(wpd19, 3, 1)) == 1
+            @test nrow(get_collocates(wpd19, 3, 2)) == 2
         end
     end