/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.recommender.svd.AbstractFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DiagonalMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.SparseMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ImplicitLinearRegressionFactorizer
extends AbstractFactorizer {
    private static final Logger log = LoggerFactory.getLogger(ImplicitLinearRegressionFactorizer.class);
    private final double preventOverfitting;
    private final int numFeatures;
    private final int numIterations;
    private final DataModel dataModel;
    private double[][] userMatrix;
    private double[][] itemMatrix;
    private Matrix userTransUser;
    private Matrix itemTransItem;
    Collection<Callable<Void>> fVectorCallables;
    private boolean recomputeUserFeatures;
    private RunningAverage avrChange;

    public ImplicitLinearRegressionFactorizer(DataModel dataModel) throws TasteException {
        this(dataModel, 64, 10, 0.1);
    }

    public ImplicitLinearRegressionFactorizer(DataModel dataModel, int numFeatures, int numIterations, double preventOverfitting) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.numFeatures = numFeatures;
        this.numIterations = numIterations;
        this.preventOverfitting = preventOverfitting;
        this.fVectorCallables = Lists.newArrayList();
        this.avrChange = new FullRunningAverage();
    }

    @Override
    public Factorization factorize() throws TasteException {
        Random random = RandomUtils.getRandom();
        this.userMatrix = new double[this.dataModel.getNumUsers()][this.numFeatures];
        this.itemMatrix = new double[this.dataModel.getNumItems()][this.numFeatures];
        this.recomputeUserFeatures = true;
        double average = this.getAveragePreference();
        double prefInterval = this.dataModel.getMaxPreference() - this.dataModel.getMinPreference();
        double defaultValue = Math.sqrt((average - prefInterval * 0.1) / (double)this.numFeatures);
        double interval = prefInterval * 0.1 / (double)this.numFeatures;
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            for (int userIndex = 0; userIndex < this.dataModel.getNumUsers(); ++userIndex) {
                this.userMatrix[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * random.nextDouble();
            }
            for (int itemIndex = 0; itemIndex < this.dataModel.getNumItems(); ++itemIndex) {
                this.itemMatrix[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * random.nextDouble();
            }
        }
        this.train();
        return this.createFactorization(this.userMatrix, this.itemMatrix);
    }

    public void train() throws TasteException {
        for (int i = 0; i < this.numIterations; ++i) {
            if (this.recomputeUserFeatures) {
                LongPrimitiveIterator userIds = this.dataModel.getUserIDs();
                log.info("Calculating Y^TY");
                this.reCalculateTrans(this.recomputeUserFeatures);
                log.info("Building callables for users.");
                while (userIds.hasNext()) {
                    long userId = userIds.nextLong();
                    int useridx = this.userIndex(userId);
                    this.buildCallables(this.buildConfidenceMatrixForUser(userId), this.buildPreferenceVectorForUser(userId), useridx);
                }
                this.finishProcessing();
                continue;
            }
            LongPrimitiveIterator itemIds = this.dataModel.getItemIDs();
            log.info("Calculating X^TX");
            this.reCalculateTrans(this.recomputeUserFeatures);
            log.info("Building callables for items.");
            while (itemIds.hasNext()) {
                long itemId = itemIds.nextLong();
                int itemidx = this.itemIndex(itemId);
                this.buildCallables(this.buildConfidenceMatrixForItem(itemId), this.buildPreferenceVectorForItem(itemId), itemidx);
            }
            this.finishProcessing();
        }
    }

    public Matrix buildPreferenceVectorForUser(long realId) throws TasteException {
        SparseMatrix ids = new SparseMatrix(1, this.dataModel.getNumItems());
        for (Preference pref : this.dataModel.getPreferencesFromUser(realId)) {
            int itemidx = this.itemIndex(pref.getItemID());
            ids.setQuick(0, itemidx, (double)pref.getValue());
        }
        return ids;
    }

    private Matrix buildConfidenceMatrixForItem(long itemId) throws TasteException {
        PreferenceArray prefs = this.dataModel.getPreferencesForItem(itemId);
        SparseMatrix confidenceMatrix = new SparseMatrix(this.dataModel.getNumUsers(), this.dataModel.getNumUsers());
        for (Preference pref : prefs) {
            long userId = pref.getUserID();
            int userIdx = this.userIndex(userId);
            confidenceMatrix.setQuick(userIdx, userIdx, 1.0);
        }
        return new DiagonalMatrix((Matrix)confidenceMatrix);
    }

    private Matrix buildConfidenceMatrixForUser(long userId) throws TasteException {
        PreferenceArray prefs = this.dataModel.getPreferencesFromUser(userId);
        SparseMatrix confidenceMatrix = new SparseMatrix(this.dataModel.getNumItems(), this.dataModel.getNumItems());
        for (Preference pref : prefs) {
            long itemId = pref.getItemID();
            int itemIdx = this.itemIndex(itemId);
            confidenceMatrix.setQuick(itemIdx, itemIdx, 1.0);
        }
        return new DiagonalMatrix((Matrix)confidenceMatrix);
    }

    private Matrix buildPreferenceVectorForItem(long realId) throws TasteException {
        SparseMatrix ids = new SparseMatrix(1, this.dataModel.getNumUsers());
        for (Preference pref : this.dataModel.getPreferencesForItem(realId)) {
            int useridx = this.userIndex(pref.getUserID());
            ids.setQuick(0, useridx, (double)pref.getValue());
        }
        return ids;
    }

    private Matrix ones(int size) {
        double[] vector = new double[size];
        for (int i = 0; i < size; ++i) {
            vector[i] = 1.0;
        }
        DiagonalMatrix ones = new DiagonalMatrix(vector);
        return ones;
    }

    private double getAveragePreference() throws TasteException {
        FullRunningAverage average = new FullRunningAverage();
        LongPrimitiveIterator it = this.dataModel.getUserIDs();
        while (it.hasNext()) {
            int count = 0;
            try {
                PreferenceArray prefs = this.dataModel.getPreferencesFromUser(it.nextLong());
                for (Preference pref : prefs) {
                    average.addDatum(pref.getValue());
                    ++count;
                }
            }
            catch (NoSuchUserException ex) {
                continue;
            }
            for (int i = 0; i < this.dataModel.getNumItems() - count; ++i) {
                average.addDatum(0.0);
            }
        }
        return average.getAverage();
    }

    public void reCalculateTrans(boolean recomputeUserFeatures) {
        if (!recomputeUserFeatures) {
            DenseMatrix uMatrix = new DenseMatrix(this.userMatrix);
            this.userTransUser = uMatrix.transpose().times((Matrix)uMatrix);
        } else {
            DenseMatrix iMatrix = new DenseMatrix(this.itemMatrix);
            this.itemTransItem = iMatrix.transpose().times((Matrix)iMatrix);
        }
    }

    private synchronized void updateMatrix(int id, Matrix m) {
        double normA = 0.0;
        double normB = 0.0;
        double aTb = 0.0;
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            if (this.recomputeUserFeatures) {
                normA += this.userMatrix[id][feature] * this.userMatrix[id][feature];
                normB += m.get(feature, 0) * m.get(feature, 0);
                aTb += this.userMatrix[id][feature] * m.get(feature, 0);
                this.userMatrix[id][feature] = m.get(feature, 0);
                continue;
            }
            normA += this.itemMatrix[id][feature] * this.itemMatrix[id][feature];
            normB += m.get(feature, 0) * m.get(feature, 0);
            aTb += this.itemMatrix[id][feature] * m.get(feature, 0);
            this.itemMatrix[id][feature] = m.get(feature, 0);
        }
        double cosine = aTb / (Math.sqrt(normA) * Math.sqrt(normB));
        if (Double.isNaN(cosine)) {
            log.info("Cosine similarity is NaN, recomputeUserFeatures=" + this.recomputeUserFeatures + " id=" + id);
        } else {
            this.avrChange.addDatum(cosine);
        }
    }

    public void resetCallables() {
        this.fVectorCallables = Lists.newArrayList();
    }

    private void resetAvrChange() {
        log.info("Avr Change: {}", (Object)this.avrChange.getAverage());
        this.avrChange = new FullRunningAverage();
    }

    public void buildCallables(Matrix C, Matrix prefVector, int id) throws TasteException {
        this.fVectorCallables.add(new FeatureVectorCallable(C, prefVector, id));
        if (this.fVectorCallables.size() % (200 * Runtime.getRuntime().availableProcessors()) == 0) {
            this.execute(this.fVectorCallables);
            this.resetCallables();
        }
    }

    public void finishProcessing() throws TasteException {
        if (this.fVectorCallables != null) {
            this.execute(this.fVectorCallables);
        }
        this.resetCallables();
        if (this.recomputeUserFeatures && this.avrChange.getCount() != this.userMatrix.length || !this.recomputeUserFeatures && this.avrChange.getCount() != this.itemMatrix.length) {
            log.info("Matrix length is not equal to count");
        }
        this.resetAvrChange();
        this.recomputeUserFeatures = !this.recomputeUserFeatures;
    }

    public Matrix identityV(int size) {
        return this.ones(size);
    }

    void execute(Collection<Callable<Void>> callables) throws TasteException {
        callables = this.wrapWithStatsCallables(callables);
        int numProcessors = Runtime.getRuntime().availableProcessors();
        ExecutorService executor = Executors.newFixedThreadPool(numProcessors);
        log.info("Starting timing of {} tasks in {} threads", (Object)callables.size(), (Object)numProcessors);
        try {
            List<Future<Void>> futures = executor.invokeAll(callables);
            for (Future<Void> future : futures) {
                future.get();
            }
        }
        catch (InterruptedException ie) {
            log.warn("error in factorization", ie);
        }
        catch (ExecutionException ee) {
            log.warn("error in factorization", ee);
        }
        executor.shutdown();
    }

    private Collection<Callable<Void>> wrapWithStatsCallables(Collection<Callable<Void>> callables) {
        int size = callables.size();
        ArrayList<Callable<Void>> wrapped = Lists.newArrayListWithExpectedSize(size);
        int count = 1;
        FullRunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
        for (Callable<Void> callable : callables) {
            boolean logStats = count++ % 1000 == 0;
            wrapped.add(new StatsCallable(callable, logStats, timing));
        }
        return wrapped;
    }

    private Matrix solve(Matrix A, Matrix y) {
        return new QRDecomposition(A).solve(y);
    }

    private static class StatsCallable
    implements Callable<Void> {
        private final Callable<Void> delegate;
        private final boolean logStats;
        private final RunningAverageAndStdDev timing;

        private StatsCallable(Callable<Void> delegate, boolean logStats, RunningAverageAndStdDev timing) {
            this.delegate = delegate;
            this.logStats = logStats;
            this.timing = timing;
        }

        @Override
        public Void call() throws Exception {
            long start = System.currentTimeMillis();
            this.delegate.call();
            long end = System.currentTimeMillis();
            this.timing.addDatum(end - start);
            if (this.logStats) {
                Runtime runtime = Runtime.getRuntime();
                int average = (int)this.timing.getAverage();
                log.info("Average time per task: {}ms", (Object)average);
                long totalMemory = runtime.totalMemory();
                long memory = totalMemory - runtime.freeMemory();
                log.info("Approximate memory used: {}MB / {}MB", (Object)(memory / 1000000L), (Object)(totalMemory / 1000000L));
            }
            return null;
        }
    }

    private class FeatureVectorCallable
    implements Callable<Void> {
        private final Matrix C;
        private final Matrix prefVector;
        private final int id;

        private FeatureVectorCallable(Matrix C, Matrix prefVector, int id) {
            this.C = C;
            this.prefVector = prefVector;
            this.id = id;
        }

        @Override
        public Void call() throws Exception {
            if (ImplicitLinearRegressionFactorizer.this.recomputeUserFeatures) {
                Matrix I = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.dataModel.getNumItems());
                Matrix I2 = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.numFeatures);
                Matrix iTi = ImplicitLinearRegressionFactorizer.this.itemTransItem.clone();
                DenseMatrix itemM = new DenseMatrix(ImplicitLinearRegressionFactorizer.this.itemMatrix);
                Matrix XTCX = iTi.plus(itemM.transpose().times(this.C.minus(I)).times((Matrix)itemM));
                Matrix diag = ImplicitLinearRegressionFactorizer.this.solve(XTCX.plus(I2.times(ImplicitLinearRegressionFactorizer.this.preventOverfitting)), I2);
                Matrix results = diag.times(itemM.transpose().times(this.C)).times(this.prefVector.transpose());
                ImplicitLinearRegressionFactorizer.this.updateMatrix(this.id, results);
            } else {
                Matrix I = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.dataModel.getNumUsers());
                Matrix I2 = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.numFeatures);
                Matrix uTu = ImplicitLinearRegressionFactorizer.this.userTransUser.clone();
                DenseMatrix userM = new DenseMatrix(ImplicitLinearRegressionFactorizer.this.userMatrix);
                Matrix XTCX = uTu.plus(userM.transpose().times(this.C.minus(I)).times((Matrix)userM));
                Matrix diag = ImplicitLinearRegressionFactorizer.this.solve(XTCX.plus(I2.times(ImplicitLinearRegressionFactorizer.this.preventOverfitting)), I2);
                Matrix results = diag.times(userM.transpose().times(this.C)).times(this.prefVector.transpose());
                ImplicitLinearRegressionFactorizer.this.updateMatrix(this.id, results);
            }
            return null;
        }
    }
}

