/*
 * Decompiled with CFR 0.152.
 */
package org.spaceroots.mantissa.optimization;

import java.util.Arrays;
import java.util.Comparator;
import org.spaceroots.mantissa.optimization.ConvergenceChecker;
import org.spaceroots.mantissa.optimization.CostException;
import org.spaceroots.mantissa.optimization.CostFunction;
import org.spaceroots.mantissa.optimization.NoConvergenceException;
import org.spaceroots.mantissa.optimization.PointCostPair;
import org.spaceroots.mantissa.random.CorrelatedRandomVectorGenerator;
import org.spaceroots.mantissa.random.NotPositiveDefiniteMatrixException;
import org.spaceroots.mantissa.random.RandomVectorGenerator;
import org.spaceroots.mantissa.random.UncorrelatedRandomVectorGenerator;
import org.spaceroots.mantissa.random.UniformRandomGenerator;
import org.spaceroots.mantissa.random.VectorialSampleStatistics;

public abstract class DirectSearchOptimizer {
    private static Comparator pointCostPairComparator = new Comparator(){

        public int compare(Object o1, Object o2) {
            if (o1 == null) {
                return o2 == null ? 0 : 1;
            }
            if (o2 == null) {
                return -1;
            }
            double cost1 = ((PointCostPair)o1).cost;
            double cost2 = ((PointCostPair)o2).cost;
            return cost1 < cost2 ? -1 : (o1 == o2 ? 0 : 1);
        }
    };
    protected PointCostPair[] simplex;
    private CostFunction f;
    private int evaluations;
    private int starts;
    private RandomVectorGenerator generator;
    private PointCostPair[] minima;

    protected DirectSearchOptimizer() {
    }

    public PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker, double[] vertexA, double[] vertexB) throws CostException, NoConvergenceException {
        this.buildSimplex(vertexA, vertexB);
        this.setSingleStart();
        return this.minimizes(f, maxEvaluations, checker);
    }

    public PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker, double[] vertexA, double[] vertexB, int starts, int[] seed) throws CostException, NoConvergenceException {
        this.buildSimplex(vertexA, vertexB);
        double[] mean = new double[vertexA.length];
        double[] standardDeviation = new double[vertexA.length];
        for (int i = 0; i < vertexA.length; ++i) {
            mean[i] = 0.5 * (vertexA[i] + vertexB[i]);
            standardDeviation[i] = 0.5 * Math.abs(vertexA[i] - vertexB[i]);
        }
        UncorrelatedRandomVectorGenerator rvg = new UncorrelatedRandomVectorGenerator(mean, standardDeviation, new UniformRandomGenerator(seed));
        this.setMultiStart(starts, rvg);
        return this.minimizes(f, maxEvaluations, checker);
    }

    public PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker, double[][] vertices) throws CostException, NoConvergenceException {
        this.buildSimplex(vertices);
        this.setSingleStart();
        return this.minimizes(f, maxEvaluations, checker);
    }

    public PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker, double[][] vertices, int starts, int[] seed) throws NotPositiveDefiniteMatrixException, CostException, NoConvergenceException {
        this.buildSimplex(vertices);
        VectorialSampleStatistics statistics = new VectorialSampleStatistics();
        for (int i = 0; i < vertices.length; ++i) {
            statistics.add(vertices[i]);
        }
        CorrelatedRandomVectorGenerator rvg = new CorrelatedRandomVectorGenerator(statistics.getMean(), statistics.getCovarianceMatrix(null), new UniformRandomGenerator(seed));
        this.setMultiStart(starts, rvg);
        return this.minimizes(f, maxEvaluations, checker);
    }

    public PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker, RandomVectorGenerator generator) throws CostException, NoConvergenceException {
        this.buildSimplex(generator);
        this.setSingleStart();
        return this.minimizes(f, maxEvaluations, checker);
    }

    public PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker, RandomVectorGenerator generator, int starts) throws CostException, NoConvergenceException {
        this.buildSimplex(generator);
        this.setMultiStart(starts, generator);
        return this.minimizes(f, maxEvaluations, checker);
    }

    private void buildSimplex(double[] vertexA, double[] vertexB) {
        int n = vertexA.length;
        this.simplex = new PointCostPair[n + 1];
        for (int i = 0; i <= n; ++i) {
            double[] vertex = new double[n];
            if (i > 0) {
                System.arraycopy(vertexB, 0, vertex, 0, i);
            }
            if (i < n) {
                System.arraycopy(vertexA, i, vertex, i, n - i);
            }
            this.simplex[i] = new PointCostPair(vertex, Double.NaN);
        }
    }

    private void buildSimplex(double[][] vertices) {
        int n = vertices.length - 1;
        this.simplex = new PointCostPair[n + 1];
        for (int i = 0; i <= n; ++i) {
            this.simplex[i] = new PointCostPair(vertices[i], Double.NaN);
        }
    }

    private void buildSimplex(RandomVectorGenerator generator) {
        double[] vertex = generator.nextVector();
        int n = vertex.length;
        this.simplex = new PointCostPair[n + 1];
        this.simplex[0] = new PointCostPair(vertex, Double.NaN);
        for (int i = 1; i <= n; ++i) {
            this.simplex[i] = new PointCostPair(generator.nextVector(), Double.NaN);
        }
    }

    private void setSingleStart() {
        this.starts = 1;
        this.generator = null;
        this.minima = null;
    }

    public void setMultiStart(int starts, RandomVectorGenerator generator) {
        if (starts < 2) {
            this.starts = 1;
            this.generator = null;
            this.minima = null;
        } else {
            this.starts = starts;
            this.generator = generator;
            this.minima = null;
        }
    }

    public PointCostPair[] getMinima() {
        return (PointCostPair[])this.minima.clone();
    }

    private PointCostPair minimizes(CostFunction f, int maxEvaluations, ConvergenceChecker checker) throws CostException, NoConvergenceException {
        this.f = f;
        this.minima = new PointCostPair[this.starts];
        for (int i = 0; i < this.starts; ++i) {
            this.evaluations = 0;
            this.evaluateSimplex();
            boolean loop = true;
            while (loop) {
                if (checker.converged(this.simplex)) {
                    this.minima[i] = this.simplex[0];
                    loop = false;
                    continue;
                }
                if (this.evaluations >= maxEvaluations) {
                    this.minima[i] = null;
                    loop = false;
                    continue;
                }
                this.iterateSimplex();
            }
            if (i >= this.starts - 1) continue;
            this.buildSimplex(this.generator);
        }
        Arrays.sort(this.minima, pointCostPairComparator);
        if (this.minima[0] == null) {
            throw new NoConvergenceException("none of the {0} start points lead to convergence", new String[]{Integer.toString(this.starts)});
        }
        return this.minima[0];
    }

    protected abstract void iterateSimplex() throws CostException;

    protected double evaluateCost(double[] x) throws CostException {
        ++this.evaluations;
        return this.f.cost(x);
    }

    protected void evaluateSimplex() throws CostException {
        for (int i = 0; i < this.simplex.length; ++i) {
            PointCostPair pair = this.simplex[i];
            if (!Double.isNaN(pair.cost)) continue;
            this.simplex[i] = new PointCostPair(pair.point, this.evaluateCost(pair.point));
        }
        Arrays.sort(this.simplex, pointCostPairComparator);
    }

    protected void replaceWorstPoint(PointCostPair pointCostPair) {
        int n = this.simplex.length - 1;
        for (int i = 0; i < n; ++i) {
            if (!(this.simplex[i].cost > pointCostPair.cost)) continue;
            PointCostPair tmp = this.simplex[i];
            this.simplex[i] = pointCostPair;
            pointCostPair = tmp;
        }
        this.simplex[n] = pointCostPair;
    }
}

