/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.ConfidencePredictingClassifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.ConfusionMatrix;
import cc.mallet.pipe.Classification2ConfidencePredictingFeatureVector;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.InstanceList;
import cc.mallet.types.PerLabelInfoGain;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ConfidencePredictingClassifierTrainer
extends ClassifierTrainer<ConfidencePredictingClassifier>
implements Boostable {
    private static Logger logger = MalletLogger.getLogger(ConfidencePredictingClassifierTrainer.class.getName());
    ClassifierTrainer underlyingClassifierTrainer;
    MaxEntTrainer confidencePredictingClassifierTrainer;
    Pipe confidencePredictingPipe;
    static ConfusionMatrix confusionMatrix = null;
    ConfidencePredictingClassifier classifier;

    @Override
    public ConfidencePredictingClassifier getClassifier() {
        return this.classifier;
    }

    public ConfidencePredictingClassifierTrainer(ClassifierTrainer underlyingClassifierTrainer, InstanceList validationSet, Pipe confidencePredictingPipe) {
        this.confidencePredictingPipe = confidencePredictingPipe;
        this.confidencePredictingClassifierTrainer = new MaxEntTrainer();
        this.validationSet = validationSet;
        this.underlyingClassifierTrainer = underlyingClassifierTrainer;
    }

    public ConfidencePredictingClassifierTrainer(ClassifierTrainer underlyingClassifierTrainer, InstanceList validationSet) {
        this(underlyingClassifierTrainer, validationSet, new Classification2ConfidencePredictingFeatureVector());
    }

    @Override
    public ConfidencePredictingClassifier train(InstanceList trainList) {
        FeatureSelection selectedFeatures = trainList.getFeatureSelection();
        logger.fine("Training underlying classifier");
        Object c = this.underlyingClassifierTrainer.train(trainList);
        confusionMatrix = new ConfusionMatrix(new Trial((Classifier)c, trainList));
        assert (this.validationSet != null) : "This ClassifierTrainer requires a validation set.";
        Trial t = new Trial((Classifier)c, this.validationSet);
        double accuracy = t.getAccuracy();
        InstanceList confidencePredictionTraining = new InstanceList(this.confidencePredictingPipe);
        logger.fine("Creating confidence prediction instance list");
        for (int i = 0; i < t.size(); ++i) {
            Classification classification = (Classification)t.get(i);
            confidencePredictionTraining.add(classification, null, classification.getInstance().getName(), classification.getInstance().getSource());
        }
        logger.info("Begin training ConfidencePredictingClassifier . . . ");
        MaxEnt cpc = this.confidencePredictingClassifierTrainer.train(confidencePredictionTraining);
        logger.info("Accuracy at predicting correct/incorrect in training = " + cpc.getAccuracy(confidencePredictionTraining));
        PerLabelInfoGain perLabelInfoGain = new PerLabelInfoGain(trainList);
        this.classifier = new ConfidencePredictingClassifier((Classifier)c, cpc);
        return this.classifier;
    }
}

