module DerekoVecs
export load, knn, cos_sim, d2vmodel, kld, kldResult, get_collocates

using Mmap
using DelimitedFiles
using LinearAlgebra
using Distances
using DataFrames
using Pkg.Artifacts
using StatsBase
using Base.PCRE
using Libdl

const libcdb = Libdl.find_library("libcollocatordb", [".", "/usr/local/lib64", "/usr/lib64", "/usr/local/lib", "/usr/lib"])

struct d2vmodel
    M::Matrix{Float32}
    m::Int64
    n::Int64
    vocabdict::Dict{String,Int64}
    vocab::Array{String}
    freqs::Array{Float64}
    total_tokens::Int64
    cdb::Union{Ptr{Nothing},Nothing}
end

struct kldResult
    df::DataFrame
    token_count::Array{Int64}
    common_type_count::Int64
    common_token_count::Int64
    common_type_share::Float64
    common_token_share::Float64
    kld::Float64
end

struct collocate
    w2::UInt32
    f2::UInt64
    raw::UInt64
    pmi::Float64
    npmi::Float64
    llr::Float64
    lfmd::Float64
    md::Float64
    left_raw::UInt64
    right_raw::UInt64
    left_pmi::Float64
    right_pmi::Float64
    dice::Float64
    logdice::Float64
    ldaf::Float64
    window::Int32
    af_window::Int32
end


function load(modelfn)::d2vmodel
    cdb = nothing
    if (occursin(r".vecs$", modelfn))
        (n, d) = map(s -> parse(Int, s), split(readline(modelfn), " "))
        vocabfn = replace(modelfn, ".vecs" => ".vocab")
        file = readdlm(vocabfn, ' ', String, dims=(n, 2), quotes=false)

        rocksdbfn = replace(modelfn, ".vecs" => "")
        if (isfile(rocksdbfn * ".rocksdb/CURRENT") && libcdb != "")
            cdb = open_collocatordb(rocksdbfn)
        end
    else
        delim = ('\t' in readline(modelfn) ? '\t' : ' ')
        file = readdlm(modelfn, delim, String, quotes=false)
    end

    vocab = file[:, 1]
    n = length(vocab)

    sizefn = replace(modelfn, r"\.[^.]+" => s".size")
    total = if (isfile(sizefn)) # .size-file with corrected token count?
        open(sizefn) do io
            readline(io)
            parse(Int, readline(io))
        end
    else
        sum(map(x -> parse(Int64, x), file[:, 2]))
    end

    freqs = map(x -> parse(Float64, x) / total, file[:, 2])
    vocabdict = Dict{String,Int64}(zip(vocab, 1:n))
    vecsfn = "$(modelfn).vecs"
    if (occursin(r".vecs$", modelfn) && isfile(vecsfn))
        M = Mmap.mmap(vecsfn, Matrix{Float32}, (d, n))
        d2vmodel(M, d, n, vocabdict, vocab, freqs, total, cdb)
    else
        d2vmodel(Matrix{Float32}(undef, 2, 0), 0, n, vocabdict, vocab, freqs, total, cdb)
    end
end

function cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64, w2index::Int64)
    try
        dot(m1.M[:, w1index], m2.M[:, w2index])
    catch error
        -1.0
    end
end

cos_sim(m::d2vmodel, w1::String, w2::String) = cos_sim(m, m, w1, w2)
cos_sim(m1::d2vmodel, m2::d2vmodel, w1::String, w2::String) = cos_sim(m1, m2, m1.vocabdict[w1], m2.vocabdict[w2])
cos_sim(m1::d2vmodel, m2::d2vmodel, w::String) = cos_sim(m1, m2, m1.vocabdict[w], m2.vocabdict[w])
cos_sim(m1::d2vmodel, m2::d2vmodel, w1index::Int64) = cos_sim(m1, m2, w1index, w1index)
cos_sim(m::d2vmodel, w1index::Int64, w2index::Int64) = cos_sim(m, m, w1index, w2index)

function minus(m::d2vmodel, w1::String, w2::String)
    knn(m, normalize(m.M[:, m.vocabdict[w1]] - m.M[:, m.vocabdict[w2]]), 1)[1]
end

function knn(m::d2vmodel, v::Array{Float32}, k)
    # dist = Array{Float64}(undef, size(m.M)[2])
    #@time Threads.for i in 1:size(M)[2] dist[i]=dot(v, M[:,i]) end
    #knn = sortperm(dist, rev=true)[1:k]
    knn = sortperm(map(x -> dot(v, m.M[:, x]), 1:m.n), rev=true)[1:k]
    map(x -> m.vocab[x], knn)
end

function knn(m::d2vmodel, w1index::Int64, k::Int)
    knn(m, m.M[:, w1index], k)
end

function knn(m::d2vmodel, w1::String, k)
    knn(m, m.vocabdict[w1], k)
end

kldc(p, q) = p * log(p / q)

function kld(dictp::Array{Dict}, total::Array{Int64})::kldResult
    min_vocab_size = minimum(map(length, dictp))
    min_token_size = minimum(total)

    common_types = collect(reduce(intersect, map(keys, dictp)))

    common_type_share = length(common_types) * 100 / min_vocab_size

    common_tokens = Array{Int64}(undef, length(dictp))
    p = Array{Float64,2}(undef, 2, length(common_types))
    for i in 1:length(dictp)
        common_tokens[i] = 0
        for j in 1:length(common_types)
            p[i, j] = get(dictp[i], common_types[j], 0)
            common_tokens[i] += p[i, j] * total[i]
        end
    end

    common_token_share = minimum(common_tokens) * 100.0 / min_token_size

    kld = Array{Float64}(undef, length(common_types))
    frq = Array{Float64}(undef, length(common_types))
    for i in 1:(length(dictp)-1)
        for j in (i+1):length(dictp)
            for k in 1:length(common_types)
                kld[k] = kldc(p[1, k], p[j, k])
                frq[k] = get(dictp[j], common_types[k], 0)
            end
        end
    end
    kldcs = hcat(Vector{String}(common_types), kld, frq)
    # df = DataFrame(type = Vector{String}(common_types), pkld = Vector{Float64}(kldcs), freq = Vector{Float64}(frq))
    df = DataFrame(kldcs, ["type", "pkld", "freq"])
    # df = DataFrame()
    df.type = Vector{String}(df.type)
    df.pkld = Vector{Float64}(df.pkld)
    df.freq = Vector{Float64}(df.freq)
    transform!(df, :freq => (x -> competerank(x, rev=true)) => :rank)
    kldResult(df, total, length(common_types), minimum(common_tokens), common_type_share, common_token_share, sum(df.pkld))
end

"Calculate contributions to the Kullback-Leibler divergence from the target language model to the background language model"
function kld(target::d2vmodel, bg::d2vmodel)::kldResult
    dictp = Array{Dict}(undef, 2)
    total = Array{Int64}(undef, 2)

    # delim = ('\t' in readline(fnames[1]) ? '\t' : ' ')

    dictp[1] = Dict(zip(target.vocab, target.freqs))
    dictp[2] = Dict(zip(bg.vocab, bg.freqs))

    kld(dictp, [target.total_tokens, bg.total_tokens])
end

kld(targetfn::String, bgfn::String)::kldResult = kld(load(targetfn), load(bgfn))

function open_collocatordb(path::String)::Ptr{Cvoid}
    @ccall libcdb.open_collocatordb(path::Cstring)::Ptr{Cvoid}
end

function get_collocates(cdb::Ptr{Nothing}, node::Int64)::Vector{collocate}
    res = @ccall libcdb.get_collocators(cdb::Ptr{collocate}, node::Cuint)::Ptr{collocate}
    i = 0
    for c in unsafe_wrap(Vector{collocate}, res, 1000, own=false)
        if (c.w2 == 0)
            break
        end
        i += 1
    end
    unsafe_wrap(Vector{collocate}, res, i - 1, own=false)
end

function get_collocates(dv::d2vmodel, node::Int)::Vector{collocate}
    get_collocates(dv.cdb, node)
end

function get_collocates(dv::d2vmodel, node::String)::Vector{collocate}
    get_collocates(dv, dv.vocabdict[node])
end

end

# cdb = open_collocatordb("/vol/work/kupietz/Work2/kl/trunk/Analysemethoden/word2vec/models/dereko-2017-ii")
