/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.inference.operators.hmc.MinimumTravelInformation;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.TaskPool;

abstract class AbstractZigZagOperator
extends AbstractParticleOperator
implements Loggable {
    final TaskPool taskPool;
    protected static final boolean DEBUG = false;

    AbstractZigZagOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, Parameter parameter, int n) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, parameter);
        this.taskPool = n > 1 ? new TaskPool(gradientWrtParameterProvider.getDimension(), n) : null;
    }

    @Override
    final double integrateTrajectory(WrappedVector wrappedVector, WrappedVector wrappedVector2) {
        this.timer.startTimer("warmUp");
        WrappedVector wrappedVector3 = this.drawInitialVelocity(wrappedVector2);
        WrappedVector wrappedVector4 = this.getInitialGradient();
        WrappedVector wrappedVector5 = this.getPrecisionProduct(wrappedVector3);
        AbstractParticleOperator.BounceState bounceState = new AbstractParticleOperator.BounceState(this.drawTotalTravelTime());
        this.initializeNumEvent();
        this.timer.stopTimer("warmUp");
        this.timer.startTimer("integrateTrajectory");
        while (bounceState.isTimeRemaining()) {
            MinimumTravelInformation minimumTravelInformation = this.getNextBounce(wrappedVector, wrappedVector3, wrappedVector5, wrappedVector4, wrappedVector2);
            bounceState = this.doBounce(bounceState, minimumTravelInformation, wrappedVector, wrappedVector3, wrappedVector5, wrappedVector4, wrappedVector2);
            this.recordOneMoreEvent();
        }
        this.timer.stopTimer("integrateTrajectory");
        return 0.0;
    }

    private AbstractParticleOperator.BounceState doBounce(AbstractParticleOperator.BounceState bounceState, MinimumTravelInformation minimumTravelInformation, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        AbstractParticleOperator.BounceState bounceState2;
        this.timer.startTimer("doBounce");
        double d = bounceState.remainingTime;
        double d2 = minimumTravelInformation.time;
        if (d < d2) {
            this.updatePositionAndMomentum(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5, d);
            bounceState2 = new AbstractParticleOperator.BounceState(AbstractParticleOperator.Type.NONE, -1, 0.0);
        } else {
            AbstractParticleOperator.Type type = minimumTravelInformation.type;
            int n = minimumTravelInformation.index;
            WrappedVector wrappedVector6 = this.getPrecisionColumn(n);
            this.updateDynamics(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5, wrappedVector6, d2, n, type);
            AbstractZigZagOperator.reflectVelocity(wrappedVector2, n);
            bounceState2 = new AbstractParticleOperator.BounceState(type, n, d - d2);
        }
        this.timer.stopTimer("doBounce");
        return bounceState2;
    }

    abstract WrappedVector drawInitialVelocity(WrappedVector var1);

    abstract MinimumTravelInformation getNextBounce(WrappedVector var1, WrappedVector var2, WrappedVector var3, WrappedVector var4, WrappedVector var5);

    abstract void updatePositionAndMomentum(WrappedVector var1, WrappedVector var2, WrappedVector var3, WrappedVector var4, WrappedVector var5, double var6);

    abstract void updateDynamics(WrappedVector var1, WrappedVector var2, WrappedVector var3, WrappedVector var4, WrappedVector var5, WrappedVector var6, double var7, int var9, AbstractParticleOperator.Type var10);

    static double findGradientRoot(double d, double d2, double d3) {
        return AbstractZigZagOperator.minimumPositiveRoot(-0.5 * d, d2, d3);
    }

    double findBoundaryTime(int n, double d, double d2) {
        double d3 = Double.POSITIVE_INFINITY;
        if (this.headingTowardsBoundary(d, d2, n)) {
            d3 = Math.abs(d / d2);
        }
        return d3;
    }

    private static double minimumPositiveRoot(double d, double d2, double d3) {
        double d4;
        double d5 = AbstractZigZagOperator.sign(d);
        if ((d4 = (d2 *= d5) * d2 - 4.0 * (d *= d5) * (d3 *= d5)) < 0.0) {
            return Double.POSITIVE_INFINITY;
        }
        double d6 = Math.sqrt(d4);
        double d7 = (-d2 - d6) / (2.0 * d);
        if (d7 <= 0.0) {
            d7 = (-d2 + d6) / (2.0 * d);
        }
        if (d7 <= 0.0) {
            d7 = Double.POSITIVE_INFINITY;
        }
        return d7;
    }

    static void reflectMomentum(WrappedVector wrappedVector, WrappedVector wrappedVector2, int n) {
        wrappedVector.set(n, -wrappedVector.get(n));
        wrappedVector2.set(n, 0.0);
    }

    static void setZeroMomentum(WrappedVector wrappedVector, int n) {
        wrappedVector.set(n, 0.0);
    }

    private static void reflectVelocity(WrappedVector wrappedVector, int n) {
        wrappedVector.set(n, -wrappedVector.get(n));
    }

    protected boolean close(double[] dArray, double[] dArray2) {
        for (int i = 0; i < dArray.length; ++i) {
            if (!(Math.abs((dArray[i] - dArray2[i]) / (dArray[i] + dArray2[i])) > 1.0E-5)) continue;
            return false;
        }
        return true;
    }

    static int sign(double d) {
        int n = 0;
        if (d > 0.0) {
            n = 1;
        } else if (d < 0.0) {
            n = -1;
        }
        return n;
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[]{new NumberColumn("number of event"){

            @Override
            public double getDoubleValue() {
                return AbstractZigZagOperator.this.numEvents;
            }
        }};
        return logColumnArray;
    }
}

