package dbn;

import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import utils.Utils;

/* loaded from: input_file:dbn/CrossValidation.class */
public class CrossValidation {
    long randomSeed = new Random().nextLong();
    private Random r = new Random(this.randomSeed);
    private Observations o;
    private int[][] allData;
    private String[][] allPassiveData;
    private List<int[][]> stratifiedData;
    private List<String[][]> stratifiedPassiveData;
    private int numFolds;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dbn/CrossValidation$Pair.class */
    public class Pair {
        int a;
        int b;

        private Pair(int i, int i2) {
            this.a = i;
            this.b = i2;
        }

        /* synthetic */ Pair(CrossValidation crossValidation, int i, int i2, Pair pair) {
            this(i, i2);
        }
    }

    public CrossValidation setRandomSeed(long j) {
        this.randomSeed = j;
        this.r.setSeed(j);
        return this;
    }

    public long getRandomSeed() {
        return this.randomSeed;
    }

    private Pair countInstancesOfFold(int i) {
        int numAttributes = this.o.numAttributes();
        int markovLag = this.o.getMarkovLag();
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this.stratifiedData.size(); i4++) {
            for (int[] iArr : this.stratifiedData.get(i4)) {
                if (iArr[(markovLag + 1) * numAttributes] == i) {
                    i2++;
                } else {
                    i3++;
                }
            }
        }
        return new Pair(this, i2, i3, null);
    }

    private List<Integer> calculateFoldIds(int i, int i2) {
        ArrayList arrayList = new ArrayList(i);
        int i3 = i / i2;
        int i4 = i % i2;
        for (int i5 = 0; i5 < i2; i5++) {
            for (int i6 = 0; i6 < i3; i6++) {
                arrayList.add(Integer.valueOf(i5));
            }
        }
        for (int i7 = 0; i7 < i4; i7++) {
            arrayList.add(Integer.valueOf(i7));
        }
        Collections.shuffle(arrayList, this.r);
        return arrayList;
    }

    private int countInstancesOfClass(int i, int i2) {
        int numAttributes = this.o.numAttributes();
        int markovLag = this.o.getMarkovLag();
        int i3 = 0;
        for (int i4 = 0; i4 < this.allData.length; i4++) {
            if (this.allData[i4][(markovLag * numAttributes) + i] == i2) {
                i3++;
            }
        }
        return i3;
    }

    public CrossValidation(Observations observations, int i, Integer num) {
        this.o = observations;
        int numObservations = observations.numObservations(-1);
        int numAttributes = observations.numAttributes();
        int numTransitions = observations.numTransitions();
        int markovLag = observations.getMarkovLag();
        int numPassiveAttributes = observations.numPassiveAttributes();
        this.allData = new int[numObservations][(markovLag + 1) * numAttributes];
        this.allPassiveData = new String[numObservations][(markovLag + 1) * numPassiveAttributes];
        int[][][] observationsMatrix = observations.getObservationsMatrix();
        String[][][] passiveObservationsMatrix = observations.getPassiveObservationsMatrix();
        int i2 = 0;
        for (int i3 = 0; i3 < numTransitions; i3++) {
            for (int i4 = 0; i4 < observations.numObservations(i3); i4++) {
                this.allData[i2] = observationsMatrix[i3][i4];
                this.allPassiveData[i2] = passiveObservationsMatrix[i3][i4];
                i2++;
            }
        }
        if (num != null) {
            int size = observations.getAttributes().get(num.intValue()).size();
            this.stratifiedData = new ArrayList(size);
            this.stratifiedPassiveData = new ArrayList(size);
            for (int i5 = 0; i5 < size; i5++) {
                this.stratifiedData.add(new int[countInstancesOfClass(num.intValue(), i5)][((markovLag + 1) * numAttributes) + 1]);
                this.stratifiedPassiveData.add(new String[countInstancesOfClass(num.intValue(), i5)][(markovLag + 1) * numPassiveAttributes]);
                int[][] iArr = this.stratifiedData.get(i5);
                String[][] strArr = this.stratifiedPassiveData.get(i5);
                int i6 = 0;
                for (int i7 = 0; i7 < this.allData.length; i7++) {
                    int[] iArr2 = this.allData[i7];
                    if (iArr2[(markovLag * numAttributes) + num.intValue()] == i5) {
                        iArr[i6] = Arrays.copyOf(iArr2, ((markovLag + 1) * numAttributes) + 1);
                        strArr[i6] = (String[]) Arrays.copyOf(this.allPassiveData[i7], (markovLag + 1) * numPassiveAttributes);
                        i6++;
                    }
                }
            }
        } else {
            this.stratifiedData = new ArrayList(1);
            this.stratifiedPassiveData = new ArrayList(1);
            this.stratifiedData.add(new int[numObservations][((markovLag + 1) * numAttributes) + 1]);
            int[][] iArr3 = this.stratifiedData.get(0);
            this.stratifiedPassiveData.add(this.allPassiveData);
            for (int i8 = 0; i8 < this.allData.length; i8++) {
                iArr3[i8] = Arrays.copyOf(this.allData[i8], ((markovLag + 1) * numAttributes) + 1);
            }
        }
        this.numFolds = i;
        if (i > 0) {
            List<Integer> calculateFoldIds = calculateFoldIds(numObservations, i);
            int i9 = 0;
            for (int[][] iArr4 : this.stratifiedData) {
                for (int[] iArr5 : iArr4) {
                    int i10 = i9;
                    i9++;
                    iArr5[(markovLag + 1) * numAttributes] = calculateFoldIds.get(i10).intValue();
                }
            }
        }
    }

    private Observations evaluateFold(Observations observations, Observations observations2, int i, ScoringFunction scoringFunction, boolean z, String str, boolean z2, boolean z3, int i2) {
        Scores scores = new Scores(observations, i, true, true);
        scores.evaluate(scoringFunction);
        DynamicBayesNet dbn2 = z3 ? scores.to_bcDBN(scoringFunction, i2) : scores.toDBN();
        if (z) {
            try {
                Utils.writeToFile(String.valueOf(str) + ".dot", dbn2.toDot(false));
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }
        }
        dbn2.learnParameters(observations, true);
        if (z) {
            return null;
        }
        return dbn2.forecast(observations2, 1, true, z2);
    }

    public String evaluate2(int i, ScoringFunction scoringFunction, String str, List<Integer> list, boolean z, boolean z2, int i2) {
        DynamicBayesNet dbn2;
        DynamicBayesNet dbn3;
        int numAttributes = this.o.numAttributes();
        int numPassiveAttributes = this.o.numPassiveAttributes();
        int markovLag = this.o.getMarkovLag();
        StringBuilder sb = new StringBuilder();
        String property = System.getProperty("line.separator");
        sb.append(String.valueOf(this.randomSeed) + property);
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            sb.append(String.valueOf(this.o.getAttributes().get(it.next().intValue()).getName()) + "\t");
        }
        sb.append("\tactual_value" + property);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i3 = 0; i3 < this.numFolds; i3++) {
            double[] dArr = new double[4];
            int i4 = i3 + 1;
            System.out.println("Fold " + i4);
            Pair countInstancesOfFold = countInstancesOfFold(i3);
            int i5 = countInstancesOfFold.a;
            int i6 = countInstancesOfFold.b;
            int[][][] iArr = new int[1][i6][(markovLag + 1) * numAttributes];
            int[][][] iArr2 = new int[1][i6][(markovLag + 1) * numAttributes];
            int[][][] iArr3 = new int[1][i5][(markovLag + 1) * numAttributes];
            String[][] strArr = new String[i5][(markovLag + 1) * numPassiveAttributes];
            System.out.println("Training size: " + i6 + "\tTest size: " + i5);
            System.out.println("size stratified data " + this.stratifiedData.size());
            int i7 = 0;
            int i8 = 0;
            for (int i9 = 0; i9 < this.stratifiedData.size(); i9++) {
                int[][] iArr4 = this.stratifiedData.get(i9);
                String[][] strArr2 = this.stratifiedPassiveData.get(i9);
                for (int i10 = 0; i10 < iArr4.length; i10++) {
                    int[] iArr5 = iArr4[i10];
                    if (iArr5[(markovLag + 1) * numAttributes] == i3) {
                        iArr3[0][i7] = Arrays.copyOf(iArr5, (markovLag + 1) * numAttributes);
                        strArr[i7] = (String[]) Arrays.copyOf(strArr2[i10], (markovLag + 1) * numPassiveAttributes);
                        i7++;
                    } else if (i9 == 0) {
                        int i11 = i8;
                        i8++;
                        iArr[0][i11] = Arrays.copyOf(iArr5, (markovLag + 1) * numAttributes);
                    } else {
                        int i12 = i8;
                        i8++;
                        iArr2[0][i12] = Arrays.copyOf(iArr5, (markovLag + 1) * numAttributes);
                    }
                }
            }
            this.o.change0();
            Observations observations = new Observations(this.o, iArr);
            Observations observations2 = new Observations(this.o, iArr2);
            new Observations(this.o, iArr3);
            Scores scores = new Scores(observations, i, true, true);
            Scores scores2 = new Scores(observations2, i, true, true);
            scores.evaluate(scoringFunction);
            scores2.evaluate(scoringFunction);
            if (z2) {
                dbn2 = scores.to_bcDBN(scoringFunction, i2);
                dbn3 = scores2.to_bcDBN(scoringFunction, i2);
            } else {
                dbn2 = scores.toDBN();
                dbn3 = scores2.toDBN();
            }
            System.out.println(dbn2.toString());
            System.out.println(dbn3.toString());
            dbn2.learnParameters(observations);
            dbn3.learnParameters(observations2);
            System.out.println("-----------------------------------------------");
            sb.append("---Fold-" + i4 + "---" + property);
            Double valueOf = Double.valueOf(1.0d);
            Double valueOf2 = Double.valueOf(1.0d);
            for (int i13 = 0; i13 < i5; i13++) {
                MutableConfiguration mutableConfiguration = new MutableConfiguration(dbn2.getAttributes(), 1, Arrays.copyOfRange(iArr3[0][i13], 0, 18));
                System.out.println(dbn2.getInit().toString());
                Iterator<Integer> it2 = dbn2.getInit().getTop().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    System.out.println("Node " + intValue);
                    System.out.println("Attributes " + dbn2.getAttributes());
                    Configuration applyMask = mutableConfiguration.applyMask(dbn2.getInit().getParents().get(intValue), intValue);
                    System.out.println("indexParameters " + applyMask);
                    System.out.println("One " + Arrays.toString(Arrays.copyOfRange(iArr3[0][i13], 0, 18)));
                    System.out.println(dbn2.getInit().getParameters().get(intValue).get(applyMask));
                }
                MutableConfiguration mutableConfiguration2 = new MutableConfiguration(dbn3.getAttributes(), 1, iArr3[0][i13]);
                for (BayesNet bayesNet : dbn3.getTrans()) {
                    Iterator<Integer> it3 = bayesNet.getTop().iterator();
                    while (it3.hasNext()) {
                        int intValue2 = it3.next().intValue();
                        valueOf2 = Double.valueOf(valueOf2.doubleValue() * bayesNet.getParameters().get(intValue2).get(mutableConfiguration2.applyMask(bayesNet.getParents().get(intValue2), intValue2)).get(iArr3[0][i13][intValue2]).doubleValue());
                    }
                }
                int i14 = valueOf.doubleValue() >= valueOf2.doubleValue() ? 0 : 1;
                sb.append(String.valueOf(i14) + "\t");
                sb.append("\t");
                sb.append(String.valueOf(strArr[i13][(markovLag * numPassiveAttributes) + 0]) + "\t");
                int parseInt = Integer.parseInt(strArr[i13][(markovLag * numPassiveAttributes) + 0]);
                if (parseInt == i14 && parseInt == 2) {
                    dArr[0] = dArr[0] + 1.0d;
                }
                if (parseInt == i14 && parseInt == 1) {
                    dArr[1] = dArr[1] + 1.0d;
                }
                if (parseInt != i14 && parseInt == 1) {
                    dArr[2] = dArr[2] + 1.0d;
                }
                if (parseInt != i14 && parseInt == 2) {
                    dArr[3] = dArr[3] + 1.0d;
                }
                sb.append(property);
            }
            System.out.println("class  " + Arrays.toString(dArr));
            double d5 = (dArr[0] + dArr[1]) / (((dArr[0] + dArr[1]) + dArr[2]) + dArr[3]);
            double d6 = dArr[0] / (dArr[0] + dArr[2]);
            double d7 = dArr[0] / (dArr[0] + dArr[3]);
            double d8 = 0.5d * ((dArr[0] / (dArr[0] + dArr[3])) + (dArr[1] / (dArr[1] + dArr[2])));
            System.out.println("ACC " + d5);
            System.out.println("PRE " + d6);
            System.out.println("REC " + d7);
            System.out.println("AUC " + d8);
            d += d5;
            d2 += d6;
            d3 += d7;
            d4 += d8;
        }
        System.out.println("Accuracy " + (d / 10.0d));
        System.out.println("Precision " + (d2 / 10.0d));
        System.out.println("Recall " + (d3 / 10.0d));
        System.out.println("AUC " + (d4 / 10.0d));
        return sb.toString();
    }

    public String evaluate(int i, ScoringFunction scoringFunction, String str, List<Integer> list, boolean z, boolean z2, int i2) {
        int numAttributes = this.o.numAttributes();
        int numPassiveAttributes = this.o.numPassiveAttributes();
        int markovLag = this.o.getMarkovLag();
        StringBuilder sb = new StringBuilder();
        String property = System.getProperty("line.separator");
        sb.append(String.valueOf(this.randomSeed) + property);
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            sb.append(String.valueOf(this.o.getAttributes().get(it.next().intValue()).getName()) + "\t");
        }
        sb.append("\tactual_value" + property);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i3 = 0; i3 < this.numFolds; i3++) {
            double[] dArr = new double[4];
            int i4 = i3 + 1;
            System.out.println("Fold " + i4);
            Pair countInstancesOfFold = countInstancesOfFold(i3);
            int i5 = countInstancesOfFold.a;
            int i6 = countInstancesOfFold.b;
            int[][][] iArr = new int[1][i6][(markovLag + 1) * numAttributes];
            int[][][] iArr2 = new int[1][i5][(markovLag + 1) * numAttributes];
            String[][] strArr = new String[i5][(markovLag + 1) * numPassiveAttributes];
            System.out.println("Training size: " + i6 + "\tTest size: " + i5);
            int i7 = 0;
            int i8 = 0;
            for (int i9 = 0; i9 < this.stratifiedData.size(); i9++) {
                int[][] iArr3 = this.stratifiedData.get(i9);
                String[][] strArr2 = this.stratifiedPassiveData.get(i9);
                for (int i10 = 0; i10 < iArr3.length; i10++) {
                    int[] iArr4 = iArr3[i10];
                    if (iArr4[(markovLag + 1) * numAttributes] == i3) {
                        iArr2[0][i7] = Arrays.copyOf(iArr4, (markovLag + 1) * numAttributes);
                        strArr[i7] = (String[]) Arrays.copyOf(strArr2[i10], (markovLag + 1) * numPassiveAttributes);
                        i7++;
                    } else {
                        int i11 = i8;
                        i8++;
                        iArr[0][i11] = Arrays.copyOf(iArr4, (markovLag + 1) * numAttributes);
                    }
                }
            }
            Observations evaluateFold = evaluateFold(new Observations(this.o, iArr), new Observations(this.o, iArr2), i, scoringFunction, false, null, z, z2, i2);
            sb.append("---Fold-" + i4 + "---" + property);
            for (int i12 = 0; i12 < i5; i12++) {
                int[][][] observationsMatrix = evaluateFold.getObservationsMatrix();
                sb.append(String.valueOf(this.o.getAttributes().get(17).get(observationsMatrix[0][i12][(markovLag * numAttributes) + 17])) + "\t");
                boolean z3 = this.o.getAttributes().get(17).get(observationsMatrix[0][i12][(markovLag * numAttributes) + 17]) != null;
                int parseDouble = z3 ? (int) Double.parseDouble(this.o.getAttributes().get(17).get(observationsMatrix[0][i12][(markovLag * numAttributes) + 17])) : 0;
                sb.append("\t");
                sb.append(String.valueOf(strArr[i12][(markovLag * numPassiveAttributes) + 0]) + "\t");
                if (strArr[i12][(markovLag * numPassiveAttributes) + 0] == null) {
                    z3 = false;
                }
                int parseInt = z3 ? Integer.parseInt(strArr[i12][(markovLag * numPassiveAttributes) + 0]) : 0;
                if (z3) {
                    if (parseInt == parseDouble && parseInt == 2) {
                        dArr[0] = dArr[0] + 1.0d;
                    }
                    if (parseInt == parseDouble && parseInt == 1) {
                        dArr[1] = dArr[1] + 1.0d;
                    }
                    if (parseInt != parseDouble && parseInt == 1) {
                        dArr[2] = dArr[2] + 1.0d;
                    }
                    if (parseInt != parseDouble && parseInt == 2) {
                        dArr[3] = dArr[3] + 1.0d;
                    }
                }
                sb.append(property);
            }
            System.out.println("class  " + Arrays.toString(dArr));
            double d5 = (dArr[0] + dArr[1]) / (((dArr[0] + dArr[1]) + dArr[2]) + dArr[3]);
            double d6 = dArr[0] / (dArr[0] + dArr[2]);
            double d7 = dArr[0] / (dArr[0] + dArr[3]);
            double d8 = 0.5d * ((dArr[0] / (dArr[0] + dArr[3])) + (dArr[1] / (dArr[1] + dArr[2])));
            System.out.println("ACC " + d5);
            System.out.println("PRE " + d6);
            System.out.println("REC " + d7);
            System.out.println("AUC " + d8);
            d += d5;
            d2 += d6;
            d3 += d7;
            d4 += d8;
        }
        System.out.println("Accuracy " + (d / 10.0d));
        System.out.println("Precision " + (d2 / 10.0d));
        System.out.println("Recall " + (d3 / 10.0d));
        System.out.println("AUC " + (d4 / 10.0d));
        return sb.toString();
    }

    public static void main(String[] strArr) {
        System.out.println(new CrossValidation(new Observations("/home/margarida/Documents/NEUROCLIMICS2/data_disc.csv", "/home/margarida/Documents/NEUROCLIMICS2/data_disc_class.csv", (Integer) 1), 10, 17).evaluate2(1, new LLScoringFunction(), "tDBN_p=2_ll", Arrays.asList(17), true, true, 4));
    }
}
