import moment from "moment";
import { AbsoluteStrikefromDeltaKAnalytic, BlackDelta, BlackDigitalPV, BlackGamma, BlackVega } from "./blackFunctions";
import { BrentsMethodSolve } from "./brentSolver";
import { CalculateYearFraction, DayCountBasis } from "./dayCountBasis";
import { IInterpolator1D } from "./iInterpolator1d";
import { InterpolatorFactory1d } from "./interpolation/interpolatorFactory";
import { Interpolator1DType } from "./interpolator1DType";
import { MultiDimToJaggedArray } from "./multiDimArray";
import { StrikeType } from "./strikeType";
import { TO_GridVolSurface } from "./to_GridVolSurface";
import { binarySearch } from "../utils/math";

export default class GridVolSurface {
    private to: TO_GridVolSurface;

    public OriginDate: Date;
    public Strikes: number[];
    public StrikeType: StrikeType;
    public StrikeInterpolatorType: Interpolator1DType = Interpolator1DType.linearFlatExtrap;
    public ExpiriesDouble: number[];
    public TimeInterpolatorType: Interpolator1DType = Interpolator1DType.linearInVariance;
    public Volatilities: number[][];
    public Expiries: Date[]
    public PillarLabels: string[];
    public TimeBasis: DayCountBasis = DayCountBasis.act365F;

    public FlatDeltaSmileInExtreme: boolean
    public FlatDeltaPoint: number = 0.001;

    public Currency: string;
    public AssetId: string

    _interpolators: IInterpolator1D[]

    constructor(transportObject?: TO_GridVolSurface) {
        this.to = transportObject;

        this.init(transportObject.originDate, transportObject.strikes, transportObject.expiries, MultiDimToJaggedArray(transportObject.volatilities), transportObject.strikeType,
            transportObject.strikeInterpolatorType, transportObject.timeInterpolatorType, transportObject.timeBasis, transportObject.pillarLabels)
    }

    init(originDate: Date, strikes: number[], expiries: Date[], vols: number[][],
        strikeType: StrikeType, strikeInterpType: Interpolator1DType, timeInterpType: Interpolator1DType,
        timeBasis: DayCountBasis, pillarLabels: string[] = null) {
        this.StrikeType = strikeType;
        this.StrikeInterpolatorType = strikeInterpType;
        this.TimeInterpolatorType = timeInterpType;
        this.TimeBasis = timeBasis;

        if (pillarLabels == null)
            this.PillarLabels = expiries.map(x => moment(x).format("yyyy-MM-dd"));
        else
            this.PillarLabels = pillarLabels;

        this.Build(originDate, strikes, expiries, vols);
    }

    Build(originDate: Date, strikes: number[], expiries: Date[], vols: number[][]) {
        this.OriginDate = originDate;
        this.Strikes = strikes;
        this.Expiries = expiries;
        this.Volatilities = vols;
        this.ExpiriesDouble = this.Expiries.map(t => CalculateYearFraction(moment(originDate), moment(t), this.TimeBasis))
        this._interpolators = vols.map((v, ix) =>
            InterpolatorFactory1d(this.Strikes, v, this.StrikeInterpolatorType));
    }

    public GetVolForAbsoluteStrike(strike: number, maturity: number, forward: number): number {
        // var key = $"{strike:f6}~{maturity:f3}~{forward:f6}";
        // if (_allowCaching && _absVolCache.TryGetValue(key, out var vol))
        //     return vol;
        var vol = 0;
        if (this.StrikeType === StrikeType.absolute) {
            var interpForStrike = InterpolatorFactory1d(this.ExpiriesDouble,
                this._interpolators.map(x => x.Interpolate(strike)),
                this.TimeInterpolatorType);
            vol = interpForStrike.Interpolate(maturity);
        }
        else {
            var fwd = forward;
            const testFunc = (deltaK) => {
                var dkModified = this.FlatDeltaSmileInExtreme ? Math.min(1.0 - this.FlatDeltaPoint, Math.max(deltaK, this.FlatDeltaPoint)) : deltaK;
                var interpForStrike = InterpolatorFactory1d(this.ExpiriesDouble,
                    this._interpolators.map(x => x.Interpolate(-dkModified)),
                    this.TimeInterpolatorType);
                var vol2 = interpForStrike.Interpolate(maturity);
                var absK = AbsoluteStrikefromDeltaKAnalytic(fwd, deltaK, 0, maturity, vol2);
                return absK - strike;
            };

            var hiK = this.FlatDeltaSmileInExtreme ? 1.0 - this.FlatDeltaPoint : 0.999999999;
            var loK = this.FlatDeltaSmileInExtreme ? this.FlatDeltaPoint : 0.000000001;
            var solvedStrike = -BrentsMethodSolve(testFunc, -hiK, -loK, 1e-12);
            if (solvedStrike === loK || solvedStrike === hiK) //out of bounds
            {
                var upperK = testFunc(-loK);
                var lowerK = testFunc(-hiK);
                if (Math.abs(upperK - fwd) < Math.abs(lowerK - fwd))
                    solvedStrike = loK;
                else
                    solvedStrike = hiK;
            }
            var interpForSolvedStrike = InterpolatorFactory1d(this.ExpiriesDouble,
                this._interpolators.map(x => x.Interpolate(solvedStrike)),
                this.TimeInterpolatorType);
            vol = interpForSolvedStrike.Interpolate(maturity);
        }

        // if (_allowCaching) _absVolCache[key] = vol;
        return vol;
    }

    public GetVolForAbsoluteStrikeExpAsDate(strike: number, expiry: Date, forward: number) {
        return this.GetVolForAbsoluteStrike(strike, CalculateYearFraction(moment(this.OriginDate), moment(expiry), this.TimeBasis), forward);
    }

    GetVolForDeltaStrike(deltaStrike: number, maturity: number, forward: number) {
        if (deltaStrike > 1.0 || deltaStrike < -1.0)
            throw new Error(`Delta strike must be in range -1.0 < x < 1.0 - value was ${deltaStrike}`);

        // var key = $"{deltaStrike:f6}~{maturity:f3}~{forward:f6}";
        // if (_allowCaching && _deltaVolCache.TryGetValue(key, out var vol))
        //     return vol;
        var vol: number;

        if (this.StrikeType === StrikeType.forwardDelta) {
            var interpForStrike = InterpolatorFactory1d(this.ExpiriesDouble,
                this._interpolators.map(x => x.Interpolate(-deltaStrike)), //strikes stored as 0.1, 0.25, 0.5 (etc) are put deltas
                this.TimeInterpolatorType);
            vol = interpForStrike.Interpolate(maturity);
        }
        else {
            var fwd = forward;
            const isCall = deltaStrike >= 0;
            const testFunc = (absK) => {
                var interpForStrike = InterpolatorFactory1d(this.ExpiriesDouble,
                    this._interpolators.map(x => x.Interpolate(absK)),
                    this.TimeInterpolatorType);
                var vol2 = interpForStrike.Interpolate(maturity);
                var deltaK = BlackDelta(fwd, absK, 0, maturity, vol2, isCall);
                return deltaK - Math.abs(deltaStrike);
            };

            var solvedStrike = BrentsMethodSolve(testFunc, 0.000000001, 10 * fwd, 1e-8);
            var interpForSolvedStrike = InterpolatorFactory1d(this.ExpiriesDouble,
                this._interpolators.map(x => x.Interpolate(solvedStrike)),
                this.TimeInterpolatorType);
            vol = interpForSolvedStrike.Interpolate(maturity);
        }

        // if (_allowCaching) _deltaVolCache[key] = vol;
        return vol;
    }

    public GetVolForDeltaStrikeExpAsDate(strike: number, expiry: Date, forward: number): number {
        return this.GetVolForDeltaStrike(strike, CalculateYearFraction(moment(this.OriginDate), moment(expiry), this.TimeBasis), forward);
    }

    private GetAbsStrikeForDelta(fwd: number, deltaStrike: number, maturity: number): number {
        const isCall = deltaStrike >= 0;
        const testFunc = (absK) => {
            var interpForStrike = InterpolatorFactory1d(this.ExpiriesDouble, this.ExpiriesDouble.map(e => this.GetVolForAbsoluteStrike(absK, e, fwd)), this.TimeInterpolatorType);
            var vol2 = interpForStrike.Interpolate(maturity);
            var deltaK = BlackDelta(fwd, absK, 0, maturity, vol2, isCall);
            return deltaK - deltaStrike;
        };

        var solvedStrike = BrentsMethodSolve(testFunc, 0.000000001, 50 * fwd, 1e-8);

        return solvedStrike;
    }

    private GetDeltaStrikeForAbs(fwd: number, strike: number, maturity: number): number {
        const testFunc = (deltaK) => {
            var interpForStrike = InterpolatorFactory1d(this.ExpiriesDouble, this.ExpiriesDouble.map(e => this.GetVolForDeltaStrike(deltaK, e, fwd)), this.TimeInterpolatorType);
            var vol2 = interpForStrike.Interpolate(maturity);
            var absK = AbsoluteStrikefromDeltaKAnalytic(fwd, deltaK, 0, maturity, vol2);
            return absK - strike;
        };

        var solvedStrike = -BrentsMethodSolve(testFunc, -0.99999999999, -0.00000000001, 1e-8);
        if (solvedStrike === 0.00000000001 || solvedStrike === 0.99999999999) //out of bounds
        {
            var upperK = testFunc(-0.00000000001);
            var lowerK = testFunc(-0.99999999999);
            if (Math.abs(upperK - fwd) < Math.abs(lowerK - fwd))
                solvedStrike = 0.00000000001;
            else
                solvedStrike = 0.99999999999;
        }

        return solvedStrike;
    }

    public CDF( expiry: Date,  fwd: number,  strike: number): number
    {
        var t = CalculateYearFraction(moment(this.OriginDate), moment(expiry), this.TimeBasis);
        var vol = this.GetVolForAbsoluteStrike(strike, t, fwd);
        var vega = BlackVega(fwd, strike, 0.0, t, vol) / 0.01;
        var digi = BlackDigitalPV(fwd, strike, 0.0, t, vol, false);
        var dvdk = this.Dvdk(strike, t, fwd);
        return digi + vega * dvdk;
    }

    public Dvdk(strike: number, expiry: number, fwd: number) : number
        {
            if (this.StrikeType === StrikeType.forwardDelta)
            {

                var pillarIx = binarySearch(this.ExpiriesDouble, expiry);
                var interpForMaturity = pillarIx > 0 ?
                this._interpolators[pillarIx] :
                InterpolatorFactory1d(this.Strikes, this.Strikes.map(k => this.GetVolForDeltaStrike(k, expiry, fwd)), this.StrikeInterpolatorType);

                var deltaK = this.GetDeltaStrikeForAbs(fwd, strike, expiry);
                var vol = this.GetVolForAbsoluteStrike(strike, expiry, fwd);
                var gamma = BlackGamma(fwd, strike, 0.0, expiry, vol);
                return interpForMaturity.FirstDerivative(deltaK) * gamma;
            }
            else
            {
                var interpForMaturity2 = InterpolatorFactory1d(this.Strikes,
                    this.Strikes.map(k => this.GetVolForAbsoluteStrike(k, expiry, fwd)),
                    this.StrikeInterpolatorType);
                return interpForMaturity2.FirstDerivative(strike);
            }
        }
}

