In [5]:
using Random, Distributions
using Plots
using MAT
using StatsBase
using FFTW
using Roots
using Optim
# for reproducibility
Random.seed!(1);
In [14]:
# mixture Gaussian distribution for testing
mixture_gaussian = MixtureModel(Normal[
Normal(-2.0, 1.2),
Normal(0.0, 1.0),
Normal(3.0, 0.5),
Normal(4.0,1)],[0.2, 0.2, 0.5, 0.1]);
# draw 1000 samples
samples = rand(mixture_gaussian, 5000);
# histogram to verify
histogram(samples);
In [7]:
function adaptive_kde(samples)
n = length(samples); # number of grid points, change later
# get range of realizations
x_max = maximum(samples); x_min = minimum(samples);
scaling = x_max-x_min;
println(scaling);
x_max = x_max+scaling/10;
x_min = x_min-scaling/10;
scaling = x_max-x_min;
println(scaling);
# shift by min and scale down
samples = (samples.-x_min)./scaling;
# mesh size (not needed if grid is specified beforehand)
mesh_size = 2^12;
# create grid
meshgrid = Array((x_min:(scaling/(mesh_size-1)):x_max));
meshgrid = (meshgrid.-x_min)./scaling;
# default cost and accuracy tradeoff
gamma = Int64(ceil(n^(1/3))+20);
@assert gamma <= n
del = 0.2/(n^0.2); idx = shuffle(1:n);
# take a subset of random permutations
idx_gamma = idx[1:gamma];
# slice of data
mu = samples[idx_gamma];
# generate perturbations
w = rand(gamma); w = w./sum(w);
sig = (del^2).*rand(gamma);
ent = -Inf;
maxiter = 1e+4;
for iter = 1:maxiter
Eold = ent;
w, mu, sig, del, ent = regEM(w, mu, sig, del, samples);
err = abs((ent-Eold)/ent);
println("Iter. Tol. Bandwidth \n");
println("------------------------------------------------------------\n");
println("$iter $err $del \n");
println("------------------------------------------------------------\n");
# stopping condition
if err < 1e-5 || iter>200
break
end
end
p = probfun(meshgrid, w, mu, sig)./scaling;
del = del*scaling;
return (p, meshgrid)
end
Out[7]:
adaptive_kde (generic function with 1 method)
In [11]:
function probfun(x, w, mu, sig)
gamma = length(mu);
n = length(x);
out = zeros(n);
for k = 1:gamma
s = sig[k];
xx = x.-mu[k];
xx = (xx.^2)./s;
out = out+exp.(-0.5*xx.+log(w[k]).-0.5*sum(log(s)).-
log(2*π)/2)
end
return out
end
Out[11]:
probfun (generic function with 1 method)
In [12]:
function regEM(w, mu, sig, del, samples)
"""
Update parameters
"""
# preallocate
gamma = length(mu);
n = length(samples);
log_lh = zeros(n,gamma);
log_sig = zeros(n,gamma);
for i = 1:gamma
s = sig[i];
samples_centered = samples.-mu[i];
xRinv = (samples_centered.^2)./s;
xSig = (xRinv./s).+eps(Float64);
log_lh[:,i] .= -0.5*xRinv.-0.5*sum(log(s)).+
log(w[i]).-(log(2*π)/2).-0.5*del^2*sum(1/s);
log_sig[:,i] .= log_lh[:,i].+log.(xSig);
end
maxll = maximum(log_lh, dims=2);
maxlsig = maximum(log_sig, dims=2);
p = exp.(log_lh.-maxll);
psig = exp.(log_sig.-maxlsig);
density = sum(p, dims=2); psig = sum(psig, dims=2);
log_p = log.(density)+maxll;
log_psigd = log.(psig)+maxlsig;
# normalize classification prob
p = p./density;
# update
ent = sum(log_p); w = sum(p, dims=1)[:];
# find positive values
pos_idx = findall(w.>0);
for i = pos_idx
mu[i] = p[:,i]'*samples/w[i];
samples_centered = samples.-mu[i];
sig[i] = p[:,i]'*(samples_centered.^2)/w[i]+del^2;
end
w = w./sum(w);
curv = mean(exp.(log_psigd-log_p));
del = 1/(4*n*(4*π)^(1/2)*curv)^(1/3);
return (w, mu, sig, del, ent)
end
Out[12]:
regEM (generic function with 1 method)
In [13]:
# fast KDE with cosine transform
function kde(samples)
"""
Very fast kernel density estimator via
discrete cosine transform
Requires StatsBase package to estimate histogram bin counts.
"""
# number of mesh points (default)
n = 2^14;
# default grid interval
MIN = minimum(samples);
MAX = maximum(samples);
scaling = MAX-MIN;
# padding
MIN = MIN - scaling/2;
MAX = MAX + scaling/2;
R = MAX-MIN;
dx = R/(n-1);
xmesh = MIN .+ (0:dx:R);
N = length(unique(samples));
# bin data uniformly (as a density)
initial_data = fit(Histogram, samples, xmesh, closed=:left).weights;
initial_data = initial_data./N;
initial_data = initial_data./sum(initial_data);
# apply discrete cosine transform
a = dct1d(initial_data);
# compute optimal bandwidth^2 (see ref.)
I = Vector(1:n-1)'.^2;
a2 = (a[2:end]./2).^2;
# find fixed point for ``t=zeta*gamma^[5](t)``
# function handle
f = t -> fixed_point(t,N,I,a2);
t_star = find_root(f,N);
# smooth discrete cosine transform
a_t = a.*exp.(-[0:n-1]'.^2*π.*t_star./2);
# apply inverse DCT to recover density
density = idct1d(a_t)./R;
bandwidth = sqrt(t_star)*R;
density[density.<0] .= eps;
end
function dct1d(data::Vector{Float64})
"""
Applies 1d discrete cosine transform to data vector.
See:
Anil K. Jain, Fundamentals of Image Processing
Requires FFTW package to compute fast Fourier transform.
"""
n = length(data);
# weights to multiply DFT coefficients
W = [1;2*(exp.(-1im*(1:n-1)*pi/(2*n)))];
# reorder elements of data
data = [data[1:2:end];data[end:-2:2]];
fft_data = fft(data);
return real((W.*fft_data))
end
function idct1d(data::Vector{Float64})
"""
Applies 1d inverse discrete cosine transform to data vector.
See:
Anil K. Jain, Fundamentals of Image Processing
Requires FFTW package to compute fast Fourier transform.
"""
n = length(data);
# weights to multiply iDFT coefficients
W = n*exp.(1im*Vector(0:n-1)*π/(2*n))';
ifft_data = real(ifft(W.*data));
# reorder
out = zeros(n,1);
out[1:2:n] = data[1:Int64(n/2)];
out[2:2:n] = data[n:-1:Int64(n/2)+1];
return out;
end
function fixed_point(t, N, I, a2)
"""
Helper function that implements:
f(t) = t-zeta*gamma^[l](t).
"""
l = 7;
f = 2*π^(2*l)*sum(I.^l.*a2.*exp.(-I.*π.^2 .*t));
for s = l-1:-1:2
K0 = prod(Vector(1:2:2*s-1))./sqrt(2*π);
cnst = (1+(1/2)^(s+1/2))/3;
Time = (2*cnst*K0/N/f)^(2/(3+2*s));
f=2*π^(2*s)*sum(I.^s.*a2.*exp.(-I*pi^2*Time));
end
return (t-(2*N*sqrt(π)*f)^(-2/5));
end
function find_root(f, N)
"""
Helper function, finds root to f(t)
Requires Roots package to find zeros.
Requires Optim package to find root approximately (Brent's method)
"""
# finds the minimal root when there is more than one
N = 50*(N<=50)+1050*(N>=1050)+N*((N<1050)&(N>50));
tol = 10^-12+0.01*(N-50)/1000;
flag = false;
while ~flag
try
# ? Caveat: uses bisection, which requires [0,tol] to be bracketing
t = fzero(f, [0,tol]);
flag = ~flag;
catch
# double the search interval
tol = min(tol*2, 0.1);
println(tol);
end
if tol == 0.1
# use a nonlinear root finding alg.
t = optimize(t->abs(f(t)), 0, 0.1);
flag = ~flag;
end
end
return (t);
end
Out[13]:
find_root (generic function with 1 method)