package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.general_regression.CovariateList;
import org.dmg.pmml.general_regression.FactorList;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PCell;
import org.dmg.pmml.general_regression.PPCell;
import org.dmg.pmml.general_regression.PPMatrix;
import org.dmg.pmml.general_regression.ParamMatrix;
import org.dmg.pmml.general_regression.Parameter;
import org.dmg.pmml.general_regression.ParameterList;
import org.dmg.pmml.general_regression.Predictor;
import org.dmg.pmml.general_regression.PredictorList;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.InteractionFeature;
import org.jpmml.sparkml.RegressionModelConverter;

/* loaded from: input_file:org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter.class */
public class GeneralizedLinearRegressionModelConverter extends RegressionModelConverter<GeneralizedLinearRegressionModel> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.model.GeneralizedLinearRegressionModelConverter$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
        }
    }

    public GeneralizedLinearRegressionModelConverter(GeneralizedLinearRegressionModel generalizedLinearRegressionModel) {
        super(generalizedLinearRegressionModel);
    }

    @Override // org.jpmml.sparkml.RegressionModelConverter, org.jpmml.sparkml.ModelConverter
    public MiningFunction getMiningFunction() {
        String family = ((GeneralizedLinearRegressionModel) getTransformer()).getFamily();
        boolean z = -1;
        switch (family.hashCode()) {
            case 950395663:
                if (family.equals("binomial")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MiningFunction.CLASSIFICATION;
            default:
                return MiningFunction.REGRESSION;
        }
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public GeneralRegressionModel mo4encodeModel(Schema schema) {
        GeneralizedLinearRegressionModel generalizedLinearRegressionModel = (GeneralizedLinearRegressionModel) getTransformer();
        double intercept = generalizedLinearRegressionModel.intercept();
        Vector coefficients = generalizedLinearRegressionModel.coefficients();
        List features = schema.getFeatures();
        if (features.size() != coefficients.size()) {
            throw new IllegalArgumentException();
        }
        String str = null;
        List targetCategories = schema.getTargetCategories();
        if (targetCategories != null && targetCategories.size() > 0) {
            if (targetCategories.size() != 2) {
                throw new IllegalArgumentException();
            }
            str = (String) targetCategories.get(1);
        }
        ParameterList parameterList = new ParameterList();
        PPMatrix pPMatrix = new PPMatrix();
        ParamMatrix paramMatrix = new ParamMatrix();
        if (!ValueUtil.isZero(Double.valueOf(intercept))) {
            Parameter label = new Parameter("p0").setLabel("(intercept)");
            parameterList.addParameters(new Parameter[]{label});
            paramMatrix.addPCells(new PCell[]{new PCell(label.getName(), intercept).setTargetCategory(str)});
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        for (int i = 0; i < features.size(); i++) {
            Feature feature = (Feature) features.get(i);
            Parameter parameter = new Parameter("p" + String.valueOf(i + 1));
            parameterList.addParameters(new Parameter[]{parameter});
            List<PPCell> createPPCells = createPPCells(parameter, feature, linkedHashSet, linkedHashSet2);
            pPMatrix.addPPCells((PPCell[]) createPPCells.toArray(new PPCell[createPPCells.size()]));
            paramMatrix.addPCells(new PCell[]{new PCell(parameter.getName(), coefficients.apply(i)).setTargetCategory(str)});
        }
        MiningFunction miningFunction = str != null ? MiningFunction.CLASSIFICATION : MiningFunction.REGRESSION;
        GeneralRegressionModel factorList = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema), parameterList, pPMatrix, paramMatrix).setDistribution(parseFamily(generalizedLinearRegressionModel.getFamily())).setLinkFunction(parseLink(generalizedLinearRegressionModel.getLink())).setCovariateList(createPredictorList(new CovariateList(), linkedHashSet)).setFactorList(createPredictorList(new FactorList(), linkedHashSet2));
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
            case 1:
                factorList.setOutput(ModelUtil.createProbabilityOutput(schema));
                break;
        }
        return factorList;
    }

    private static List<PPCell> createPPCells(Parameter parameter, Feature feature, Set<FieldName> set, Set<FieldName> set2) {
        if (feature instanceof InteractionFeature) {
            ArrayList arrayList = new ArrayList();
            Iterator<Feature> it = ((InteractionFeature) feature).getFeatures().iterator();
            while (it.hasNext()) {
                arrayList.addAll(createPPCells(parameter, it.next(), set, set2));
            }
            return arrayList;
        }
        if (feature instanceof ContinuousFeature) {
            ContinuousFeature continuousFeature = (ContinuousFeature) feature;
            set.add(continuousFeature.getName());
            return Collections.singletonList(new PPCell("1", continuousFeature.getName(), parameter.getName()));
        }
        if (!(feature instanceof BinaryFeature)) {
            throw new IllegalArgumentException();
        }
        BinaryFeature binaryFeature = (BinaryFeature) feature;
        set2.add(binaryFeature.getName());
        return Collections.singletonList(new PPCell(binaryFeature.getValue(), binaryFeature.getName(), parameter.getName()));
    }

    private static GeneralRegressionModel.Distribution parseFamily(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1526272517:
                if (str.equals("gaussian")) {
                    z = 2;
                    break;
                }
                break;
            case -400457335:
                if (str.equals("poisson")) {
                    z = 3;
                    break;
                }
                break;
            case 98120615:
                if (str.equals("gamma")) {
                    z = true;
                    break;
                }
                break;
            case 950395663:
                if (str.equals("binomial")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return GeneralRegressionModel.Distribution.BINOMIAL;
            case true:
                return GeneralRegressionModel.Distribution.GAMMA;
            case true:
                return GeneralRegressionModel.Distribution.NORMAL;
            case true:
                return GeneralRegressionModel.Distribution.POISSON;
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private static GeneralRegressionModel.LinkFunction parseLink(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -979816640:
                if (str.equals("probit")) {
                    z = 4;
                    break;
                }
                break;
            case -135761730:
                if (str.equals("identity")) {
                    z = true;
                    break;
                }
                break;
            case 107332:
                if (str.equals("log")) {
                    z = 2;
                    break;
                }
                break;
            case 103149423:
                if (str.equals("logit")) {
                    z = 3;
                    break;
                }
                break;
            case 866186147:
                if (str.equals("cloglog")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return GeneralRegressionModel.LinkFunction.CLOGLOG;
            case true:
                return GeneralRegressionModel.LinkFunction.IDENTITY;
            case true:
                return GeneralRegressionModel.LinkFunction.LOG;
            case true:
                return GeneralRegressionModel.LinkFunction.LOGIT;
            case true:
                return GeneralRegressionModel.LinkFunction.PROBIT;
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private static <L extends PredictorList> L createPredictorList(L l, Set<FieldName> set) {
        if (set.isEmpty()) {
            return null;
        }
        List predictors = l.getPredictors();
        Iterator<FieldName> it = set.iterator();
        while (it.hasNext()) {
            predictors.add(new Predictor(it.next()));
        }
        return l;
    }
}
