import ndarray, { NdArray } from "ndarray";
// NOTE: @types/ndarray-ops exists, but is not accurate.
//@ts-ignore
import ops from 'ndarray-ops';
//@ts-ignore
import linspace from 'ndarray-linspace';
//@ts-ignore
import pool from 'ndarray-scratch';
//@ts-ignore
import fft from 'ndarray-fft';
//@ts-ignore
import cops from 'ndarray-complex';

import { LatencyResult } from "wos-types/LatencyResult";

export function serializeArray(arr: NdArray): string {
    const data = new Array(arr.size);
    for (let i = 0; i < arr.size; i++) {
        data[i] = arr.get(i);
    }
    return JSON.stringify(data);
}

export const LATENCY_CORR_THRESHOLD = 0.1;

export function addTwo(a: number, b: number) {
    return a + b;
}

function mean(x: NdArray): number {
    return ops.sum(x) / x.size;
}

function stddev(x: NdArray): number {
    const mu = mean(x);
    const inner = ndarray(new Float32Array(x.size));
    ops.subs(inner, x, mu);
    ops.muleq(inner, inner);
    const sigma2 = ops.sum(inner) / (x.size - 1);
    const sigma = Math.sqrt(sigma2);
    return sigma;
}

export function argmax(x: NdArray): number {
    // typescript binding seems to be incorrect
    // for this function...
    //@ts-ignore
    return ops.argmax(x)[0];
}

// Shift so that mean=0, stddev=1
export function normalize(arr: NdArray): NdArray {
    const mu = mean(arr);
    const sigma = stddev(arr);

    const out = ndarray(new Float32Array(arr.size));
    ops.assign(out, arr);
    ops.subseq(out, mu);
    ops.divseq(out, sigma);
    return out;
}

// Time offset for cross-correlation
// inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlation_lags.html
export function correlationLags(N: number, M: number): NdArray {
    const nz = N + M - 1

    const lags = ndarray(new Float32Array(nz));
    linspace(lags, -M + 1, nz - M);

    return lags;
}


// Similar to the ndarray-convolve js package,
// but matching the convention of scipy.signal.correlate
export function correlate(a1: NdArray, a2: NdArray): NdArray {
    const N = a1.size;
    const M = a2.size;
    const nz = N + M - 1;

    // zero-pad real inputs
    //@ts-ignore
    const r1: NdArray = pool.zeros([nz]);
    //@ts-ignore
    const r2: NdArray = pool.zeros([nz]);

    // zeros in front
    ops.assign(r1.lo(nz - N), a1);
    // zeros in back
    ops.assign(r2.hi(M), a2);

    // allocate imaginary input parts
    const i1: NdArray = pool.zeros([nz]);
    const i2: NdArray = pool.zeros([nz]);

    // forward fourier transforms
    fft(1, r1, i1);
    fft(1, r2, i2);

    // complex conjugate
    ops.mulseq(i2, -1);

    // allocate result
    const rOut: NdArray = pool.malloc([nz]);
    const iOut: NdArray = pool.malloc([nz]);

    // elementwise complex multiplication
    // in frequency domain
    cops.mul(rOut, iOut, r1, i1, r2, i2);

    // inverse fourier transform
    fft(-1, rOut, iOut);

    // Scale by number of elements
    ops.divseq(rOut, Math.min(N, M));

    // NOTE: discarding final imaginary parts
    // (they should all be zero)

    // Deallocate arrays
    pool.free(r1);
    pool.free(i1);
    pool.free(r2);
    pool.free(i2);
    pool.free(iOut);

    return rOut;
}

export function manualCrossCorr(a1: NdArray, a2: NdArray): NdArray {
    const N = a1.size;
    const norm1 = normalize(a1);
    const norm2 = normalize(a2);
    const corr = ndarray(new Float32Array(N));
    for (let n = 0; n < N; n++) {
        let el = 0;
        for (let m = 0; m < N; m++) {
            // const ind = m - n;
            // // zero-padded
            // if ((0 <= ind) && (ind < N)) {
            //     el += norm1.get(ind) * norm2.get(m);
            // }

            const ind = m + n;
            // zero-padded
            if ((0 <= ind) && (ind < N)) {
                el += norm1.get(m) * norm2.get(ind);
            }
        }
        corr.set(n, el);
    }
    const nmin = Math.min(a1.size, a2.size);
    ops.divseq(corr, nmin);
    return corr;

}

export interface CorrelationResult {
    lags: NdArray,
    corr: NdArray,
}

// Normalized cross-corelation
// https://en.wikipedia.org/wiki/Cross-correlation#Normalization
export function normCrossCorr(a1: NdArray, a2: NdArray): CorrelationResult {
    const norm1 = normalize(a1);
    const norm2 = normalize(a2);

    const lags = correlationLags(a1.size, a2.size);
    const corr = correlate(norm1, norm2);

    return { lags, corr };
}

export function calculateLatencyMillis(played: NdArray, recorded: NdArray, sampleRate: number): LatencyResult {
    console.log('start cross corr');

    // Calculate offset via cross-correlation
    const { lags, corr } = normCrossCorr(played, recorded);

    // console.log('lags', serializeArray(lags));
    // console.log('corr', serializeArray(corr));

    // Find maximum correlatino
    const maxInd = argmax(corr);
    const offset = -lags.get(maxInd);
    const maxCorr = corr.get(maxInd);

    console.log('corr:', corr);
    console.log('offset (# samples): ', offset);
    console.log('maxCorr:', maxCorr);

    const millis = Math.round(offset / sampleRate * 1000);
    console.log('latency', millis);

    return { millis, maxCorr };
}