Add Kullback-Leibler divergence calculation to the package
diff --git a/Manifest.toml b/Manifest.toml
index 524ad72..9a44136 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -1,10 +1,12 @@
# This file is machine-generated - editing it directly is not advised
-julia_version = "1.7.3"
+julia_version = "1.8.0-rc3"
manifest_format = "2.0"
+project_hash = "af668ce292bd766579fbcfeae7d9729d80ba7ba5"
[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
+version = "1.1.1"
[[deps.ArtifactUtils]]
deps = ["Downloads", "Git", "HTTP", "Pkg", "ProgressLogging", "SHA", "TOML", "gh_cli_jll"]
@@ -18,9 +20,67 @@
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+[[deps.BinaryProvider]]
+deps = ["Libdl", "Logging", "SHA"]
+git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
+uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
+version = "0.5.10"
+
+[[deps.ChainRulesCore]]
+deps = ["Compat", "LinearAlgebra", "SparseArrays"]
+git-tree-sha1 = "80ca332f6dcb2508adba68f22f551adb2d00a624"
+uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+version = "1.15.3"
+
+[[deps.ChangesOfVariables]]
+deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
+git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8"
+uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
+version = "0.1.4"
+
+[[deps.CodeTracking]]
+deps = ["InteractiveUtils", "UUIDs"]
+git-tree-sha1 = "6d4fa04343a7fc9f9cb9cff9558929f3d2752717"
+uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
+version = "1.0.9"
+
+[[deps.Compat]]
+deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
+git-tree-sha1 = "9be8be1d8a6f44b96482c8af52238ea7987da3e3"
+uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
+version = "3.45.0"
+
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
+version = "0.5.2+0"
+
+[[deps.Crayons]]
+git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
+uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
+version = "4.1.1"
+
+[[deps.DataAPI]]
+git-tree-sha1 = "fb5f5316dd3fd4c5e7c30a24d50643b73e37cd40"
+uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
+version = "1.10.0"
+
+[[deps.DataFrames]]
+deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
+git-tree-sha1 = "daa21eb85147f72e41f6352a57fccea377e310a9"
+uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
+version = "1.3.4"
+
+[[deps.DataStructures]]
+deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
+git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0"
+uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
+version = "0.18.13"
+
+[[deps.DataValueInterfaces]]
+git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
+uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
+version = "1.0.0"
[[deps.Dates]]
deps = ["Printf"]
@@ -36,9 +96,20 @@
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.7"
+[[deps.Distributed]]
+deps = ["Random", "Serialization", "Sockets"]
+uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+
+[[deps.DocStringExtensions]]
+deps = ["LibGit2"]
+git-tree-sha1 = "5158c2b41018c5f7eb1470d558127ac274eca0c9"
+uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
+version = "0.9.1"
+
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
+version = "1.6.0"
[[deps.Expat_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -49,6 +120,16 @@
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
+[[deps.Formatting]]
+deps = ["Printf"]
+git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8"
+uuid = "59287772-0a20-5a39-b81b-1366585eb4c0"
+version = "0.4.2"
+
+[[deps.Future]]
+deps = ["Random"]
+uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
+
[[deps.Gettext_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"]
git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046"
@@ -82,19 +163,48 @@
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+[[deps.InverseFunctions]]
+deps = ["Test"]
+git-tree-sha1 = "b3364212fb5d870f724876ffcd34dd8ec6d98918"
+uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
+version = "0.1.7"
+
+[[deps.InvertedIndices]]
+git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f"
+uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
+version = "1.1.0"
+
+[[deps.IrrationalConstants]]
+git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
+uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
+version = "0.1.1"
+
+[[deps.IteratorInterfaceExtensions]]
+git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
+uuid = "82899510-4779-5014-852e-03e436cf321d"
+version = "1.0.0"
+
[[deps.JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.4.1"
+[[deps.JuliaInterpreter]]
+deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
+git-tree-sha1 = "1101d9e5a062963612e8d2bd5bd653d73ae033f4"
+uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
+version = "0.9.14"
+
[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
+version = "0.6.3"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
+version = "7.83.1+1"
[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
@@ -103,6 +213,7 @@
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
+version = "1.10.2+0"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -117,9 +228,21 @@
deps = ["Libdl", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+[[deps.LogExpFunctions]]
+deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
+git-tree-sha1 = "7c88f63f9f0eb5929f15695af9a4d7d3ed278a91"
+uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
+version = "0.3.16"
+
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+[[deps.LoweredCodeUtils]]
+deps = ["JuliaInterpreter"]
+git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7"
+uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
+version = "2.2.2"
+
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -133,19 +256,29 @@
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
+version = "2.28.0+0"
+
+[[deps.Missings]]
+deps = ["DataAPI"]
+git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f"
+uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
+version = "1.0.2"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
+version = "2022.2.1"
[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
+version = "1.2.0"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
+version = "0.3.20+0"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -153,13 +286,32 @@
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
version = "1.1.17+0"
+[[deps.OrderedCollections]]
+git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
+uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+version = "1.4.1"
+
+[[deps.PCRE2]]
+deps = ["BinaryProvider", "Libdl"]
+git-tree-sha1 = "fc4205405f792d5e3a0eac7a49ae40e55ec6e04b"
+uuid = "c9310f65-a42c-5928-aca3-d34f64192029"
+version = "1.0.2"
+
[[deps.PCRE2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15"
+version = "10.40.0+0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+version = "1.8.0"
+
+[[deps.PooledArrays]]
+deps = ["DataAPI", "Future"]
+git-tree-sha1 = "a6062fe4063cdafe78f4a0a81cfffb89721b30e7"
+uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
+version = "1.4.2"
[[deps.Preferences]]
deps = ["TOML"]
@@ -167,6 +319,12 @@
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.3.0"
+[[deps.PrettyTables]]
+deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"]
+git-tree-sha1 = "dfb54c4e414caa595a1f2ed759b160f5a3ddcba5"
+uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
+version = "1.3.1"
+
[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -185,15 +343,43 @@
deps = ["SHA", "Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+[[deps.Reexport]]
+git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
+uuid = "189a3867-3050-52da-a836-e630ba90ab69"
+version = "1.2.2"
+
+[[deps.Requires]]
+deps = ["UUIDs"]
+git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
+uuid = "ae029012-a4dd-5104-9daa-d747884805df"
+version = "1.3.0"
+
+[[deps.Revise]]
+deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
+git-tree-sha1 = "c73149ff75d4efb19b6d77411d293ae8fb55c58e"
+uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
+version = "3.3.4"
+
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
+version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+[[deps.SharedArrays]]
+deps = ["Distributed", "Mmap", "Random", "Serialization"]
+uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
+
[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+[[deps.SortingAlgorithms]]
+deps = ["DataStructures"]
+git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508"
+uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
+version = "1.0.1"
+
[[deps.SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -208,13 +394,37 @@
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.4.0"
+[[deps.StatsBase]]
+deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
+git-tree-sha1 = "0005d75f43ff23688914536c5e9d5ac94f8077f7"
+uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+version = "0.33.20"
+
[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
+version = "1.0.0"
+
+[[deps.TableTraits]]
+deps = ["IteratorInterfaceExtensions"]
+git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
+uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
+version = "1.0.1"
+
+[[deps.Tables]]
+deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"]
+git-tree-sha1 = "5ce79ce186cc678bbb5c5681ca3379d1ddae11a1"
+uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
+version = "1.7.0"
[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
+version = "1.10.0"
+
+[[deps.Test]]
+deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
+uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.URIs]]
git-tree-sha1 = "e59ecc5a41b000fa94423a578d29290c7266fc10"
@@ -237,6 +447,7 @@
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
+version = "1.2.12+3"
[[deps.gh_cli_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -247,11 +458,14 @@
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
+version = "5.1.1+0"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
+version = "1.47.0+0"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
+version = "17.4.0+0"
diff --git a/Project.toml b/Project.toml
index 1920e79..9e99411 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,11 +6,15 @@
[deps]
ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
+PCRE2 = "c9310f65-a42c-5928-aca3-d34f64192029"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
+StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
julia = "1"
diff --git a/README.md b/README.md
index 1c1a710..7e11afa 100644
--- a/README.md
+++ b/README.md
@@ -32,6 +32,8 @@
model2 = load("model2.vecs")
cos_sim(model, model2, "also")
+
+kld(model2, model)
```
## License
diff --git a/src/DerekoVecs.jl b/src/DerekoVecs.jl
index e0ffb02..7b66d09 100644
--- a/src/DerekoVecs.jl
+++ b/src/DerekoVecs.jl
@@ -1,64 +1,160 @@
module DerekoVecs
-export load, knn, cos_sim, d2vmodel
+export load, knn, cos_sim, d2vmodel, kld, kldResult
using Mmap
using DelimitedFiles
using LinearAlgebra
using Distances
+using DataFrames
using Pkg.Artifacts
+using StatsBase
+using Base.PCRE
struct d2vmodel
M::Matrix{Float32}
m::Int64
n::Int64
- vocabdict::Dict{String, Int64}
+ vocabdict::Dict{String,Int64}
vocab::Array{String}
+ freqs::Array{Float64}
+ total_tokens::Int64
end
-function load(modelfn) ::d2vmodel
- (n, d) = map(s->parse(Int, s), split(readline(modelfn), " "))
- vocabfn = replace(modelfn, ".vecs" => ".vocab")
+struct kldResult
+ df::DataFrame
+ token_count::Array{Int64}
+ common_type_count::Int64
+ common_token_count::Int64
+ common_type_share::Float64
+ common_token_share::Float64
+ kld::Float64
+end
- vocab = readdlm(vocabfn, ' ', String, dims=(n,2), quotes = false)[:,1]
- vocabdict = Dict{String, Int64}(zip(vocab, 1:n))
+function load(modelfn)::d2vmodel
+ if (occursin(r".vecs$", modelfn))
+ (n, d) = map(s -> parse(Int, s), split(readline(modelfn), " "))
+ vocabfn = replace(modelfn, ".vecs" => ".vocab")
+ sizefn = replace(modelfn, ".vecs" => ".size")
+ file = readdlm(vocabfn, ' ', String, dims=(n, 2), quotes=false)
+ vocab = file[:, 1]
+ total = if (isfile(sizefn)) # .size-file with corrected token count?
+ open(sizefn) do io
+ readline(io)
+ parse(Int, readline(io))
+ end
+ else
+ sum(map(x -> parse(Int64, x), file[:, 2]))
+ end
+ else
+ delim = ('\t' in readline(modelfn) ? '\t' : ' ')
+ file = readdlm(modelfn, delim, String, quotes=false)
+ vocab = file[:, 1]
+ n = length(vocab)
+ total = sum(map(x -> parse(Int64, x), file[:, 2]))
+ end
+
+ freqs = map(x -> parse(Float64, x) / total, file[:, 2])
+ vocabdict = Dict{String,Int64}(zip(vocab, 1:n))
vecsfn = "$(modelfn).vecs"
- M = Mmap.mmap(vecsfn, Matrix{Float32}, (d,n))
- d2vmodel(M, d, n, vocabdict, vocab)
+ if (occursin(r".vecs$", modelfn) && isfile(vecsfn))
+ M = Mmap.mmap(vecsfn, Matrix{Float32}, (d, n))
+ d2vmodel(M, d, n, vocabdict, vocab, freqs, total)
+ else
+ d2vmodel(Matrix{Float32}(undef, 2, 0), 0, n, vocabdict, vocab, freqs, total)
+ end
end
function cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64, w2index::Int64)
try
dot(m1.M[:, w1index], m2.M[:, w2index])
- catch error
- -1.0
+ catch error
+ -1.0
end
end
-cos_sim(m::d2vmodel, w1::String, w2::String) = cos_sim(m, m, w1,w2)
+cos_sim(m::d2vmodel, w1::String, w2::String) = cos_sim(m, m, w1, w2)
cos_sim(m1::d2vmodel, m2::d2vmodel, w1::String, w2::String) = cos_sim(m1, m2, m1.vocabdict[w1], m2.vocabdict[w2])
cos_sim(m1::d2vmodel, m2::d2vmodel, w::String) = cos_sim(m1, m2, m1.vocabdict[w], m2.vocabdict[w])
cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64) = cos_sim(m1, m2, w1index, w1index)
cos_sim(m::d2vmodel, w1index::Int64, w2index::Int64) = cos_sim(m, m, w1index, w2index)
function minus(m::d2vmodel, w1::String, w2::String)
- knn(m, normalize(m.M[:,m.vocabdict[w1]] - m.M[:,m.vocabdict[w2]]), 1)[1]
+ knn(m, normalize(m.M[:, m.vocabdict[w1]] - m.M[:, m.vocabdict[w2]]), 1)[1]
end
function knn(m::d2vmodel, v::Array{Float32}, k)
# dist = Array{Float64}(undef, size(m.M)[2])
- #@time Threads.@threads for i in 1:size(M)[2] dist[i]=dot(v, M[:,i]) end
+ #@time Threads.for i in 1:size(M)[2] dist[i]=dot(v, M[:,i]) end
#knn = sortperm(dist, rev=true)[1:k]
- knn = sortperm(map(x->dot(v, m.M[:,x]), 1:m.n), rev=true)[1:k]
- map(x->m.vocab[x], knn)
+ knn = sortperm(map(x -> dot(v, m.M[:, x]), 1:m.n), rev=true)[1:k]
+ map(x -> m.vocab[x], knn)
end
function knn(m::d2vmodel, w1index::Int64, k::Int)
- knn(m, m.M[:,w1index], k)
+ knn(m, m.M[:, w1index], k)
end
function knn(m::d2vmodel, w1::String, k)
knn(m, m.vocabdict[w1], k)
end
+kldc(p, q) = p * log(p / q)
+
+function kld(dictp::Array{Dict}, total::Array{Int64})::kldResult
+ min_vocab_size = minimum(map(length, dictp))
+ min_token_size = minimum(total)
+
+ common_types = collect(reduce(intersect, map(keys, dictp)))
+
+ common_type_share = length(common_types) * 100 / min_vocab_size
+
+ common_tokens = Array{Int64}(undef, length(dictp))
+ p = Array{Float64,2}(undef, 2, length(common_types))
+ for i in 1:length(dictp)
+ common_tokens[i] = 0
+ for j in 1:length(common_types)
+ p[i, j] = get(dictp[i], common_types[j], 0)
+ common_tokens[i] += p[i, j] * total[i]
+ end
+ end
+
+ common_token_share = minimum(common_tokens) * 100.0 / min_token_size
+
+ kld = Array{Float64}(undef, length(common_types))
+ frq = Array{Float64}(undef, length(common_types))
+ for i in 1:(length(dictp)-1)
+ for j in (i+1):length(dictp)
+ for k in 1:length(common_types)
+ kld[k] = kldc(p[1, k], p[j, k])
+ frq[k] = get(dictp[j], common_types[k], 0)
+ end
+ end
+ end
+ kldcs = hcat(Vector{String}(common_types), kld, frq)
+ # df = DataFrame(type = Vector{String}(common_types), pkld = Vector{Float64}(kldcs), freq = Vector{Float64}(frq))
+ df = DataFrame(kldcs, ["type", "pkld", "freq"])
+ # df = DataFrame()
+ df.type = Vector{String}(df.type)
+ df.pkld = Vector{Float64}(df.pkld)
+ df.freq = Vector{Float64}(df.freq)
+ transform!(df, :freq => (x -> competerank(x, rev=true)) => :rank)
+ kldResult(df, total, length(common_types), minimum(common_tokens), common_type_share, common_token_share, sum(df.pkld))
+end
+
+"Calculate contributions to the Kullback-Leibler divergence from the target language model to the background language model"
+function kld(target::d2vmodel, bg::d2vmodel)::kldResult
+ dictp = Array{Dict}(undef, 2)
+ total = Array{Int64}(undef, 2)
+
+ # delim = ('\t' in readline(fnames[1]) ? '\t' : ' ')
+
+ dictp[1] = Dict(zip(target.vocab, target.freqs))
+ dictp[2] = Dict(zip(bg.vocab, bg.freqs))
+
+ kld(dictp, [target.total_tokens, bg.total_tokens])
+end
+
+kld(targetfn::String, bgfn::String)::kldResult = kld(load(targetfn, load(bgfn)))
+
end
diff --git a/test/runtests.jl b/test/runtests.jl
index 7dfe844..e0ff466 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -8,8 +8,8 @@
wpd19 = load(joinpath(artifact"wpd19_10000", "wpd19_10000", "wpd19_10000.vecs"))
@testset "DerekoVecs.jl: loading" begin
- @test wpd19.m == 200;
- @test wpd19.n >= 10000;
+ @test wpd19.m == 200
+ @test wpd19.n >= 10000
end
@testset "DerekoVecs.jl: similarities" begin
@@ -27,4 +27,18 @@
@test "wurden" in knn(wpd19, "wurde", 3)
end
+ @testset "DerekoVecs.jl: kld" begin
+ mykld = kld(wpd19, wpd19)
+ @test mykld.common_type_count == length(wpd19.vocabdict)
+ @test isapprox(mykld.common_type_share, 100)
+ @test isapprox(mykld.kld, 0)
+ @test wpd19.total_tokens == mykld.common_token_count
+ @test isapprox(mykld.common_token_share, 100)
+ end
+
+ @testset "DerekoVecs.jl: load freq list only" begin
+ wpd19_freqlist = load(joinpath(artifact"wpd19_10000", "wpd19_10000", "wpd19_10000.vocab"))
+ @test wpd19.total_tokens == wpd19_freqlist.total_tokens
+ @test isapprox(kld(wpd19_freqlist, wpd19).kld, 0)
+ end
end