package com.antbrains.crf;

import com.google.gson.Gson;
import de.ruedigermoeller.serialization.FSTObjectInput;
import de.ruedigermoeller.serialization.FSTObjectOutput;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.procedure.TObjectIntProcedure;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/* loaded from: input_file:com/antbrains/crf/SgdCrf.class */
public class SgdCrf {
    private static void initWeights(TrainingWeights trainingWeights) {
        Arrays.fill(trainingWeights.getBosTransitionWeights(), 0.0d);
        Arrays.fill(trainingWeights.getEosTransitionWeights(), 0.0d);
        Arrays.fill(trainingWeights.getTransitionWeights(), 0.0d);
        Arrays.fill(trainingWeights.getAttributeWeights(), 0.0d);
    }

    private static double[] computeStateScores(Instance instance, boolean z, int i, double[] dArr) {
        int length = instance.length();
        int rowSize = instance.rowSize();
        int[] attrIds = instance.getAttrIds();
        double[] dArr2 = new double[length * i];
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < rowSize; i3++) {
                int i4 = attrIds[(i2 * rowSize) + i3];
                if (i4 >= 0) {
                    for (int i5 = 0; i5 < i; i5++) {
                        int i6 = (i2 * i) + i5;
                        dArr2[i6] = dArr2[i6] + dArr[(i4 * i) + i5];
                    }
                }
            }
        }
        if (z) {
            for (int i7 = 0; i7 < length; i7++) {
                for (int i8 = 0; i8 < i; i8++) {
                    dArr2[(i7 * i) + i8] = Math.exp(dArr2[(i7 * i) + i8]);
                }
            }
        }
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    private static double[][] computeForwardScores(Instance instance, double[] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, double[] dArr5, int i) {
        int length = instance.length();
        ?? r0 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            r0[i2] = new double[i];
        }
        double d = 0.0d;
        double[] dArr6 = r0[0];
        for (int i3 = 0; i3 < i; i3++) {
            double d2 = dArr[i3] * dArr4[i3];
            dArr6[i3] = d2;
            d += d2;
        }
        dArr5[0] = d != 0.0d ? 1.0d / d : 1.0d;
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i4;
            dArr6[i5] = dArr6[i5] * dArr5[0];
        }
        for (int i6 = 1; i6 < length; i6++) {
            double d3 = 0.0d;
            Object[] objArr = r0[i6 - 1];
            double[] dArr7 = r0[i6];
            for (int i7 = 0; i7 < i; i7++) {
                double d4 = 0.0d;
                for (int i8 = 0; i8 < i; i8++) {
                    d4 += objArr[i8] * dArr3[i8][i7];
                }
                double d5 = d4 * dArr4[(i6 * i) + i7];
                dArr7[i7] = d5;
                d3 += d5;
            }
            dArr5[i6] = d3 != 0.0d ? 1.0d / d3 : 1.0d;
            for (int i9 = 0; i9 < i; i9++) {
                int i10 = i9;
                dArr7[i10] = dArr7[i10] * dArr5[i6];
            }
        }
        double d6 = 0.0d;
        Object[] objArr2 = r0[length - 1];
        for (int i11 = 0; i11 < i; i11++) {
            d6 += objArr2[i11] * dArr2[i11];
        }
        dArr5[length] = d6 != 0.0d ? 1.0d / d6 : 1.0d;
        return r0;
    }

    private static double computeLogProb(Instance instance, double[] dArr, double[][] dArr2, double[][] dArr3, double[] dArr4, double d, double[] dArr5, int i) {
        int length = instance.length();
        int[] labelIds = instance.labelIds();
        int i2 = labelIds[0];
        double log = Math.log(dArr2[0][i2]) - Math.log(dArr4[0]);
        for (int i3 = 1; i3 < length; i3++) {
            int i4 = labelIds[i3];
            log = log + dArr5[(i2 * i) + i4] + Math.log(dArr[(i3 * i) + i4]);
            i2 = i4;
        }
        return (log + (Math.log(dArr3[length - 1][i2]) - Math.log(dArr4[length - 1]))) - d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v9, types: [double[], double[][]] */
    private static double computeInitialLoglikelihood(TrainingDataSet trainingDataSet, int i, double d, double[] dArr, double[] dArr2) {
        int labelNum = trainingDataSet.getLabelNum();
        double[] dArr3 = new double[labelNum];
        double[] dArr4 = new double[labelNum];
        for (int i2 = 0; i2 < labelNum; i2++) {
            dArr3[i2] = 1.0d;
            dArr4[i2] = 1.0d;
        }
        ?? r0 = new double[labelNum];
        for (int i3 = 0; i3 < labelNum; i3++) {
            r0[i3] = new double[labelNum];
            for (int i4 = 0; i4 < labelNum; i4++) {
                r0[i3][i4] = 4607182418800017408;
            }
        }
        double d2 = 0.0d;
        List<Instance> instances = trainingDataSet.getInstances();
        for (int i5 = 0; i5 < i; i5++) {
            Instance instance = instances.get(i5);
            double[] computeStateScores = computeStateScores(instance, true, labelNum, dArr);
            double[] dArr5 = new double[instance.length() + 1];
            double[][] computeForwardScores = computeForwardScores(instance, dArr3, dArr4, r0, computeStateScores, dArr5, labelNum);
            double d3 = 0.0d;
            for (int i6 = 0; i6 <= instance.length(); i6++) {
                d3 -= Math.log(dArr5[i6]);
            }
            d2 += computeLogProb(instance, computeStateScores, computeForwardScores, computeBackwardScores(instance, dArr4, r0, computeStateScores, dArr5, labelNum), dArr5, d3, dArr2, labelNum);
        }
        double d4 = 0.0d;
        for (int i7 = 0; i7 < trainingDataSet.getAttributeNum(); i7++) {
            for (int i8 = 0; i8 < labelNum; i8++) {
                double d5 = dArr[(i7 * labelNum) + i8];
                d4 += d5 * d5;
            }
        }
        return d2 - (((0.5d * d) * d4) * i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    private static double[][] computeBackwardScores(Instance instance, double[] dArr, double[][] dArr2, double[] dArr3, double[] dArr4, int i) {
        int length = instance.length();
        ?? r0 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            r0[i2] = new double[i];
        }
        double[] dArr5 = r0[length - 1];
        double d = dArr4[length - 1];
        for (int i3 = 0; i3 < i; i3++) {
            dArr5[i3] = dArr[i3] * d;
        }
        for (int i4 = length - 2; i4 >= 0; i4--) {
            double[] dArr6 = r0[i4];
            Object[] objArr = r0[i4 + 1];
            double d2 = dArr4[i4];
            for (int i5 = 0; i5 < i; i5++) {
                double d3 = 0.0d;
                double[] dArr7 = dArr2[i5];
                for (int i6 = 0; i6 < i; i6++) {
                    d3 += dArr7[i6] * dArr3[((i4 + 1) * i) + i6] * objArr[i6];
                }
                dArr6[i5] = d3 * d2;
            }
        }
        return r0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v10, types: [double[], double[][]] */
    private static double calibrateSgd(List<Instance> list, int i, double d, double d2, int i2, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        int i3 = 0;
        double d3 = 1.0d;
        double d4 = 0.0d;
        double[] dArr5 = new double[i2];
        double[] dArr6 = new double[i2];
        ?? r0 = new double[i2];
        for (int i4 = 0; i4 < i2; i4++) {
            r0[i4] = new double[i2];
        }
        for (int i5 = 0; i5 < i; i5++) {
            Instance instance = list.get(i5);
            int length = instance.length();
            double d5 = 1.0d / (d2 * (d + i3));
            d3 *= 1.0d - (d5 * d2);
            double d6 = d5 / (d3 * 1.0d);
            for (int i6 = 0; i6 < i2; i6++) {
                dArr5[i6] = Math.exp(dArr[i6]);
                dArr6[i6] = Math.exp(dArr2[i6]);
            }
            for (int i7 = 0; i7 < i2; i7++) {
                for (int i8 = 0; i8 < i2; i8++) {
                    r0[i7][i8] = Math.exp(dArr3[(i7 * i2) + i8]);
                }
            }
            double[] computeStateScores = computeStateScores(instance, true, i2, dArr4);
            double[] dArr7 = new double[length + 1];
            double[][] computeForwardScores = computeForwardScores(instance, dArr5, dArr6, r0, computeStateScores, dArr7, i2);
            double[][] computeBackwardScores = computeBackwardScores(instance, dArr6, r0, computeStateScores, dArr7, i2);
            double d7 = 0.0d;
            for (int i9 = 0; i9 <= length; i9++) {
                d7 -= Math.log(dArr7[i9]);
            }
            d4 += computeLogProb(instance, computeStateScores, computeForwardScores, computeBackwardScores, dArr7, d7, dArr3, i2);
            updateFeatureWeights(instance, r0, computeStateScores, computeForwardScores, computeBackwardScores, dArr7, d6, i2, dArr4, dArr, dArr2, dArr3);
            i3++;
        }
        return d4;
    }

    private static double updateWeight(double[] dArr, int i, double d) {
        double d2 = dArr[i];
        dArr[i] = d2 + d;
        return d * (d + (d2 * 2.0d));
    }

    private static double updateWeight(double[] dArr, int i, double d, double d2, boolean z) {
        double updateWeight = updateWeight(dArr, i, (-d) * d2);
        if (z) {
            updateWeight += updateWeight(dArr, i, d);
        }
        return updateWeight;
    }

    private static double updateFeatureWeights(Instance instance, double[][] dArr, double[] dArr2, double[][] dArr3, double[][] dArr4, double[] dArr5, double d, int i, double[] dArr6, double[] dArr7, double[] dArr8, double[] dArr9) {
        int length = instance.length();
        int rowSize = instance.rowSize();
        int[] attrIds = instance.getAttrIds();
        int[] labelIds = instance.labelIds();
        double d2 = 0.0d;
        double[] dArr10 = dArr3[0];
        double[] dArr11 = dArr4[0];
        double d3 = dArr5[length] / dArr5[0];
        double[] dArr12 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr12[i2] = dArr10[i2] * dArr11[i2] * d3;
        }
        int i3 = 0;
        while (i3 < i) {
            d2 += updateWeight(dArr7, i3, d, dArr12[i3], i3 == labelIds[0]);
            i3++;
        }
        for (int i4 = 0; i4 < rowSize; i4++) {
            int i5 = attrIds[i4];
            if (i5 >= 0) {
                int i6 = 0;
                while (i6 < i) {
                    d2 += updateWeight(dArr6, (i5 * i) + i6, d, dArr12[i6], i6 == labelIds[0]);
                    i6++;
                }
            }
        }
        double[] dArr13 = dArr3[length - 1];
        double[] dArr14 = dArr4[length - 1];
        double d4 = dArr5[length] / dArr5[length - 1];
        for (int i7 = 0; i7 < i; i7++) {
            dArr12[i7] = dArr13[i7] * dArr14[i7] * d4;
        }
        int i8 = 0;
        while (i8 < i) {
            d2 += updateWeight(dArr8, i8, d, dArr12[i8], i8 == labelIds[length - 1]);
            i8++;
        }
        for (int i9 = (length - 1) * rowSize; i9 < attrIds.length; i9++) {
            int i10 = attrIds[i9];
            if (i10 >= 0) {
                int i11 = 0;
                while (i11 < i) {
                    d2 += updateWeight(dArr6, (i10 * i) + i11, d, dArr12[i11], i11 == labelIds[length - 1]);
                    i11++;
                }
            }
        }
        for (int i12 = 1; i12 < length - 1; i12++) {
            double[] dArr15 = dArr3[i12];
            double[] dArr16 = dArr4[i12];
            double d5 = dArr5[length] / dArr5[i12];
            for (int i13 = 0; i13 < i; i13++) {
                dArr12[i13] = -1.0d;
            }
            for (int i14 = 0; i14 < rowSize; i14++) {
                int i15 = attrIds[(i12 * rowSize) + i14];
                if (i15 >= 0) {
                    int i16 = 0;
                    while (i16 < i) {
                        if (dArr12[i16] == -1.0d) {
                            dArr12[i16] = dArr15[i16] * dArr16[i16] * d5;
                        }
                        d2 += updateWeight(dArr6, (i15 * i) + i16, d, dArr12[i16], i16 == labelIds[i12]);
                        i16++;
                    }
                }
            }
        }
        for (int i17 = 0; i17 < length - 1; i17++) {
            double[] dArr17 = dArr3[i17];
            double[] dArr18 = dArr4[i17 + 1];
            double d6 = dArr5[length];
            int i18 = 0;
            while (i18 < i) {
                double[] dArr19 = dArr[i18];
                int i19 = 0;
                while (i19 < i) {
                    d2 += updateWeight(dArr9, (i18 * i) + i19, d, dArr17[i18] * dArr19[i19] * dArr2[((i17 + 1) * i) + i19] * dArr18[i19] * d6, i18 == labelIds[i17] && i19 == labelIds[i17 + 1]);
                    i19++;
                }
                i18++;
            }
        }
        return d2;
    }

    private static double calibrate(TrainingDataSet trainingDataSet, double d, TrainingParams trainingParams, TrainingWeights trainingWeights) {
        List<Instance> instances = trainingDataSet.getInstances();
        int min = Math.min(instances.size(), trainingParams.getSamplesNum());
        System.out.println(String.format("sgd.calibration.eta: %f\n", Double.valueOf(trainingParams.getEta())));
        System.out.println(String.format("sgd.calibration.rate: %f\n", Double.valueOf(trainingParams.getRate())));
        System.out.println(String.format("sgd.calibration.samples: %d\n", Integer.valueOf(min)));
        System.out.println(String.format("sgd.calibration.candidates: %d\n", Integer.valueOf(trainingParams.getCandidatesNum())));
        Collections.shuffle(trainingDataSet.getInstances());
        initWeights(trainingWeights);
        double computeInitialLoglikelihood = computeInitialLoglikelihood(trainingDataSet, min, d, trainingWeights.getAttributeWeights(), trainingWeights.getTransitionWeights());
        System.out.println(String.format("Initial Log-likelihood: %f\n", Double.valueOf(computeInitialLoglikelihood)));
        boolean z = false;
        int candidatesNum = trainingParams.getCandidatesNum();
        int i = 0;
        double d2 = -1.7976931348623157E308d;
        double eta = trainingParams.getEta();
        double d3 = eta;
        double d4 = eta;
        System.out.println("calibrating");
        while (true) {
            if (candidatesNum <= 0 && z) {
                double d5 = d3;
                System.out.println(String.format("Best learning rate (eta): %f\n", Double.valueOf(d5)));
                return 1.0d / (d * d5);
            }
            System.out.println(String.format("Trial #%d (eta = %f): ", Integer.valueOf(i + 1), Double.valueOf(d4)));
            initWeights(trainingWeights);
            double calibrateSgd = calibrateSgd(instances, min, 1.0d / (d * d4), d, trainingDataSet.getLabelNum(), trainingWeights.getBosTransitionWeights(), trainingWeights.getEosTransitionWeights(), trainingWeights.getTransitionWeights(), trainingWeights.getAttributeWeights());
            boolean z2 = !Double.isInfinite(calibrateSgd) && calibrateSgd > computeInitialLoglikelihood;
            if (z2) {
                System.out.println(String.format("%f\n", Double.valueOf(calibrateSgd)));
            } else {
                System.out.println(String.format("%f (worse)\n", Double.valueOf(calibrateSgd)));
            }
            if (z2) {
                candidatesNum--;
                if (calibrateSgd > d2) {
                    d2 = calibrateSgd;
                    d3 = d4;
                }
            }
            if (z) {
                System.out.println(String.format("etaValue(%f)/=rate(%f)", Double.valueOf(d4), Double.valueOf(trainingParams.getRate())));
                d4 /= trainingParams.getRate();
                System.out.println("etaValue=" + d4);
            } else if (z2) {
                System.out.println(String.format("etaValue(%f)*=rate(%f)", Double.valueOf(d4), Double.valueOf(trainingParams.getRate())));
                d4 *= trainingParams.getRate();
                System.out.println("etaValue=" + d4);
            } else {
                z = true;
                System.out.println(String.format("initEtaValue(%f)/=rate(%f)", Double.valueOf(eta), Double.valueOf(trainingParams.getRate())));
                d4 = eta / trainingParams.getRate();
                System.out.println("etaValue=" + d4);
            }
            i++;
        }
    }

    public static void saveModel(TrainingParams trainingParams, TrainingWeights trainingWeights, String str) throws IOException {
        FSTObjectOutput fSTObjectOutput = null;
        try {
            fSTObjectOutput = new FSTObjectOutput(new FileOutputStream(str));
            fSTObjectOutput.writeObject(trainingParams, new Class[]{TrainingParams.class});
            fSTObjectOutput.writeObject(trainingWeights, new Class[]{TrainingWeights.class});
            if (fSTObjectOutput != null) {
                fSTObjectOutput.close();
            }
        } catch (Throwable th) {
            if (fSTObjectOutput != null) {
                fSTObjectOutput.close();
            }
            throw th;
        }
    }

    public static CrfModel loadModel(InputStream inputStream) throws Exception {
        FSTObjectInput fSTObjectInput = null;
        try {
            fSTObjectInput = new FSTObjectInput(inputStream);
            CrfModel crfModel = new CrfModel((TrainingParams) fSTObjectInput.readObject(new Class[]{TrainingParams.class}), (TrainingWeights) fSTObjectInput.readObject(new Class[]{TrainingWeights.class}));
            if (fSTObjectInput != null) {
                fSTObjectInput.close();
            }
            return crfModel;
        } catch (Throwable th) {
            if (fSTObjectInput != null) {
                fSTObjectInput.close();
            }
            throw th;
        }
    }

    public static CrfModel loadModel(String str) throws Exception {
        FSTObjectInput fSTObjectInput = null;
        try {
            fSTObjectInput = new FSTObjectInput(new FileInputStream(str));
            CrfModel crfModel = new CrfModel((TrainingParams) fSTObjectInput.readObject(new Class[]{TrainingParams.class}), (TrainingWeights) fSTObjectInput.readObject(new Class[]{TrainingWeights.class}));
            if (fSTObjectInput != null) {
                fSTObjectInput.close();
            }
            return crfModel;
        } catch (Throwable th) {
            if (fSTObjectInput != null) {
                fSTObjectInput.close();
            }
            throw th;
        }
    }

    private static Instance readInstance(Gson gson, String str) {
        return (Instance) gson.fromJson(str.split("\t", 2)[1], Instance.class);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v25, types: [double[], double[][]] */
    public static void train(TrainingDataSet trainingDataSet, int i, int i2, TrainingParams trainingParams, TrainingWeights trainingWeights, TrainingProgress trainingProgress) {
        List<Instance> instances = trainingDataSet.getInstances();
        int labelNum = trainingDataSet.getLabelNum();
        Collections.shuffle(instances);
        List<Instance> subList = instances.subList(0, i);
        List<Instance> subList2 = instances.subList(i, instances.size());
        trainingProgress.startTraining();
        boolean z = subList != null && subList.size() > 0;
        int i3 = 0;
        double d = 1.0d;
        double d2 = 1.0d;
        double sigma = 1.0d / ((trainingParams.getSigma() * trainingParams.getSigma()) * subList2.size());
        double[] dArr = new double[labelNum];
        double[] dArr2 = new double[labelNum];
        ?? r0 = new double[labelNum];
        for (int i4 = 0; i4 < labelNum; i4++) {
            r0[i4] = new double[labelNum];
        }
        double calibrate = trainingParams.getT0() == 0.0d ? calibrate(trainingDataSet, sigma, trainingParams, trainingWeights) : trainingParams.getT0();
        double d3 = 0.0d;
        initWeights(trainingWeights);
        double[] bosTransitionWeights = trainingWeights.getBosTransitionWeights();
        double[] eosTransitionWeights = trainingWeights.getEosTransitionWeights();
        double[] transitionWeights = trainingWeights.getTransitionWeights();
        double[] attributeWeights = trainingWeights.getAttributeWeights();
        for (int i5 = 1; i5 <= i2; i5++) {
            trainingProgress.doIter(i5);
            Collections.shuffle(subList2);
            double d4 = 0.0d;
            for (Instance instance : subList2) {
                double d5 = 1.0d / (sigma * (calibrate + i3));
                d *= 1.0d - (d5 * sigma);
                d4 = d * d2;
                double d6 = d5 / d4;
                for (int i6 = 0; i6 < labelNum; i6++) {
                    dArr[i6] = Math.exp(bosTransitionWeights[i6]);
                    dArr2[i6] = Math.exp(eosTransitionWeights[i6]);
                }
                for (int i7 = 0; i7 < labelNum; i7++) {
                    for (int i8 = 0; i8 < labelNum; i8++) {
                        r0[i7][i8] = Math.exp(transitionWeights[(i7 * labelNum) + i8]);
                    }
                }
                int length = instance.length();
                double[] computeStateScores = computeStateScores(instance, true, labelNum, attributeWeights);
                double[] dArr3 = new double[length + 1];
                d3 += updateFeatureWeights(instance, r0, computeStateScores, computeForwardScores(instance, dArr, dArr2, r0, computeStateScores, dArr3, labelNum), computeBackwardScores(instance, dArr2, r0, computeStateScores, dArr3, labelNum), dArr3, d6, labelNum, attributeWeights, bosTransitionWeights, eosTransitionWeights, transitionWeights);
                double d7 = d3 * d4 * d4 * sigma;
                if (d7 > 1.0d) {
                    d2 = 1.0d / Math.sqrt(d7);
                }
                i3++;
            }
            if (d4 < 1.0E-20d) {
                for (int i9 = 0; i9 < labelNum; i9++) {
                    int i10 = i9;
                    bosTransitionWeights[i10] = bosTransitionWeights[i10] * d4;
                    int i11 = i9;
                    eosTransitionWeights[i11] = eosTransitionWeights[i11] * d4;
                }
                for (int i12 = 0; i12 < labelNum; i12++) {
                    for (int i13 = 0; i13 < labelNum; i13++) {
                        int i14 = (i12 * labelNum) + i13;
                        transitionWeights[i14] = transitionWeights[i14] * d4;
                    }
                }
                for (int i15 = 0; i15 < attributeWeights.length; i15++) {
                    int i16 = i15;
                    attributeWeights[i16] = attributeWeights[i16] * d4;
                }
                d = 1.0d;
                d2 = 1.0d;
            }
            if (z) {
                trainingProgress.doValidate(evaluate(subList, trainingWeights).toString());
            }
        }
        trainingProgress.doValidate(evaluate(instances, trainingWeights).toString());
    }

    private static EvaluationResult evaluate(List<Instance> list, TrainingWeights trainingWeights) {
        String[] labelTexts = trainingWeights.getLabelTexts();
        EvaluationResult evaluationResult = new EvaluationResult(labelTexts);
        for (Instance instance : list) {
            int[] tagId = tagId(instance, trainingWeights);
            evaluationResult.totalItemCount += tagId.length;
            boolean z = false;
            for (int i = 0; i < instance.length(); i++) {
                int i2 = tagId[i];
                int i3 = instance.labelIds()[i];
                for (int i4 = 0; i4 < labelTexts.length; i4++) {
                    int[] iArr = evaluationResult.labelIndex2count[i4];
                    if (i3 == i4) {
                        if (i2 == i4) {
                            iArr[0] = iArr[0] + 1;
                        } else {
                            iArr[1] = iArr[1] + 1;
                        }
                    } else if (i2 == i4) {
                        iArr[3] = iArr[3] + 1;
                    } else {
                        iArr[2] = iArr[2] + 1;
                    }
                }
                if (i3 == i2) {
                    evaluationResult.correctItemCount++;
                } else if (!z) {
                    z = true;
                }
            }
            evaluationResult.totalSeqCount++;
            if (!z) {
                evaluationResult.correctSeqCount++;
            }
        }
        return evaluationResult;
    }

    public static double getScore(String[] strArr, TagConvertor tagConvertor, CrfModel crfModel) {
        TrainingWeights trainingWeights = crfModel.weights;
        return getScore(trainingWeights, buildInstance(strArr, tagConvertor, trainingWeights.getAttributeDict(), trainingWeights.getLabelDict(), trainingWeights.getTemplate(), false, false));
    }

    public static double getScore(TrainingWeights trainingWeights, Instance instance) {
        int length = trainingWeights.getLabelTexts().length;
        int[] labelIds = instance.labelIds();
        double[] computeStateScores = computeStateScores(instance, false, length, trainingWeights.getAttributeWeights());
        int i = labelIds[0];
        double d = trainingWeights.getBosTransitionWeights()[i] + computeStateScores[i];
        for (int i2 = 1; i2 < labelIds.length; i2++) {
            d += trainingWeights.getTransitionWeights()[(labelIds[i2 - 1] * length) + labelIds[i2]] + computeStateScores[(i2 * length) + labelIds[i2]];
        }
        return d + trainingWeights.getEosTransitionWeights()[labelIds[labelIds.length - 1]];
    }

    public static int[] tagId(Instance instance, TrainingWeights trainingWeights) {
        if (instance == null) {
            return new int[0];
        }
        int length = instance.length();
        int length2 = trainingWeights.getLabelTexts().length;
        int[] iArr = new int[length];
        double[] computeStateScores = computeStateScores(instance, false, length2, trainingWeights.getAttributeWeights());
        int[] iArr2 = new int[length * length2];
        double[] dArr = new double[length * length2];
        double[] bosTransitionWeights = trainingWeights.getBosTransitionWeights();
        double[] transitionWeights = trainingWeights.getTransitionWeights();
        double[] eosTransitionWeights = trainingWeights.getEosTransitionWeights();
        for (int i = 0; i < length2; i++) {
            dArr[i] = bosTransitionWeights[i] + computeStateScores[i];
        }
        int i2 = 1;
        int i3 = length2;
        while (true) {
            int i4 = i3;
            if (i2 >= length) {
                break;
            }
            for (int i5 = 0; i5 < length2; i5++) {
                double d = dArr[i4 - length2] + transitionWeights[i5];
                int i6 = 0;
                int i7 = 1;
                int i8 = length2;
                while (true) {
                    int i9 = i8;
                    if (i7 < length2) {
                        double d2 = dArr[(i4 - length2) + i7] + transitionWeights[i9 + i5];
                        if (d2 > d) {
                            d = d2;
                            i6 = i7;
                        }
                        i7++;
                        i8 = i9 + length2;
                    }
                }
                dArr[i4 + i5] = d + computeStateScores[i4 + i5];
                iArr2[i4 + i5] = i6;
            }
            i2++;
            i3 = i4 + length2;
        }
        int i10 = (length - 1) * length2;
        double d3 = dArr[i10] + eosTransitionWeights[0];
        int i11 = 0;
        for (int i12 = 1; i12 < length2; i12++) {
            double d4 = dArr[i10 + i12] + eosTransitionWeights[i12];
            if (d4 > d3) {
                d3 = d4;
                i11 = i12;
            }
        }
        iArr[length - 1] = i11;
        int i13 = length - 2;
        while (i13 >= 0) {
            i11 = iArr2[i10 + i11];
            iArr[i13] = i11;
            i13--;
            i10 -= length2;
        }
        return iArr;
    }

    private static int getFeatureIndex(FeatureDict featureDict, String str, boolean z) {
        int i = featureDict.get(str, false);
        if (i < 0 && z) {
            if (str.contains("_B-") || str.contains("_B+")) {
                return i;
            }
            i = featureDict.get(str, true);
        }
        return i;
    }

    private static int getLabelIndex(TObjectIntHashMap<String> tObjectIntHashMap, String str, boolean z) {
        int i = tObjectIntHashMap.get(str);
        if (i < 0 && z) {
            int size = tObjectIntHashMap.size();
            i = size;
            tObjectIntHashMap.put(str, size);
        }
        return i;
    }

    public static Instance buildInstance(String[] strArr, TagConvertor tagConvertor, FeatureDict featureDict, TObjectIntHashMap<String> tObjectIntHashMap, Template template, boolean z, boolean z2) {
        String[] strArr2 = tagConvertor.tokens2Tags(strArr);
        int length = strArr2.length;
        List<String> arrayList = new ArrayList();
        for (String str : strArr) {
            for (int i = 0; i < str.length(); i++) {
                arrayList.add(str.substring(i, i + 1));
            }
        }
        if (template != null) {
            arrayList = template.expandTemplate(arrayList, length);
        }
        int[] iArr = new int[arrayList.size()];
        int size = arrayList.size() / length;
        int[] iArr2 = new int[length];
        int i2 = 0;
        for (int i3 = 0; i3 < length; i3++) {
            iArr2[i3] = getLabelIndex(tObjectIntHashMap, strArr2[i3], z2);
            for (int i4 = 0; i4 < size; i4++) {
                iArr[i2] = getFeatureIndex(featureDict, arrayList.get(i2), z);
                i2++;
            }
        }
        return new Instance(iArr, iArr2);
    }

    private static Instance buildInstance(List<String> list, int i, List<String> list2, TrainingWeights trainingWeights, boolean z) {
        Template template = trainingWeights.getTemplate();
        TObjectIntHashMap<String> labelDict = trainingWeights.getLabelDict();
        FeatureDict attributeDict = trainingWeights.getAttributeDict();
        if (template != null) {
            list = template.expandTemplate(list, i);
        }
        int[] iArr = new int[list.size()];
        int size = list.size() / i;
        boolean z2 = list2 != null;
        int[] iArr2 = z2 ? new int[i] : null;
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            if (z2) {
                iArr2[i3] = getLabelIndex(labelDict, list2.get(i3), z);
            }
            for (int i4 = 0; i4 < size; i4++) {
                iArr[i2] = getFeatureIndex(attributeDict, list.get(i2), z);
                i2++;
            }
        }
        return z2 ? new Instance(iArr, iArr2) : new Instance(iArr, i);
    }

    public static List<Instance> readTestData(String str, String str2, TrainingWeights trainingWeights) throws IOException {
        return getInstances(str, str2, true, trainingWeights, false);
    }

    public static List<Instance> readTestData2(String str, String str2, TrainingWeights trainingWeights, TagConvertor tagConvertor) throws IOException {
        return getInstances2(str, str2, true, trainingWeights, false, tagConvertor);
    }

    public static EvaluationResult readAndEvaluate(String str, String str2, TrainingWeights trainingWeights, TagConvertor tagConvertor) throws IOException {
        BufferedReader bufferedReader = null;
        String[] labelTexts = trainingWeights.getLabelTexts();
        EvaluationResult evaluationResult = new EvaluationResult(labelTexts);
        try {
            bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), str2));
            int i = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                if (readLine.trim().length() != 0) {
                    i++;
                    Instance buildInstance = buildInstance(readLine.split("\t"), tagConvertor, trainingWeights.getAttributeDict(), trainingWeights.getLabelDict(), trainingWeights.getTemplate(), false, false);
                    if (i % 10000 == 0) {
                        System.out.println(i + " lines evaluated");
                    }
                    int[] tagId = tagId(buildInstance, trainingWeights);
                    evaluationResult.totalItemCount += tagId.length;
                    boolean z = false;
                    for (int i2 = 0; i2 < buildInstance.length(); i2++) {
                        int i3 = tagId[i2];
                        int i4 = buildInstance.labelIds()[i2];
                        for (int i5 = 0; i5 < labelTexts.length; i5++) {
                            int[] iArr = evaluationResult.labelIndex2count[i5];
                            if (i4 == i5) {
                                if (i3 == i5) {
                                    iArr[0] = iArr[0] + 1;
                                } else {
                                    iArr[1] = iArr[1] + 1;
                                }
                            } else if (i3 == i5) {
                                iArr[3] = iArr[3] + 1;
                            } else {
                                iArr[2] = iArr[2] + 1;
                            }
                        }
                        if (i4 == i3) {
                            evaluationResult.correctItemCount++;
                        } else if (!z) {
                            z = true;
                        }
                    }
                    evaluationResult.totalSeqCount++;
                    if (!z) {
                        evaluationResult.correctSeqCount++;
                    }
                }
            }
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            return evaluationResult;
        } catch (Throwable th) {
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            throw th;
        }
    }

    private static TrainingDataSet shrinkAndInit(int i, List<Instance> list, TrainingWeights trainingWeights) {
        if (i > 1) {
            shrinkAttributeDict(list, i, trainingWeights.getAttributeDict());
        }
        TrainingDataSet trainingDataSet = new TrainingDataSet();
        trainingDataSet.setInstances(list);
        int size = trainingWeights.getAttributeDict().size();
        int size2 = trainingWeights.getLabelDict().size();
        trainingDataSet.setAttributeNum(size);
        trainingDataSet.setLabelNum(size2);
        System.out.println("labelNum: " + size2);
        System.out.println("attrNum: " + size);
        trainingWeights.getLabelDict().forEachEntry(new TObjectIntProcedure<String>() { // from class: com.antbrains.crf.SgdCrf.1
            public boolean execute(String str, int i2) {
                System.out.println(str + "\t" + i2);
                return true;
            }
        });
        trainingWeights.setAttributeWeights(new double[size2 * size]);
        trainingWeights.setTransitionWeights(new double[size2 * size2]);
        trainingWeights.setBosTransitionWeights(new double[size2]);
        trainingWeights.setEosTransitionWeights(new double[size2]);
        final String[] strArr = new String[size2];
        trainingWeights.getLabelDict().forEachEntry(new TObjectIntProcedure<String>() { // from class: com.antbrains.crf.SgdCrf.2
            public boolean execute(String str, int i2) {
                strArr[i2] = str;
                return true;
            }
        });
        trainingWeights.setLabelTexts(strArr);
        return trainingDataSet;
    }

    public static TrainingDataSet readTrainingData(String str, String str2, TrainingWeights trainingWeights, int i) throws IOException {
        return shrinkAndInit(i, getInstances(str, str2, true, trainingWeights, true), trainingWeights);
    }

    public static TrainingDataSet readTrainingData2(String str, String str2, TrainingWeights trainingWeights, int i, TagConvertor tagConvertor) throws IOException {
        TObjectIntHashMap<String> labelDict = trainingWeights.getLabelDict();
        Iterator<String> it = tagConvertor.getTags().iterator();
        while (it.hasNext()) {
            labelDict.put(it.next(), labelDict.size());
        }
        return shrinkAndInit(i, getInstances2(str, str2, true, trainingWeights, true, tagConvertor), trainingWeights);
    }

    private static List<Instance> getInstances2(String str, String str2, boolean z, TrainingWeights trainingWeights, boolean z2, TagConvertor tagConvertor) throws IOException {
        BufferedReader bufferedReader = null;
        ArrayList arrayList = new ArrayList();
        try {
            bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), str2));
            int i = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                if (readLine.trim().length() != 0) {
                    i++;
                    arrayList.add(buildInstance(readLine.split("\t"), tagConvertor, trainingWeights.getAttributeDict(), trainingWeights.getLabelDict(), trainingWeights.getTemplate(), z2, false));
                    if (i % 10000 == 0) {
                        System.out.println(i + " lines read");
                    }
                }
            }
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            return arrayList;
        } catch (Throwable th) {
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            throw th;
        }
    }

    private static List<Instance> getInstances(String str, String str2, boolean z, TrainingWeights trainingWeights, boolean z2) throws IOException {
        BufferedReader bufferedReader = null;
        ArrayList arrayList = new ArrayList();
        try {
            BufferedReader bufferedReader2 = new BufferedReader(new InputStreamReader(new FileInputStream(str), str2));
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = z ? new ArrayList() : null;
            int i = 0;
            int i2 = 0;
            System.out.println("extracting instances");
            int i3 = -1;
            while (true) {
                String readLine = bufferedReader2.readLine();
                if (readLine == null) {
                    if (arrayList2.size() > 0) {
                        arrayList.add(buildInstance(arrayList2, arrayList3.size(), arrayList3, trainingWeights, z2));
                    }
                    System.out.println("found " + i + " instances and " + i2 + " items");
                    if (bufferedReader2 != null) {
                        bufferedReader2.close();
                    }
                    return arrayList;
                }
                if (readLine.trim().length() != 0) {
                    String[] split = readLine.split("\\s+");
                    int length = z ? split.length - 1 : split.length;
                    if (i3 < 0) {
                        i3 = length;
                    } else if (i3 != length) {
                        throw new IllegalStateException("inconsistent input format: " + readLine);
                    }
                    if (z) {
                        for (int i4 = 0; i4 < split.length - 1; i4++) {
                            arrayList2.add(split[i4]);
                        }
                        arrayList3.add(split[split.length - 1]);
                    } else {
                        for (String str3 : split) {
                            arrayList2.add(str3);
                        }
                    }
                    i2++;
                } else if (arrayList2.size() > 0) {
                    arrayList.add(buildInstance(arrayList2, arrayList3.size(), arrayList3, trainingWeights, z2));
                    i++;
                    if (i % 10000 == 0) {
                        System.out.println(i + " lines read");
                    }
                    arrayList2.clear();
                    arrayList3.clear();
                }
            }
        } catch (Throwable th) {
            if (0 != 0) {
                bufferedReader.close();
            }
            throw th;
        }
    }

    public static void showUsageAndExit() {
        System.err.println("Usage:");
        System.err.println("\tSgdCrf help");
        System.err.println("\tSgdCrf train <CRF++_format_train_file> <model_file> <crf_train_properties_file> [encoding]");
        System.err.println("\tSgdCrf train2 <tab_sep_text_train_file> <model_file> <crf_train_properties_file> [encoding]");
        System.err.println("\tSgdCrf hdfs-train <hdfs_dir> <model_file> <crf_train_properties_file> <feature_dict> [encoding] [hdfsconf1] [hdfsconf2] ...");
        System.err.println("\tSgdCrf test  <test_file> <model_file> [encoding]");
        System.err.println("\tSgdCrf test2  <test_file> <model_file> [encoding]");
        System.err.println("\tSgdCrf tag <model_file> [nBest] [encoding]");
        System.exit(1);
    }

    public static TrainingParams loadParams(String str) throws IOException {
        TrainingParams trainingParams = new TrainingParams();
        Properties properties = new Properties();
        properties.load(new FileInputStream(new File(str)));
        trainingParams.setMinFeatureFreq(getIntParam(properties, "mininumFeatureFrequency", 1));
        trainingParams.setEta(getDoubleParam(properties, "eta", 0.1d));
        trainingParams.setSigma(getDoubleParam(properties, "sigma", 10.0d));
        trainingParams.setRate(getDoubleParam(properties, "rate", 2.0d));
        trainingParams.setIterationNum(getIntParam(properties, "iterateCount", 100));
        trainingParams.setCandidatesNum(getIntParam(properties, "candidatesNum", 10));
        trainingParams.setSamplesNum(getIntParam(properties, "samplesNum", 1000));
        trainingParams.setT0(getDoubleParam(properties, "t0", 0.0d));
        trainingParams.setTemplates(readTemplates(properties.getProperty("templateFile")));
        return trainingParams;
    }

    public static List<String> readTemplates(String str) throws IOException {
        BufferedReader bufferedReader = null;
        try {
            bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str)));
            ArrayList arrayList = new ArrayList();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String trim = readLine.trim();
                if (!trim.startsWith("#") && !trim.equals("")) {
                    arrayList.add(trim);
                }
            }
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            return arrayList;
        } catch (Throwable th) {
            if (bufferedReader != null) {
                bufferedReader.close();
            }
            throw th;
        }
    }

    private static int shrinkAttributeDict(List<Instance> list, int i, FeatureDict featureDict) {
        int[] iArr = new int[featureDict.size()];
        Iterator<Instance> it = list.iterator();
        while (it.hasNext()) {
            for (int i2 : it.next().getAttrIds()) {
                if (i2 >= 0) {
                    iArr[i2] = iArr[i2] + 1;
                }
            }
        }
        int i3 = 0;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (iArr[i4] > i) {
                int i5 = i3;
                i3++;
                iArr[i4] = i5;
            } else {
                iArr[i4] = -1;
            }
        }
        TObjectIntIterator<String> it2 = featureDict.iterator();
        int i6 = 0;
        while (it2.hasNext()) {
            it2.advance();
            int i7 = iArr[it2.value()];
            if (i7 < 0) {
                it2.remove();
                i6++;
            } else {
                it2.setValue(i7);
            }
        }
        Iterator<Instance> it3 = list.iterator();
        while (it3.hasNext()) {
            int[] attrIds = it3.next().getAttrIds();
            for (int i8 = 0; i8 < attrIds.length; i8++) {
                int i9 = attrIds[i8];
                if (i9 >= 0) {
                    attrIds[i8] = iArr[i9];
                }
            }
        }
        return i6;
    }

    private static int getIntParam(Properties properties, String str, int i) {
        return properties.containsKey(str) ? Integer.valueOf(properties.getProperty(str)).intValue() : i;
    }

    private static double getDoubleParam(Properties properties, String str, double d) {
        return properties.containsKey(str) ? Double.valueOf(properties.getProperty(str)).doubleValue() : d;
    }

    private static Instance buildInstance4Explanation(List<String> list, int i, List<String> list2, Map<Integer, String> map, Template template, FeatureDict featureDict, TObjectIntHashMap<String> tObjectIntHashMap) {
        if (template != null) {
            list = template.expandTemplate(list, i);
        }
        int[] iArr = new int[list.size()];
        int size = list.size() / i;
        boolean z = list2 != null;
        int[] iArr2 = z ? new int[i] : null;
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            if (z) {
                iArr2[i3] = getLabelIndex(tObjectIntHashMap, list2.get(i3), false);
            }
            for (int i4 = 0; i4 < size; i4++) {
                iArr[i2] = getFeatureIndex(featureDict, list.get(i2), false);
                map.put(Integer.valueOf(iArr[i2]), list.get(i2));
                i2++;
            }
        }
        return z ? new Instance(iArr, iArr2) : new Instance(iArr, i);
    }

    private static double[] computeStateScores4Explanation(Instance instance, boolean z, FeatureWeightScore[][] featureWeightScoreArr, Map<Integer, String> map, int i, double[] dArr) {
        int length = instance.length();
        int rowSize = instance.rowSize();
        int[] attrIds = instance.getAttrIds();
        double[] dArr2 = new double[length * i];
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < rowSize; i3++) {
                int i4 = attrIds[(i2 * rowSize) + i3];
                if (i4 >= 0) {
                    for (int i5 = 0; i5 < i; i5++) {
                        int i6 = (i2 * i) + i5;
                        dArr2[i6] = dArr2[i6] + dArr[(i4 * i) + i5];
                        featureWeightScoreArr[i2][i5].features.add(map.get(Integer.valueOf(i4)));
                        featureWeightScoreArr[i2][i5].weights.add(Double.valueOf(dArr[(i4 * i) + i5]));
                    }
                }
            }
        }
        if (z) {
            for (int i7 = 0; i7 < length; i7++) {
                for (int i8 = 0; i8 < i; i8++) {
                    dArr2[(i7 * i) + i8] = Math.exp(dArr2[(i7 * i) + i8]);
                }
            }
        }
        for (int i9 = 0; i9 < length; i9++) {
            for (int i10 = 0; i10 < i; i10++) {
                featureWeightScoreArr[i9][i10].score = dArr2[(i9 * i) + i10];
                ArrayList<String> arrayList = featureWeightScoreArr[i9][i10].features;
                ArrayList<Double> arrayList2 = featureWeightScoreArr[i9][i10].weights;
                ArrayList<Object[]> arrayList3 = new ArrayList(arrayList.size());
                for (int i11 = 0; i11 < arrayList.size(); i11++) {
                    arrayList3.add(new Object[]{arrayList.get(i11), arrayList2.get(i11)});
                }
                Collections.sort(arrayList3, new Comparator<Object[]>() { // from class: com.antbrains.crf.SgdCrf.3
                    @Override // java.util.Comparator
                    public int compare(Object[] objArr, Object[] objArr2) {
                        return Math.abs(((Double) objArr[1]).doubleValue()) >= Math.abs(((Double) objArr2[1]).doubleValue()) ? -1 : 1;
                    }
                });
                arrayList.clear();
                arrayList2.clear();
                for (Object[] objArr : arrayList3) {
                    arrayList.add((String) objArr[0]);
                    arrayList2.add((Double) objArr[1]);
                }
            }
        }
        return dArr2;
    }

    public static Explanation explain(String str, CrfModel crfModel) {
        ArrayList arrayList = new ArrayList(str.length());
        for (int i = 0; i < str.length(); i++) {
            arrayList.add(str.charAt(i) + "");
        }
        Explanation tagAndExplain = tagAndExplain(arrayList, str.length(), crfModel);
        tagAndExplain.tokens = arrayList;
        return tagAndExplain;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v33, types: [com.antbrains.crf.FeatureWeightScore[], com.antbrains.crf.FeatureWeightScore[][]] */
    public static Explanation tagAndExplain(List<String> list, int i, CrfModel crfModel) {
        double[] bosTransitionWeights = crfModel.weights.getBosTransitionWeights();
        double[] transitionWeights = crfModel.weights.getTransitionWeights();
        double[] eosTransitionWeights = crfModel.weights.getEosTransitionWeights();
        double[] attributeWeights = crfModel.weights.getAttributeWeights();
        Template template = crfModel.weights.getTemplate();
        FeatureDict attributeDict = crfModel.weights.getAttributeDict();
        TObjectIntHashMap<String> labelDict = crfModel.weights.getLabelDict();
        Explanation explanation = new Explanation();
        HashMap hashMap = new HashMap();
        Instance buildInstance4Explanation = buildInstance4Explanation(list, i, null, hashMap, template, attributeDict, labelDict);
        if (buildInstance4Explanation == null) {
            explanation.bestTagIds = new int[0];
            return explanation;
        }
        int[] iArr = new int[i];
        int size = crfModel.weights.getLabelDict().size();
        ?? r0 = new FeatureWeightScore[i];
        for (int i2 = 0; i2 < r0.length; i2++) {
            r0[i2] = new FeatureWeightScore[size];
            for (int i3 = 0; i3 < r0[i2].length; i3++) {
                r0[i2][i3] = new FeatureWeightScore();
            }
        }
        explanation.details = r0;
        double[] computeStateScores4Explanation = computeStateScores4Explanation(buildInstance4Explanation, false, r0, hashMap, size, attributeWeights);
        int[] iArr2 = new int[i * size];
        double[] dArr = new double[i * size];
        for (int i4 = 0; i4 < size; i4++) {
            dArr[i4] = bosTransitionWeights[i4] + computeStateScores4Explanation[i4];
        }
        int i5 = 1;
        int i6 = size;
        while (true) {
            int i7 = i6;
            if (i5 >= i) {
                break;
            }
            for (int i8 = 0; i8 < size; i8++) {
                double d = dArr[i7 - size] + transitionWeights[i8];
                int i9 = 0;
                int i10 = 1;
                int i11 = size;
                while (true) {
                    int i12 = i11;
                    if (i10 < size) {
                        double d2 = dArr[(i7 - size) + i10] + transitionWeights[i12 + i8];
                        if (d2 > d) {
                            d = d2;
                            i9 = i10;
                        }
                        i10++;
                        i11 = i12 + size;
                    }
                }
                dArr[i7 + i8] = d + computeStateScores4Explanation[i7 + i8];
                iArr2[i7 + i8] = i9;
            }
            i5++;
            i6 = i7 + size;
        }
        int i13 = (i - 1) * size;
        double d3 = dArr[i13] + eosTransitionWeights[0];
        int i14 = 0;
        for (int i15 = 1; i15 < size; i15++) {
            double d4 = dArr[i13 + i15] + eosTransitionWeights[i15];
            if (d4 > d3) {
                d3 = d4;
                i14 = i15;
            }
        }
        iArr[i - 1] = i14;
        int i16 = i - 2;
        while (i16 >= 0) {
            i14 = iArr2[i13 + i14];
            iArr[i16] = i14;
            i16--;
            i13 -= size;
        }
        explanation.bestTagIds = iArr;
        explanation.bosTransitionWeights = bosTransitionWeights;
        explanation.eosTransitionWeights = eosTransitionWeights;
        explanation.transitionWeights = transitionWeights;
        explanation.labelTexts = crfModel.weights.getLabelTexts();
        return explanation;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static List<String[]> tagNBest(Instance instance, int i, double[] dArr, CrfModel crfModel) {
        ArrayList arrayList = new ArrayList(i);
        if (instance == null) {
            return arrayList;
        }
        int length = instance.length();
        int size = crfModel.weights.getLabelDict().size();
        double[] attributeWeights = crfModel.weights.getAttributeWeights();
        double[] bosTransitionWeights = crfModel.weights.getBosTransitionWeights();
        double[] transitionWeights = crfModel.weights.getTransitionWeights();
        double[] eosTransitionWeights = crfModel.weights.getEosTransitionWeights();
        String[] labelTexts = crfModel.weights.getLabelTexts();
        double[] computeStateScores = computeStateScores(instance, false, size, attributeWeights);
        int[][][] iArr = new int[length][];
        double[][] dArr2 = new double[length];
        dArr2[0] = new double[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr2[0][i2] = new double[i];
        }
        for (int i3 = 0; i3 < size; i3++) {
            dArr2[0][i3][0] = computeStateScores[i3] + bosTransitionWeights[i3];
            for (int i4 = 1; i4 < i; i4++) {
                dArr2[0][i3][i4] = -4503599627370497;
            }
        }
        for (int i5 = 1; i5 < length; i5++) {
            dArr2[i5] = new double[size];
            iArr[i5] = new int[size];
            for (int i6 = 0; i6 < size; i6++) {
                double[] dArr3 = new double[i];
                int[][] iArr2 = new int[i][2];
                int i7 = 0;
                dArr3[0] = -1.7976931348623157E308d;
                for (int i8 = 0; i8 < size; i8++) {
                    for (int i9 = 0; i9 < i && dArr2[i5 - 1][i8][i9] != -1.7976931348623157E308d; i9++) {
                        double d = dArr2[i5 - 1][i8][i9] + transitionWeights[(i8 * size) + i6];
                        if (i7 < i || d > dArr3[i - 1]) {
                            int i10 = 0;
                            while (i10 < i && d <= dArr3[i10]) {
                                i10++;
                            }
                            for (int i11 = i - 1; i11 > i10; i11--) {
                                dArr3[i11] = dArr3[i11 - 1];
                                iArr2[i11] = iArr2[i11 - 1];
                            }
                            dArr3[i10] = d;
                            int[] iArr3 = new int[2];
                            iArr3[0] = i8;
                            iArr3[1] = i9;
                            iArr2[i10] = iArr3;
                            i7++;
                        }
                    }
                }
                for (int i12 = 0; i12 < i7 && i12 < i; i12++) {
                    int i13 = i12;
                    dArr3[i13] = dArr3[i13] + computeStateScores[(i5 * size) + i6];
                }
                dArr2[i5][i6] = dArr3;
                iArr[i5][i6] = iArr2;
            }
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i14 = 0; i14 < size; i14++) {
            Object[] objArr = dArr2[length - 1][i14];
            for (int i15 = 0; i15 < i; i15++) {
                long j = objArr[i15];
                if (j == -1.7976931348623157E308d) {
                    break;
                }
                arrayList2.add(new Object[]{Double.valueOf(j + eosTransitionWeights[i14]), Integer.valueOf(i14), Integer.valueOf(i15)});
            }
        }
        Collections.sort(arrayList2, new Comparator<Object[]>() { // from class: com.antbrains.crf.SgdCrf.4
            @Override // java.util.Comparator
            public int compare(Object[] objArr2, Object[] objArr3) {
                return ((Double) objArr2[0]).doubleValue() >= ((Double) objArr3[0]).doubleValue() ? -1 : 1;
            }
        });
        for (int i16 = 0; i16 < i; i16++) {
            String[] strArr = new String[length];
            Object[] objArr2 = (Object[]) arrayList2.get(i16);
            int intValue = ((Integer) objArr2[1]).intValue();
            int intValue2 = ((Integer) objArr2[2]).intValue();
            strArr[length - 1] = labelTexts[intValue];
            dArr[i16] = ((Double) objArr2[0]).doubleValue();
            for (int i17 = length - 2; i17 >= 0; i17--) {
                Object[] objArr3 = iArr[i17 + 1][intValue][intValue2];
                intValue = objArr3[0];
                intValue2 = objArr3[1];
                strArr[i17] = labelTexts[intValue];
            }
            arrayList.add(strArr);
        }
        return arrayList;
    }

    public static String[] tagId2Text(int[] iArr, CrfModel crfModel) {
        String[] labelTexts = crfModel.weights.getLabelTexts();
        String[] strArr = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr[i] = labelTexts[iArr[i]];
        }
        return strArr;
    }

    public static String[] tag(List<String> list, CrfModel crfModel) {
        return tagId2Text(tagId(buildInstance(list, list.size(), null, crfModel.weights, false), crfModel.weights), crfModel);
    }

    public static List<String> segment(String str, CrfModel crfModel, TagConvertor tagConvertor) {
        ArrayList arrayList = new ArrayList(str.length());
        for (int i = 0; i < str.length(); i++) {
            arrayList.add(str.charAt(i) + "");
        }
        return tagConvertor.tags2TokenList(tagId2Text(tagId(buildInstance(arrayList, arrayList.size(), null, crfModel.weights, false), crfModel.weights), crfModel), str);
    }

    public static List<String[]> segment(String str, CrfModel crfModel, TagConvertor tagConvertor, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(str.length());
        for (int i2 = 0; i2 < str.length(); i2++) {
            arrayList2.add(str.charAt(i2) + "");
        }
        Iterator<String[]> it = tagNBest(buildInstance(arrayList2, arrayList2.size(), null, crfModel.weights, false), i, new double[i], crfModel).iterator();
        while (it.hasNext()) {
            arrayList.add(tagConvertor.tags2Tokens(it.next(), str));
        }
        return arrayList;
    }

    public static void main(String[] strArr) throws Exception {
        BufferedReader bufferedReader;
        BufferedWriter bufferedWriter;
        if (strArr.length < 1) {
            showUsageAndExit();
        }
        String str = strArr[0];
        if (str.equals("help")) {
            showUsageAndExit();
            return;
        }
        if (str.equals("train")) {
            if (strArr.length != 4 && strArr.length != 5) {
                showUsageAndExit();
            }
            String str2 = strArr[1];
            String str3 = strArr[2];
            String str4 = strArr[3];
            String str5 = strArr.length > 4 ? strArr[4] : "UTF8";
            TrainingParams loadParams = loadParams(str4);
            TrainingWeights trainingWeights = new TrainingWeights(new Template((String[]) loadParams.getTemplates().toArray(new String[0])), FeatureDictEnum.TROVE_HASHMAP);
            train(readTrainingData(str2, str5, trainingWeights, loadParams.getMinFeatureFreq()), 0, loadParams.getIterationNum(), loadParams, trainingWeights, new PrintTrainingProgress());
            saveModel(loadParams, trainingWeights, str3);
            return;
        }
        if (str.equals("train2")) {
            if (strArr.length != 4 && strArr.length != 5) {
                showUsageAndExit();
            }
            String str6 = strArr[1];
            String str7 = strArr[2];
            String str8 = strArr[3];
            String str9 = strArr.length > 4 ? strArr[4] : "UTF8";
            TrainingParams loadParams2 = loadParams(str8);
            TrainingWeights trainingWeights2 = new TrainingWeights(new Template((String[]) loadParams2.getTemplates().toArray(new String[0])), FeatureDictEnum.TROVE_HASHMAP);
            train(readTrainingData2(str6, str9, trainingWeights2, loadParams2.getMinFeatureFreq(), new BESB1B2MTagConvertor()), 0, loadParams2.getIterationNum(), loadParams2, trainingWeights2, new PrintTrainingProgress());
            saveModel(loadParams2, trainingWeights2, str7);
            return;
        }
        if (str.equals("test")) {
            if (strArr.length != 3 && strArr.length != 4) {
                showUsageAndExit();
            }
            String str10 = strArr[1];
            String str11 = strArr[2];
            String str12 = strArr.length > 3 ? strArr[3] : "UTF8";
            CrfModel loadModel = loadModel(str11);
            System.out.println(evaluate(readTestData(str10, str12, loadModel.weights), loadModel.weights));
            return;
        }
        if (str.equals("test2")) {
            if (strArr.length != 3 && strArr.length != 4) {
                showUsageAndExit();
            }
            System.out.println(readAndEvaluate(strArr[1], strArr.length > 3 ? strArr[3] : "UTF8", loadModel(strArr[2]).weights, new BESB1B2MTagConvertor()));
            return;
        }
        if (!str.equals("seg")) {
            System.err.println("unknown command: " + str);
            showUsageAndExit();
            return;
        }
        if (strArr.length != 3 && strArr.length != 2 && strArr.length != 4) {
            showUsageAndExit();
        }
        String str13 = strArr[1];
        int intValue = strArr.length > 2 ? Integer.valueOf(strArr[2]).intValue() : 1;
        String str14 = strArr.length > 3 ? strArr[3] : "";
        CrfModel loadModel2 = loadModel(str13);
        if (str14.equals("")) {
            bufferedReader = new BufferedReader(new InputStreamReader(System.in));
            bufferedWriter = new BufferedWriter(new OutputStreamWriter(System.out));
        } else {
            bufferedReader = new BufferedReader(new InputStreamReader(System.in, str14));
            bufferedWriter = new BufferedWriter(new OutputStreamWriter(System.out, str14));
        }
        bufferedWriter.write("Enter Chinese sentences to be segment, enter quit to exit!\n");
        bufferedWriter.flush();
        BESB1B2MTagConvertor bESB1B2MTagConvertor = new BESB1B2MTagConvertor();
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null || readLine.trim().equals("quit")) {
                break;
            }
            if (!readLine.trim().equals("")) {
                bufferedWriter.write("Input: " + readLine + "\n");
                if (intValue < 2) {
                    boolean z = true;
                    for (String str15 : segment(readLine, loadModel2, bESB1B2MTagConvertor)) {
                        if (z) {
                            z = false;
                        } else {
                            bufferedWriter.write("\t");
                        }
                        bufferedWriter.write(str15);
                    }
                    bufferedWriter.write("\n");
                    bufferedWriter.write("Enter Chinese sentences to be segment, enter quit to exit!\n");
                    bufferedWriter.flush();
                } else {
                    for (String[] strArr2 : segment(readLine, loadModel2, bESB1B2MTagConvertor, intValue)) {
                        boolean z2 = true;
                        for (String str16 : strArr2) {
                            if (z2) {
                                z2 = false;
                            } else {
                                bufferedWriter.write("\t");
                            }
                            bufferedWriter.write(str16);
                        }
                        bufferedWriter.write("\n");
                    }
                    bufferedWriter.write("Enter Chinese sentences to be segment, enter quit to exit!\n");
                    bufferedWriter.flush();
                }
            }
        }
        bufferedReader.close();
        bufferedWriter.close();
    }
}
