blob: 66100a86f4aa39bf6d964aae83be21b9636f9d42 [file] [log] [blame]
module DerekoVecs
export load, knn, cos_sim, d2vmodel, kld, kldResult, get_collocates, collocate
using Mmap
using DelimitedFiles
using LinearAlgebra
using Distances
using DataFrames
using Pkg.Artifacts
using StatsBase
using Base.PCRE
using Libdl
const libcdb = Libdl.find_library("libcollocatordb", [".", "/usr/local/lib64", "/usr/lib64", "/usr/local/lib", "/usr/lib"])
struct d2vmodel
M::Matrix{Float32}
m::Int64
n::Int64
vocabdict::Dict{String,Int64}
vocab::Array{String}
freqs::Array{Float64}
total_tokens::Int64
cdb::Union{Ptr{Nothing},Nothing}
end
struct kldResult
df::DataFrame
token_count::Array{Int64}
common_type_count::Int64
common_token_count::Int64
common_type_share::Float64
common_token_share::Float64
kld::Float64
end
struct collocate
w2::UInt32
f2::UInt64
raw::UInt64
pmi::Float64
npmi::Float64
llr::Float64
lfmd::Float64
md::Float64
left_raw::UInt64
right_raw::UInt64
left_pmi::Float64
right_pmi::Float64
dice::Float64
logdice::Float64
ldaf::Float64
window::Int32
af_window::Int32
end
function load(modelfn)::d2vmodel
cdb = nothing
if (occursin(r".vecs$", modelfn))
(n, d) = map(s -> parse(Int, s), split(readline(modelfn), " "))
vocabfn = replace(modelfn, ".vecs" => ".vocab")
file = readdlm(vocabfn, ' ', String, dims=(n, 2), quotes=false)
rocksdbfn = replace(modelfn, ".vecs" => "")
if (isfile(rocksdbfn * ".rocksdb/CURRENT") && libcdb != "")
cdb = open_collocatordb(rocksdbfn)
end
else
delim = ('\t' in readline(modelfn) ? '\t' : ' ')
file = readdlm(modelfn, delim, String, quotes=false)
end
vocab = file[:, 1]
n = length(vocab)
sizefn = replace(modelfn, r"\.[^.]+" => s".size")
total = if (isfile(sizefn)) # .size-file with corrected token count?
open(sizefn) do io
readline(io)
parse(Int, readline(io))
end
else
sum(map(x -> parse(Int64, x), file[:, 2]))
end
freqs = map(x -> parse(Float64, x) / total, file[:, 2])
vocabdict = Dict{String,Int64}(zip(vocab, 1:n))
vecsfn = "$(modelfn).vecs"
if (occursin(r".vecs$", modelfn) && isfile(vecsfn))
M = Mmap.mmap(vecsfn, Matrix{Float32}, (d, n))
d2vmodel(M, d, n, vocabdict, vocab, freqs, total, cdb)
else
d2vmodel(Matrix{Float32}(undef, 2, 0), 0, n, vocabdict, vocab, freqs, total, cdb)
end
end
function cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64, w2index::Int64)
try
dot(m1.M[:, w1index], m2.M[:, w2index])
catch error
-1.0
end
end
cos_sim(m::d2vmodel, w1::String, w2::String) = cos_sim(m, m, w1, w2)
cos_sim(m1::d2vmodel, m2::d2vmodel, w1::String, w2::String) = cos_sim(m1, m2, m1.vocabdict[w1], m2.vocabdict[w2])
cos_sim(m1::d2vmodel, m2::d2vmodel, w::String) = cos_sim(m1, m2, m1.vocabdict[w], m2.vocabdict[w])
cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64) = cos_sim(m1, m2, w1index, w1index)
cos_sim(m::d2vmodel, w1index::Int64, w2index::Int64) = cos_sim(m, m, w1index, w2index)
function minus(m::d2vmodel, w1::String, w2::String)
knn(m, normalize(m.M[:, m.vocabdict[w1]] - m.M[:, m.vocabdict[w2]]), 1)[1]
end
function knn(m::d2vmodel, v::Array{Float32}, k)
# dist = Array{Float64}(undef, size(m.M)[2])
#@time Threads.for i in 1:size(M)[2] dist[i]=dot(v, M[:,i]) end
#knn = sortperm(dist, rev=true)[1:k]
knn = sortperm(map(x -> dot(v, m.M[:, x]), 1:m.n), rev=true)[1:k]
map(x -> m.vocab[x], knn)
end
function knn(m::d2vmodel, w1index::Int64, k::Int)
knn(m, m.M[:, w1index], k)
end
function knn(m::d2vmodel, w1::String, k)
knn(m, m.vocabdict[w1], k)
end
kldc(p, q) = p * log(p / q)
function kld(dictp::Array{Dict}, total::Array{Int64})::kldResult
min_vocab_size = minimum(map(length, dictp))
min_token_size = minimum(total)
common_types = collect(reduce(intersect, map(keys, dictp)))
common_type_share = length(common_types) * 100 / min_vocab_size
common_tokens = Array{Int64}(undef, length(dictp))
p = Array{Float64,2}(undef, 2, length(common_types))
for i in 1:length(dictp)
common_tokens[i] = 0
for j in 1:length(common_types)
p[i, j] = get(dictp[i], common_types[j], 0)
common_tokens[i] += p[i, j] * total[i]
end
end
common_token_share = minimum(common_tokens) * 100.0 / min_token_size
kld = Array{Float64}(undef, length(common_types))
frq = Array{Float64}(undef, length(common_types))
for i in 1:(length(dictp)-1)
for j in (i+1):length(dictp)
for k in 1:length(common_types)
kld[k] = kldc(p[1, k], p[j, k])
frq[k] = get(dictp[j], common_types[k], 0)
end
end
end
kldcs = hcat(Vector{String}(common_types), kld, frq)
# df = DataFrame(type = Vector{String}(common_types), pkld = Vector{Float64}(kldcs), freq = Vector{Float64}(frq))
df = DataFrame(kldcs, ["type", "pkld", "freq"])
# df = DataFrame()
df.type = Vector{String}(df.type)
df.pkld = Vector{Float64}(df.pkld)
df.freq = Vector{Float64}(df.freq)
transform!(df, :freq => (x -> competerank(x, rev=true)) => :rank)
kldResult(df, total, length(common_types), minimum(common_tokens), common_type_share, common_token_share, sum(df.pkld))
end
"Calculate contributions to the Kullback-Leibler divergence from the target language model to the background language model"
function kld(target::d2vmodel, bg::d2vmodel)::kldResult
dictp = Array{Dict}(undef, 2)
total = Array{Int64}(undef, 2)
# delim = ('\t' in readline(fnames[1]) ? '\t' : ' ')
dictp[1] = Dict(zip(target.vocab, target.freqs))
dictp[2] = Dict(zip(bg.vocab, bg.freqs))
kld(dictp, [target.total_tokens, bg.total_tokens])
end
kld(targetfn::String, bgfn::String)::kldResult = kld(load(targetfn), load(bgfn))
function open_collocatordb(path::String)::Ptr{Cvoid}
@ccall libcdb.open_collocatordb(path::Cstring)::Ptr{Cvoid}
end
function get_collocates(cdb::Ptr{Nothing}, node::Int64)::Vector{collocate}
res = @ccall libcdb.get_collocators(cdb::Ptr{collocate}, node::Cuint)::Ptr{collocate}
if res == Ptr{collocate}(C_NULL)
return Vector{collocate}()
end
i = 0
for c in unsafe_wrap(Vector{collocate}, res, 1000, own=false)
if (c.w2 == 0)
break
end
i += 1
end
unsafe_wrap(Vector{collocate}, res, i - 1, own=false)
end
function get_collocates(dv::d2vmodel, node::Int)::Vector{collocate}
get_collocates(dv.cdb, node)
end
function get_collocates(dv::d2vmodel, node::String)::Vector{collocate}
get_collocates(dv, dv.vocabdict[node])
end
end
# cdb = open_collocatordb("/vol/work/kupietz/Work2/kl/trunk/Analysemethoden/word2vec/models/dereko-2017-ii")