package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.MatrixUtil;
import org.jpmml.sparkml.VectorUtil;

/* loaded from: input_file:org/jpmml/sparkml/model/LogisticRegressionModelConverter.class */
public class LogisticRegressionModelConverter extends ClassificationModelConverter<LogisticRegressionModel> {
    public LogisticRegressionModelConverter(LogisticRegressionModel logisticRegressionModel) {
        super(logisticRegressionModel);
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public RegressionModel mo7encodeModel(Schema schema) {
        LogisticRegressionModel logisticRegressionModel = (LogisticRegressionModel) getTransformer();
        CategoricalLabel label = schema.getLabel();
        if (label.size() == 2) {
            return RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(logisticRegressionModel.coefficients()), Double.valueOf(logisticRegressionModel.intercept()), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput((Output) null);
        }
        if (label.size() <= 2) {
            throw new IllegalArgumentException();
        }
        Matrix coefficientMatrix = logisticRegressionModel.coefficientMatrix();
        Vector interceptVector = logisticRegressionModel.interceptVector();
        List features = schema.getFeatures();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < label.size(); i++) {
            arrayList.add(RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), Double.valueOf(interceptVector.apply(i))).setTargetCategory(label.getValue(i)));
        }
        return new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), arrayList).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
    }
}
