/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;
import java.util.ArrayList;
import java.util.Collection;
import java.util.logging.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Trial
extends ArrayList<Classification> {
    private static Logger logger = Logger.getLogger(Trial.class.getName());
    Classifier classifier;

    public Trial(Classifier c, InstanceList ilist) {
        super(ilist.size());
        this.classifier = c;
        for (Instance instance : ilist) {
            this.add(c.classify(instance));
        }
    }

    @Override
    public boolean add(Classification c) {
        if (c.getClassifier() != this.classifier) {
            throw new IllegalArgumentException("Trying to add Classification from a different Classifier.");
        }
        return super.add(c);
    }

    @Override
    public void add(int index, Classification c) {
        if (c.getClassifier() != this.classifier) {
            throw new IllegalArgumentException("Trying to add Classification from a different Classifier.");
        }
        super.add(index, c);
    }

    @Override
    public boolean addAll(Collection<? extends Classification> collection) {
        boolean ret = true;
        for (Classification classification : collection) {
            if (this.add(classification)) continue;
            ret = false;
        }
        return ret;
    }

    @Override
    public boolean addAll(int index, Collection<? extends Classification> collection) {
        throw new IllegalStateException("Not implemented.");
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public double getAccuracy() {
        int numCorrect = 0;
        for (int i = 0; i < this.size(); ++i) {
            if (!((Classification)this.get(i)).bestLabelIsCorrect()) continue;
            ++numCorrect;
        }
        return (double)numCorrect / (double)this.size();
    }

    public double getPrecision(Object labelEntry) {
        int index = labelEntry instanceof Labeling ? ((Labeling)labelEntry).getBestIndex() : this.classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
        if (index == -1) {
            throw new IllegalArgumentException("Label " + labelEntry.toString() + " is not a valid label.");
        }
        return this.getPrecision(index);
    }

    public double getPrecision(Labeling label) {
        return this.getPrecision(label.getBestIndex());
    }

    public double getPrecision(int index) {
        int numCorrect = 0;
        int numInstances = 0;
        for (int i = 0; i < this.size(); ++i) {
            int trueLabel = ((Classification)this.get(i)).getInstance().getLabeling().getBestIndex();
            int classLabel = ((Classification)this.get(i)).getLabeling().getBestIndex();
            if (classLabel != index) continue;
            ++numInstances;
            if (trueLabel != index) continue;
            ++numCorrect;
        }
        if (numInstances == 0) {
            logger.warning("No class instances: dividing by 0");
        }
        return (double)numCorrect / (double)numInstances;
    }

    public double getRecall(Object labelEntry) {
        int index = labelEntry instanceof Labeling ? ((Labeling)labelEntry).getBestIndex() : this.classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
        if (index == -1) {
            throw new IllegalArgumentException("Label " + labelEntry.toString() + " is not a valid label.");
        }
        return this.getRecall(index);
    }

    public double getRecall(Labeling label) {
        return this.getRecall(label.getBestIndex());
    }

    public double getRecall(int labelIndex) {
        int numCorrect = 0;
        int numInstances = 0;
        for (int i = 0; i < this.size(); ++i) {
            int trueLabel = ((Classification)this.get(i)).getInstance().getLabeling().getBestIndex();
            int classLabel = ((Classification)this.get(i)).getLabeling().getBestIndex();
            if (trueLabel != labelIndex) continue;
            ++numInstances;
            if (classLabel != labelIndex) continue;
            ++numCorrect;
        }
        if (numInstances == 0) {
            logger.warning("No class instances: dividing by 0");
        }
        return (double)numCorrect / (double)numInstances;
    }

    public double getF1(Object labelEntry) {
        int index = labelEntry instanceof Labeling ? ((Labeling)labelEntry).getBestIndex() : this.classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
        if (index == -1) {
            throw new IllegalArgumentException("Label " + labelEntry.toString() + " is not a valid label.");
        }
        return this.getF1(index);
    }

    public double getF1(Labeling label) {
        return this.getF1(label.getBestIndex());
    }

    public double getF1(int index) {
        double precision = this.getPrecision(index);
        double recall = this.getRecall(index);
        if (precision == 0.0 && recall == 0.0) {
            logger.warning("Precision and recall are 0: dividing by 0");
        }
        return 2.0 * precision * recall / (precision + recall);
    }

    public double getAverageRank() {
        double rsum = 0.0;
        for (int i = 0; i < this.size(); ++i) {
            Classification tmpC = (Classification)this.get(i);
            Instance tmpI = tmpC.getInstance();
            Labeling tmpL = tmpC.getLabeling();
            Label tmpLbl = (Label)tmpI.getTarget();
            int tmpInt = tmpL.getRank(tmpLbl);
            Label tmpLbl2 = tmpL.getLabelAtRank(0);
            rsum += (double)tmpInt;
        }
        return rsum / (double)this.size();
    }
}

