package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.SparkMLEncoder;

/* loaded from: input_file:org/jpmml/sparkml/model/MultilayerPerceptronClassificationModelConverter.class */
public class MultilayerPerceptronClassificationModelConverter extends ClassificationModelConverter<MultilayerPerceptronClassificationModel> {
    public MultilayerPerceptronClassificationModelConverter(MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel) {
        super(multilayerPerceptronClassificationModel);
    }

    @Override // org.jpmml.sparkml.ClassificationModelConverter, org.jpmml.sparkml.ModelConverter
    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder sparkMLEncoder) {
        MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel = (MultilayerPerceptronClassificationModel) getTransformer();
        List<OutputField> registerOutputFields = super.registerOutputFields(label, sparkMLEncoder);
        if (!(multilayerPerceptronClassificationModel instanceof HasProbabilityCol)) {
            registerOutputFields = new ArrayList(registerOutputFields);
            registerOutputFields.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, ((CategoricalLabel) label).getValues()));
        }
        return registerOutputFields;
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public NeuralNetwork mo10encodeModel(Schema schema) {
        MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel = (MultilayerPerceptronClassificationModel) getTransformer();
        int[] layers = multilayerPerceptronClassificationModel.layers();
        Vector weights = multilayerPerceptronClassificationModel.weights();
        CategoricalLabel label = schema.getLabel();
        if (label.size() != layers[layers.length - 1]) {
            throw new IllegalArgumentException();
        }
        List features = schema.getFeatures();
        if (features.size() != layers[0]) {
            throw new IllegalArgumentException();
        }
        NeuralInputs createNeuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
        List neuralInputs = createNeuralInputs.getNeuralInputs();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 1; i2 < layers.length; i2++) {
            NeuralLayer neuralLayer = new NeuralLayer();
            int size = neuralInputs.size();
            int i3 = layers[i2];
            ArrayList arrayList2 = new ArrayList();
            for (int i4 = 0; i4 < i3; i4++) {
                ArrayList arrayList3 = new ArrayList();
                for (int i5 = 0; i5 < size; i5++) {
                    arrayList3.add(Double.valueOf(weights.apply(i + (i5 * i3) + i4)));
                }
                arrayList2.add(arrayList3);
            }
            i += size * i3;
            for (int i6 = 0; i6 < i3; i6++) {
                neuralLayer.addNeurons(new Neuron[]{NeuralNetworkUtil.createNeuron(neuralInputs, (List) arrayList2.get(i6), Double.valueOf(weights.apply(i))).setId(String.valueOf(i2) + "/" + String.valueOf(i6 + 1))});
                i++;
            }
            if (i2 == layers.length - 1) {
                neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY).setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
            }
            arrayList.add(neuralLayer);
            neuralInputs = neuralLayer.getNeurons();
        }
        if (i != weights.size()) {
            throw new IllegalArgumentException();
        }
        return new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema(label), createNeuralInputs, arrayList).setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(neuralInputs, label));
    }
}
