using NBInclude
nbinclude("Basic Functions.ipynb");
# d-ary Huffman coding: a recursive implementation I
# WITH sorting the pmf!
# The input pmf does not have to be normalized.
function daryHuffman(pmf, d)
pmf = pmf[pmf .> 0]
n = length(pmf)
if d > 2
r = mod(n, d-1)
# if the number of symbols is not the form of k*(d-1)+1, append 0's at the end.
if r != 1
r = (r == 0 ? d-1 : r)
append!(pmf, zeros(d-r))
end
end
perm = sortperm(pmf, rev=true) # sort the pmf in descending order
sorted_pmf = pmf[perm]
list_range = map(string, collect(0:d-1))
if n > d
input_pmf = sorted_pmf[1:end-d]
append!(input_pmf, sum(sorted_pmf[end-d+1:end])) # merge the r smallest probability masses
code = daryHuffman(input_pmf, d) # call the huffman function recursively
last_codeword = pop!(code) # pop the last codeword which corresponds to the merged mass
append!(code, map(x -> "$(last_codeword)$x", list_range)) # append the r codewords accordingly
return ipermute!(code, perm) # permute back the code according to the previous permutation
elseif n == d
return list_range[perm]
else
print("ERROR:: You have only one symbol!\n")
return 0
end
end
# d-ary Huffman coding: a recursive implementation II
# WITHOUT sorting the pmf!
# The input pmf does not have to be normalized.
function daryHuffman2(pmf, d)
pmf = pmf[pmf .> 0]
n = length(pmf)
if d > 2
r = mod(n, d-1)
# if the number of symbols is not the form of k*(d-1)+1, append 0's at the end.
if r != 1
r = (r == 0 ? d-1 : r)
append!(pmf, zeros(d-r))
end
end
if n > d
input_pmf = copy(pmf)
min_vals, idx_list = minValues(input_pmf, d) # pop r smallest values from input_pmf, and get the indices as well
append!(input_pmf, sum(min_vals)) # merge the r smallest probability masses and append it
code = daryHuffman2(input_pmf, d) # call the huffman function recursively
last_codeword = pop!(code) # pop the last codeword which corresponds to the merged mass
list_range = map(string, collect(0:d-1)) # append the r codewords accordingly
for i = d:-1:1
insert!(code, idx_list[i], "$(last_codeword)$(list_range[i])")
end
return code
elseif n == d
return map(string, collect(0:d-1))
else
print("ERROR:: You have less than ", d, " symbols!\n")
return 0
end
end