%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%                                %%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%      DECLIPPING MAIN FILE      %%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%                                %%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% uses LTFAT toolbox
% http://ltfat.github.io/


close all
clear variables


%% input file settings

audio_file = 'acoustic_guitar'; % 'acoustic_guitar'
                                % 'double_bass'
                                % 'sample'
                                % 'speech'
                                % 'vivaldi_storm'
                                
fprintf(['Loading audio ''', audio_file , '.wav''\n']);
[data, fs] = audioread(['Sounds/' audio_file '.wav']);

% normalization
fprintf('Normalizing audio')
maxAbsVal = max(abs(data));
data = data/maxAbsVal;
fprintf('(The original maximum was %f)\n', maxAbsVal)

% signal length
param.Ls = length(data);

%% settings
fprintf('Setting up parameters\n')

% clipping threshold
param.theta = 0.3;     % set (symetric) clipping threshold

% DGT parameters
param.wtype = 'hann';  % window type
param.w = 1024;        % window length 
param.a = param.w / 4; % window shift
param.M = 1024;        % number of frequency channels     

% algorithm
param.algorithm = 'DR'; % algorithm to compute declipping, options: 'Condat', 'Douglas-Rachford'

param.replaceReliable = 1;  % after finishing all iterations, replace reliable samples in declipped signal with original signal


%% clipping
fprintf('Generating clipped signal\n')
[data_clipped, param.masks] = hard_clip(data, -param.theta, param.theta); % clipping of original signal


%% construction of frame
fprintf('Creating the frame\n')

param.F = frame('dgtreal', {param.wtype, param.w}, param.a, param.M); % creating requested frame

param.F = frametight(param.F); % creating Parseval tight frame
param.F = frameaccel(param.F, param.Ls);  % precomputation for a fixed signal length

% display the framebounds
[A, B] = framebounds(param.F, param.Ls);
fprintf('Framebounds are %g and %g\n',A,B)
if abs(A/B-1) >= 10e-6
    warning('Frame is not tight!')
end
fprintf('\n')


%% Proximal algorithm
fprintf('Starting the proximal algorithm\n')

tic;
switch param.algorithm
    case {'Condat', 'C'}
        [data_rec, dSDR_process, objective_process] = condat(data_clipped, param, data);
    case {'Douglas-Rachford', 'DR'}
        [data_rec, dSDR_process, objective_process] = douglas_rachford(data_clipped, param, data);
    otherwise
        error('Invalid algorithm is set!');
end
time = toc;

% replace reliable samples with original from clipped signal and compute error on reliable samples
if param.replaceReliable
    reliableDifference = norm(data_clipped(param.masks.Mr) - data_rec(param.masks.Mr));
    data_rec(param.masks.Mr) = data_clipped(param.masks.Mr);
    fprintf('l2-norm of difference on reliable samples after proximal algorithm is %4.3f \n', reliableDifference);
end


%% Evaluation of the result
fprintf('Computing final SDR and plotting some nice figures\n')

% time
fprintf('Result obtained in %4.3f seconds.\n', time);

% SDR
sdr_clip = sdr(data, data_clipped);
sdr_rec = sdr(data, data_rec);
dsdr = sdr_rec - sdr_clip;
fprintf('SDR of the clipped signal is %4.3f dB.\n', sdr_clip);
fprintf('SDR of the reconstructed signal is %4.3f dB.\n', sdr_rec);
fprintf('SDR improvement is %4.3f dB.\n', dsdr);

% Plot of declpped signal
t = linspace(0, length(data)/fs, length(data));
figure
plot(t, data);
hold on
plot(t, data_rec);
plot(t, data_clipped);
xlabel('time (s)')
legend({'original', 'reconstructed', 'clipped'})

% Plot of dSDR and objective function
figure
len = length(dSDR_process(~isnan(dSDR_process)));
t = linspace(0, time, len);
yyaxis left
    p1 = plot(t,dSDR_process);
    hold on
    px1 = plot(t(1:100:len), dSDR_process(1:100:len), 'x');
    ylabel('{\Delta}SDR (dB)');
      
yyaxis right
    p2 = plot(t, objective_process);
    px2 = plot(t(1:100:len), objective_process(1:100:len), 'x');
    xlabel('time (s)');
    ylabel('Objective function');
    grid on




