Provide all missing cos_sim variants
Change-Id: I86bc78e7e3fc031a94568d4f5e7f0a2f91a59dac
diff --git a/README.md b/README.md
index ca3f88e..39c7d7b 100644
--- a/README.md
+++ b/README.md
@@ -19,7 +19,11 @@
knn(model, "interessant", 10)
-cos_dist(model, "gut", "besser")
+cos_sim(model, "gut", "besser")
+
+
+model2 = load("model2.vecs")
+cos_sim(model, model2, "also")
```
## License
diff --git a/src/DerekoVecs.jl b/src/DerekoVecs.jl
index c279884..e0ffb02 100644
--- a/src/DerekoVecs.jl
+++ b/src/DerekoVecs.jl
@@ -26,15 +26,19 @@
d2vmodel(M, d, n, vocabdict, vocab)
end
-function cos_sim(m1::d2vmodel, m2::d2vmodel, w1::String, w2::String)
+function cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64, w2index::Int64)
try
- dot(m1.M[:, m1.vocabdict[w1]], m2.M[:, m2.vocabdict[w2]])
+ dot(m1.M[:, w1index], m2.M[:, w2index])
catch error
-1.0
end
end
cos_sim(m::d2vmodel, w1::String, w2::String) = cos_sim(m, m, w1,w2)
+cos_sim(m1::d2vmodel, m2::d2vmodel, w1::String, w2::String) = cos_sim(m1, m2, m1.vocabdict[w1], m2.vocabdict[w2])
+cos_sim(m1::d2vmodel, m2::d2vmodel, w::String) = cos_sim(m1, m2, m1.vocabdict[w], m2.vocabdict[w])
+cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64) = cos_sim(m1, m2, w1index, w1index)
+cos_sim(m::d2vmodel, w1index::Int64, w2index::Int64) = cos_sim(m, m, w1index, w2index)
function minus(m::d2vmodel, w1::String, w2::String)
knn(m, normalize(m.M[:,m.vocabdict[w1]] - m.M[:,m.vocabdict[w2]]), 1)[1]
diff --git a/test/runtests.jl b/test/runtests.jl
index c2834ad..7dfe844 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -17,6 +17,9 @@
@test cos_sim(wpd19, "wurden", "war") == cos_sim(wpd19, "war", "wurden")
@test cos_sim(wpd19, "wurde", "wurden") > cos_sim(wpd19, "wurde", "ich")
@test cos_sim(wpd19, "wurden", "war") == cos_sim(wpd19, wpd19, "war", "wurden")
+ @test isapprox(cos_sim(wpd19, wpd19, "war"), 1)
+ @test isapprox(cos_sim(wpd19, wpd19, 50), 1)
+ @test isapprox(cos_sim(wpd19, 50, 50), 1)
end
@testset "DerekoVecs.jl: knn" begin