Import functions from IDS-KL single repo
Change-Id: Ibbbb8f58a1ee86e5404c2b8363fa73feb0ceb872
diff --git a/src/DerekoVecs.jl b/src/DerekoVecs.jl
index 795acc1..f2f2f54 100644
--- a/src/DerekoVecs.jl
+++ b/src/DerekoVecs.jl
@@ -1,5 +1,64 @@
module DerekoVecs
+export load, knn, cos_sim, d2vmodel
-# Write your package code here.
+using Mmap
+using DelimitedFiles
+using LinearAlgebra
+using Distances
+using Pkg.Artifacts
+
+struct d2vmodel
+ M::Matrix{Float32}
+ m::Int64
+ n::Int64
+ vocabdict::Dict{String, Int64}
+ vocab::Array{String}
+end
+
+function load(modelfn) ::d2vmodel
+ (n, d) = map(s->parse(Int, s), split(readline(modelfn), " "))
+ vocabfn = replace(modelfn, ".vecs" => ".vocab")
+
+ vocab = readdlm(vocabfn, ' ', String, dims=(n,2), quotes = false)[:,1]
+ vocabdict = Dict{String, Int64}(zip(vocab, 1:n))
+ vecsfn = "$(modelfn).vecs"
+ M = Mmap.mmap(vecsfn, Matrix{Float32}, (d,n))
+ d2vmodel(M, d, n, vocabdict, vocab)
+end
+
+function load() ::d2vmodel
+ load(defaultmodelfn)
+end
+
+function cos_sim(m1::d2vmodel, m2::d2vmodel, w1::String, w2::String)
+ try
+ dot(m1.M[:, m1.vocabdict[w1]], m2.M[:, m2.vocabdict[w2]])
+ catch error
+ -1.0
+ end
+end
+
+cos_sim(m::d2vmodel, w1::String, w2::String) = cos_sim(m, m, w1,w2)
+
+function minus(m::d2vmodel, w1::String, w2::String)
+ knn(m, normalize(m.M[:,m.vocabdict[w1]] - m.M[:,m.vocabdict[w2]]), 1)[1]
+end
+
+function knn(m::d2vmodel, v::Array{Float32}, k)
+ # dist = Array{Float64}(undef, size(m.M)[2])
+ #@time Threads.@threads for i in 1:size(M)[2] dist[i]=dot(v, M[:,i]) end
+ #knn = sortperm(dist, rev=true)[1:k]
+ knn = sortperm(map(x->dot(v, m.M[:,x]), 1:m.n), rev=true)[1:k]
+ map(x->m.vocab[x], knn)
+end
+
+function knn(m::d2vmodel, w1index::Int64, k::Int)
+ knn(m, m.M[:,w1index], k)
+end
+
+function knn(m::d2vmodel, w1::String, k)
+ knn(m, m.vocabdict[w1], k)
+end
end
+