package dbn;

import au.com.bytecode.opencsv.CSVWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;

/* loaded from: input_file:dbn/MultiNet.class */
public class MultiNet {
    private List<DynamicBayesNet> networks;
    private Observations o;
    private double[][][] clustering;
    private boolean is_bcDBN;
    private boolean is_cDBN;
    private boolean spanning;
    private int intra_ind;
    private int root;
    private int maxParents;
    private boolean stationaryProcess;
    private boolean multithread;

    public MultiNet(Observations observations, int i, boolean z, boolean z2, boolean z3, int i2, int i3, int i4, boolean z4, boolean z5) {
        this.o = observations;
        int numSubjects = observations.getNumSubjects();
        int numTransitions = observations.getNumTransitions();
        this.is_bcDBN = z;
        this.is_cDBN = z2;
        this.spanning = z3;
        this.intra_ind = i2;
        this.root = i3;
        this.maxParents = i4;
        this.stationaryProcess = z4;
        this.multithread = z5;
        this.networks = new ArrayList(i);
        this.clustering = new double[numTransitions][numSubjects][i];
        for (int i5 = 0; i5 < i; i5++) {
            Scores scores = new Scores(observations, this.maxParents, z4, true, z5);
            scores.evaluate(new RandomScoringFunction());
            DynamicBayesNet dbn2 = this.is_bcDBN ? scores.to_bcDBN(new RandomScoringFunction(), this.intra_ind) : this.is_cDBN ? scores.to_cDBN(new RandomScoringFunction(), this.intra_ind) : scores.toDBN(this.root, this.spanning);
            dbn2.generateParameters();
            this.networks.add(dbn2);
        }
        this.clustering = computeClusters(this.networks, z4, false);
    }

    private double[][][] computeClusters(List<DynamicBayesNet> list, boolean z, boolean z2) {
        double d;
        double log;
        int[][][] observationsMatrix = this.o.getObservationsMatrix();
        int numSubjects = this.o.getNumSubjects();
        int numTransitions = this.o.getNumTransitions();
        int size = this.o.getAttributes().size();
        int size2 = list.size();
        double[][][] dArr = new double[numTransitions][numSubjects][size2];
        double pow = Math.pow(10.0d, -5.0d);
        double[] alpha = getAlpha(this.clustering);
        for (int i = 0; i < numSubjects; i++) {
            double d2 = Double.NEGATIVE_INFINITY;
            int i2 = 0;
            int i3 = 0;
            for (DynamicBayesNet dynamicBayesNet : list) {
                double d3 = 0.0d;
                for (int i4 = 0; i4 < numTransitions; i4++) {
                    for (int i5 = 0; i5 < size; i5++) {
                        if (z) {
                            d = d3;
                            log = Math.log(dynamicBayesNet.transitionNets.get(0).getParameters(i5, observationsMatrix[i4][i]).get(0).doubleValue());
                        } else {
                            d = d3;
                            log = Math.log(dynamicBayesNet.transitionNets.get(i4).getParameters(i5, observationsMatrix[i4][i]).get(0).doubleValue());
                        }
                        d3 = d + log;
                    }
                }
                for (int i6 = 0; i6 < numTransitions; i6++) {
                    dArr[i6][i][i2] = d3;
                }
                if (d2 < d3) {
                    d2 = d3;
                    i3 = i2;
                }
                i2++;
            }
            for (int i7 = 0; i7 < size2; i7++) {
                for (int i8 = 0; i8 < numTransitions; i8++) {
                    if (dArr[i8][i][i7] - d2 >= Math.log(pow) - Math.log(size2)) {
                        dArr[i8][i][i7] = Math.exp(dArr[i8][i][i7] - d2);
                    } else {
                        dArr[i8][i][i7] = 0.0d;
                    }
                }
            }
            double d4 = 0.0d;
            for (int i9 = 0; i9 < size2; i9++) {
                d4 += (alpha[i9] * Math.ceil(dArr[0][i][i9] * 1.0E15d)) / 1.0E15d;
            }
            for (int i10 = 0; i10 < size2; i10++) {
                for (int i11 = 0; i11 < numTransitions; i11++) {
                    double[] dArr2 = dArr[i11][i];
                    int i12 = i10;
                    dArr2[i12] = dArr2[i12] * (alpha[i10] / d4);
                }
            }
            if (z2) {
                for (int i13 = 0; i13 < size2; i13++) {
                    for (int i14 = 0; i14 < numTransitions; i14++) {
                        dArr[i14][i][i13] = 0.0d;
                        if (i13 == i3) {
                            dArr[i14][i][i13] = 1.0d;
                        }
                    }
                }
            } else {
                int i15 = 0;
                double d5 = 0.0d;
                double nextDouble = 0.0d + (1.0d * new Random().nextDouble());
                int i16 = 0;
                while (true) {
                    if (i16 >= size2) {
                        break;
                    }
                    d5 += dArr[0][i][i16];
                    if (d5 >= nextDouble) {
                        i15 = i16;
                        break;
                    }
                    i16++;
                }
                for (int i17 = 0; i17 < size2; i17++) {
                    for (int i18 = 0; i18 < numTransitions; i18++) {
                        dArr[i18][i][i17] = 0.0d;
                        if (i17 == i15) {
                            dArr[i18][i][i17] = 1.0d;
                        }
                    }
                }
            }
        }
        return dArr;
    }

    private double[] getAlpha(double[][][] dArr) {
        int numSubjects = this.o.getNumSubjects();
        int size = this.networks.size();
        double d = 0.0d;
        double[] dArr2 = new double[size];
        for (int i = 0; i < size; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < numSubjects; i2++) {
                d2 += dArr[0][i2][i];
            }
            dArr2[i] = d2 / numSubjects;
            d += dArr2[i];
        }
        if (d == 0.0d) {
            for (int i3 = 0; i3 < size; i3++) {
                dArr2[i3] = 1.0d / size;
            }
        }
        return dArr2;
    }

    private double[][] selectCluster(double[][][] dArr, int i) {
        int numSubjects = this.o.getNumSubjects();
        int numTransitions = this.o.getNumTransitions();
        double[][] dArr2 = new double[numTransitions][numSubjects];
        for (int i2 = 0; i2 < numTransitions; i2++) {
            for (int i3 = 0; i3 < numSubjects; i3++) {
                dArr2[i2][i3] = dArr[i2][i3][i];
            }
        }
        return dArr2;
    }

    public List<DynamicBayesNet> getNetworks() {
        return this.networks;
    }

    private List<DynamicBayesNet> trainNetworks(double[][][] dArr) {
        int size = this.networks.size();
        int[][][] observationsMatrix = this.o.getObservationsMatrix();
        List<Attribute> attributes = this.o.getAttributes();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            Observations observations = new Observations(attributes, observationsMatrix, selectCluster(dArr, i));
            Scores scores = new Scores(observations, this.maxParents, this.stationaryProcess, false, this.multithread);
            scores.evaluate(new LLScoringFunction());
            DynamicBayesNet dbn2 = this.is_bcDBN ? scores.to_bcDBN(new LLScoringFunction(), this.intra_ind) : this.is_cDBN ? scores.to_cDBN(new LLScoringFunction(), this.intra_ind) : scores.toDBN(this.root, this.spanning);
            dbn2.learnParameters(observations, this.stationaryProcess);
            arrayList.add(dbn2);
        }
        return arrayList;
    }

    public void clust() {
        this.networks.size();
        List<DynamicBayesNet> list = this.networks;
        List<DynamicBayesNet> list2 = this.networks;
        double[][][] dArr = this.clustering;
        double score = getScore(list2, dArr, this.stationaryProcess);
        double d = Double.NEGATIVE_INFINITY;
        boolean z = false;
        int i = 0;
        if (0 == 0) {
            System.out.println("Starting with stochastic EM.");
        } else {
            System.out.println("Starting with classification EM.");
        }
        while (score > d) {
            if ((i >= 100) & (!z)) {
                z = true;
                System.out.println("Changing to classification EM.");
            }
            list = list2;
            d = score;
            list2 = trainNetworks(dArr);
            dArr = computeClusters(list2, this.stationaryProcess, z);
            score = getScore(list2, dArr, this.stationaryProcess);
            i++;
        }
        this.networks = list;
        this.clustering = computeClusters(list, true, false);
    }

    public double getBICScore() {
        List<Attribute> attributes = this.o.getAttributes();
        int[][][] observationsMatrix = this.o.getObservationsMatrix();
        int i = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        for (DynamicBayesNet dynamicBayesNet : this.networks) {
            Observations observations = new Observations(attributes, observationsMatrix, selectCluster(this.clustering, i));
            d += 2.0d * dynamicBayesNet.getScore(observations, new LLScoringFunction(), this.stationaryProcess);
            d2 += dynamicBayesNet.getNumberParameters(observations);
            i++;
        }
        return d - (d2 * Math.log(this.o.getNumSubjects()));
    }

    public double getScore(List<DynamicBayesNet> list, double[][][] dArr, boolean z) {
        int numSubjects = this.o.getNumSubjects();
        int[][][] observationsMatrix = this.o.getObservationsMatrix();
        List<Attribute> attributes = this.o.getAttributes();
        int size = this.networks.size();
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        double[] alpha = getAlpha(dArr);
        Iterator<DynamicBayesNet> it = list.iterator();
        while (it.hasNext()) {
            d += it.next().getScore(new Observations(attributes, observationsMatrix, selectCluster(dArr, i)), new LLScoringFunction(), z);
            i++;
        }
        for (int i2 = 0; i2 < size; i2++) {
            double d3 = 0.0d;
            for (int i3 = 0; i3 < numSubjects; i3++) {
                d3 += dArr[0][i3][i2];
            }
            d2 += d3 * Math.log(alpha[i2]);
        }
        return d + d2;
    }

    public void writeToFile(String str) {
        int i = 0;
        int size = this.networks.size();
        Map<String, boolean[]> subjectIsPresent = this.o.getSubjectIsPresent();
        try {
            File file = new File(str);
            file.createNewFile();
            CSVWriter cSVWriter = new CSVWriter(new FileWriter(file));
            int numSubjects = this.o.getNumSubjects();
            ArrayList arrayList = new ArrayList(2);
            arrayList.add("subject_id");
            arrayList.add("Class");
            cSVWriter.writeNext((String[]) arrayList.toArray(new String[0]));
            Iterator<String> it = subjectIsPresent.keySet().iterator();
            int i2 = -1;
            for (int i3 = 0; i3 < numSubjects; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                double[] dArr = this.clustering[0][i3];
                for (int i4 = 0; i4 < size; i4++) {
                    if (d < dArr[i4]) {
                        i = i4;
                        d = dArr[i4];
                    }
                }
                ArrayList arrayList2 = new ArrayList(2);
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    String next = it.next();
                    i2++;
                    if (subjectIsPresent.get(next)[0]) {
                        arrayList2.add(next);
                        break;
                    }
                }
                arrayList2.add(Integer.toString(i));
                cSVWriter.writeNext((String[]) arrayList2.toArray(new String[0]));
            }
            cSVWriter.close();
        } catch (IOException e) {
            System.err.println("Could not write to " + str + ".");
            e.printStackTrace();
            System.exit(1);
        }
    }

    public String toString(boolean z) {
        int size = this.networks.size();
        int numSubjects = this.o.getNumSubjects();
        StringBuilder sb = new StringBuilder();
        String property = System.getProperty("line.separator");
        sb.append("Number of clusters : " + size + property);
        sb.append("Number of Observations : " + numSubjects + property + property);
        double[] alpha = getAlpha(this.clustering);
        for (int i = 0; i < size; i++) {
            sb.append("--- Cluster " + i + " ---" + property);
            sb.append(this.networks.get(i).toString(z));
            sb.append("Alpha: " + alpha[i] + property + property);
        }
        return sb.toString();
    }

    public static void main(String[] strArr) {
        Options options = new Options();
        Option build = Option.builder("i").hasArg().required(true).desc("Folder of the dataset.").argName("path").longOpt("inputFile").build();
        Option build2 = Option.builder("n").longOpt("numClusters").desc("Number of clusters.").required(true).hasArg().argName("int").build();
        options.addOption(build);
        options.addOption(build2);
        new DefaultParser();
    }
}
