/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Tree;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import gnu.trove.THashSet;
import gnu.trove.TIntObjectHashMap;
import gnu.trove.TIntObjectIterator;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

public class JunctionTree
extends Tree {
    private int numNodes;
    private TIntObjectHashMap sepsets;
    private Factor[] cpfs;

    public JunctionTree(int size) {
        this.numNodes = size;
        this.sepsets = new TIntObjectHashMap();
        this.cpfs = new Factor[size];
    }

    public void addNode(Object parent1, Object child1) {
        super.addNode(parent1, child1);
        VarSet parent = (VarSet)parent1;
        VarSet child = (VarSet)child1;
        VarSet sepset = parent.intersection(child);
        int id1 = this.lookupIndex(parent);
        int id2 = this.lookupIndex(child);
        this.putSepset(id1, id2, new Sepset(sepset, this.newSepsetPtl(sepset)));
    }

    private Factor newSepsetPtl(Set sepset) {
        if (sepset.isEmpty()) {
            return ConstantFactor.makeIdentityFactor();
        }
        return new TableFactor(sepset);
    }

    private int hashIdxIdx(int id1, int id2) {
        assert (id1 < 65536 && id2 < 65536);
        int id = id1 < id2 ? id1 << 16 | id2 : id2 << 16 | id1;
        return id;
    }

    private void putSepset(int id1, int id2, Sepset sepset) {
        int id = this.hashIdxIdx(id1, id2);
        this.sepsets.put(id, (Object)sepset);
    }

    private Sepset getSepset(int id1, int id2) {
        int id = this.hashIdxIdx(id1, id2);
        return (Sepset)this.sepsets.get(id);
    }

    public Factor getCPF(VarSet c) {
        return this.cpfs[this.lookupIndex(c)];
    }

    public void setCPF(VarSet c, Factor pot) {
        this.cpfs[this.lookupIndex((Object)c)] = pot;
    }

    void clearCPFs() {
        for (int i = 0; i < this.cpfs.length; ++i) {
            this.cpfs[i] = new TableFactor((VarSet)this.lookupVertex(i));
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Sepset sepset = (Sepset)it.value();
            sepset.ptl = this.newSepsetPtl(sepset.set);
        }
    }

    public Set sepsetPotentials() {
        THashSet set = new THashSet();
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor ptl = ((Sepset)it.value()).ptl;
            set.add((Object)ptl);
        }
        return set;
    }

    void setSepsetPot(Factor pot, VarSet v1, VarSet v2) {
        int id1 = this.lookupIndex(v1);
        int id2 = this.lookupIndex(v2);
        this.getSepset((int)id1, (int)id2).ptl = pot;
    }

    public Factor getSepsetPot(VarSet v1, VarSet v2) {
        int id1 = this.lookupIndex(v1);
        int id2 = this.lookupIndex(v2);
        return this.getSepset((int)id1, (int)id2).ptl;
    }

    public Collection clusterPotentials() {
        HashSet<Factor> h = new HashSet<Factor>();
        for (int i = 0; i < this.cpfs.length; ++i) {
            if (this.cpfs[i] == null) continue;
            h.add(this.cpfs[i]);
        }
        return h;
    }

    public Set getSepset(VarSet v1, VarSet v2) {
        int id1 = this.lookupIndex(v1);
        int id2 = this.lookupIndex(v2);
        return this.getSepset((int)id1, (int)id2).set;
    }

    public Factor lookupMarginal(Variable var) {
        VarSet c = this.findParentCluster(var);
        Factor pot = this.getCPF(c);
        return pot.marginalize(var);
    }

    public double lookupLogJoint(Assignment assn) {
        double accum = 0.0;
        for (int i = 0; i < this.cpfs.length; ++i) {
            if (this.cpfs[i] == null) continue;
            double phi = this.cpfs[i].logValue(assn);
            accum += phi;
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor ptl = ((Sepset)it.value()).ptl;
            double phi = ptl.logValue(assn);
            accum -= phi;
        }
        return accum;
    }

    public VarSet findParentCluster(Variable var) {
        int best = Integer.MAX_VALUE;
        VarSet retval = null;
        Iterator it = this.getVerticesIterator();
        while (it.hasNext()) {
            VarSet c = (VarSet)it.next();
            if (!c.contains(var) || c.weight() >= best) continue;
            retval = c;
            best = c.weight();
        }
        return retval;
    }

    public VarSet findParentCluster(Collection vars) {
        int best = Integer.MAX_VALUE;
        VarSet retval = null;
        Iterator it = this.getVerticesIterator();
        while (it.hasNext()) {
            VarSet c = (VarSet)it.next();
            if (!c.containsAll(vars) || c.weight() >= best) continue;
            retval = c;
            best = c.weight();
        }
        return retval;
    }

    public VarSet findCluster(Variable[] vars) {
        List<Variable> l = Arrays.asList(vars);
        Iterator it = this.getVerticesIterator();
        while (it.hasNext()) {
            VarSet c2 = (VarSet)it.next();
            if (!c2.containsAll(l) || !l.containsAll(c2)) continue;
            return c2;
        }
        return null;
    }

    public void normalizeAll() {
        int n = this.cpfs.length;
        for (int i = 0; i < n; ++i) {
            if (this.cpfs[i] == null) continue;
            this.cpfs[i].normalize();
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor ptl = ((Sepset)it.value()).ptl;
            ptl.normalize();
        }
    }

    int getId(VarSet c) {
        return this.lookupIndex(c);
    }

    public void dump() {
        int n = this.cpfs.length;
        System.out.println(this);
        System.out.println("Vertex CPFs");
        for (int i = 0; i < n; ++i) {
            if (this.cpfs[i] == null) continue;
            System.out.println("CPF " + i + " " + this.cpfs[i].dumpToString());
        }
        System.out.println("sepset CPFs");
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor ptl = ((Sepset)it.value()).ptl;
            System.out.println(ptl.dumpToString());
        }
        System.out.println("/End JT");
    }

    public double dumpLogJoint(Assignment assn) {
        double accum = 0.0;
        for (int i = 0; i < this.cpfs.length; ++i) {
            if (this.cpfs[i] == null) continue;
            double phi = this.cpfs[i].logValue(assn);
            System.out.println("CPF " + i + " accum = " + accum);
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor ptl = ((Sepset)it.value()).ptl;
            double phi = ptl.logValue(assn);
            System.out.println("Sepset " + ptl.varSet() + " accum " + accum);
        }
        return accum;
    }

    public boolean isNaN() {
        int n = this.cpfs.length;
        for (int i = 0; i < n; ++i) {
            if (!this.cpfs[i].isNaN()) continue;
            return true;
        }
        TIntObjectIterator it = this.sepsets.iterator();
        while (it.hasNext()) {
            it.advance();
            Factor ptl = ((Sepset)it.value()).ptl;
            if (!ptl.isNaN()) continue;
            return true;
        }
        return false;
    }

    public double entropy() {
        double entropy = 0.0;
        for (Factor ptl : this.clusterPotentials()) {
            entropy += ptl.entropy();
        }
        for (Factor ptl : this.sepsetPotentials()) {
            entropy -= ptl.entropy();
        }
        return entropy;
    }

    public void decompact() {
        this.cpfs = new Factor[this.numNodes];
        this.clearCPFs();
    }

    public void compact() {
        this.cpfs = null;
    }

    private static class Sepset {
        Set set;
        Factor ptl;

        Sepset(Set s, Factor p) {
            this.set = s;
            this.ptl = p;
        }
    }
}

