package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.dmg.pmml.general_regression.PPMatrix;
import org.dmg.pmml.general_regression.ParamMatrix;
import org.dmg.pmml.general_regression.ParameterList;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.general_regression.GeneralRegressionModelUtil;
import org.jpmml.sparkml.RegressionModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.VectorUtil;

/* loaded from: input_file:org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter.class */
public class GeneralizedLinearRegressionModelConverter extends RegressionModelConverter<GeneralizedLinearRegressionModel> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.model.GeneralizedLinearRegressionModelConverter$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/model/GeneralizedLinearRegressionModelConverter$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
        }
    }

    public GeneralizedLinearRegressionModelConverter(GeneralizedLinearRegressionModel generalizedLinearRegressionModel) {
        super(generalizedLinearRegressionModel);
    }

    @Override // org.jpmml.sparkml.RegressionModelConverter, org.jpmml.sparkml.ModelConverter
    public MiningFunction getMiningFunction() {
        String family = ((GeneralizedLinearRegressionModel) getTransformer()).getFamily();
        boolean z = -1;
        switch (family.hashCode()) {
            case 950395663:
                if (family.equals("binomial")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MiningFunction.CLASSIFICATION;
            default:
                return MiningFunction.REGRESSION;
        }
    }

    @Override // org.jpmml.sparkml.RegressionModelConverter, org.jpmml.sparkml.ModelConverter
    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder sparkMLEncoder) {
        List<OutputField> registerOutputFields = super.registerOutputFields(label, sparkMLEncoder);
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[getMiningFunction().ordinal()]) {
            case 1:
                registerOutputFields = new ArrayList(registerOutputFields);
                registerOutputFields.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, ((CategoricalLabel) label).getValues()));
                break;
        }
        return registerOutputFields;
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public GeneralRegressionModel mo6encodeModel(Schema schema) {
        GeneralizedLinearRegressionModel generalizedLinearRegressionModel = (GeneralizedLinearRegressionModel) getTransformer();
        String str = null;
        MiningFunction miningFunction = getMiningFunction();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
            case 1:
                CategoricalLabel label = schema.getLabel();
                if (label.size() == 2) {
                    str = label.getValue(1);
                    break;
                } else {
                    throw new IllegalArgumentException();
                }
        }
        GeneralRegressionModel linkParameter = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), (ParameterList) null, (PPMatrix) null, (ParamMatrix) null).setDistribution(parseFamily(generalizedLinearRegressionModel.getFamily())).setLinkFunction(parseLinkFunction(generalizedLinearRegressionModel.getLink())).setLinkParameter(parseLinkParameter(generalizedLinearRegressionModel.getLink()));
        GeneralRegressionModelUtil.encodeRegressionTable(linkParameter, schema.getFeatures(), Double.valueOf(generalizedLinearRegressionModel.intercept()), VectorUtil.toList(generalizedLinearRegressionModel.coefficients()), str);
        return linkParameter;
    }

    private static GeneralRegressionModel.Distribution parseFamily(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1526272517:
                if (str.equals("gaussian")) {
                    z = 2;
                    break;
                }
                break;
            case -400457335:
                if (str.equals("poisson")) {
                    z = 3;
                    break;
                }
                break;
            case 98120615:
                if (str.equals("gamma")) {
                    z = true;
                    break;
                }
                break;
            case 950395663:
                if (str.equals("binomial")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return GeneralRegressionModel.Distribution.BINOMIAL;
            case true:
                return GeneralRegressionModel.Distribution.GAMMA;
            case true:
                return GeneralRegressionModel.Distribution.NORMAL;
            case true:
                return GeneralRegressionModel.Distribution.POISSON;
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private static GeneralRegressionModel.LinkFunction parseLinkFunction(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -979816640:
                if (str.equals("probit")) {
                    z = 5;
                    break;
                }
                break;
            case -135761730:
                if (str.equals("identity")) {
                    z = true;
                    break;
                }
                break;
            case 107332:
                if (str.equals("log")) {
                    z = 3;
                    break;
                }
                break;
            case 3538208:
                if (str.equals("sqrt")) {
                    z = 6;
                    break;
                }
                break;
            case 103149423:
                if (str.equals("logit")) {
                    z = 4;
                    break;
                }
                break;
            case 866186147:
                if (str.equals("cloglog")) {
                    z = false;
                    break;
                }
                break;
            case 1959910192:
                if (str.equals("inverse")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return GeneralRegressionModel.LinkFunction.CLOGLOG;
            case true:
                return GeneralRegressionModel.LinkFunction.IDENTITY;
            case true:
                return GeneralRegressionModel.LinkFunction.POWER;
            case true:
                return GeneralRegressionModel.LinkFunction.LOG;
            case true:
                return GeneralRegressionModel.LinkFunction.LOGIT;
            case true:
                return GeneralRegressionModel.LinkFunction.PROBIT;
            case true:
                return GeneralRegressionModel.LinkFunction.POWER;
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private static Double parseLinkParameter(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 3538208:
                if (str.equals("sqrt")) {
                    z = true;
                    break;
                }
                break;
            case 1959910192:
                if (str.equals("inverse")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return Double.valueOf(-1.0d);
            case true:
                return Double.valueOf(0.5d);
            default:
                return null;
        }
    }
}
