package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ClassificationModelConverter;

/* loaded from: input_file:org/jpmml/sparkml/model/MultilayerPerceptronClassificationModelConverter.class */
public class MultilayerPerceptronClassificationModelConverter extends ClassificationModelConverter<MultilayerPerceptronClassificationModel> {
    public MultilayerPerceptronClassificationModelConverter(MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel) {
        super(multilayerPerceptronClassificationModel);
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public NeuralNetwork mo4encodeModel(Schema schema) {
        FieldRef normDiscrete;
        MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel = (MultilayerPerceptronClassificationModel) getTransformer();
        int[] layers = multilayerPerceptronClassificationModel.layers();
        Vector weights = multilayerPerceptronClassificationModel.weights();
        List features = schema.getFeatures();
        if (features.size() != layers[0]) {
            throw new IllegalArgumentException();
        }
        FieldName targetField = schema.getTargetField();
        List targetCategories = schema.getTargetCategories();
        if (targetCategories.size() != layers[layers.length - 1]) {
            throw new IllegalArgumentException();
        }
        NeuralInputs neuralInputs = new NeuralInputs();
        for (int i = 0; i < features.size(); i++) {
            ContinuousFeature continuousFeature = (Feature) features.get(i);
            if (continuousFeature instanceof ContinuousFeature) {
                normDiscrete = new FieldRef(continuousFeature.getName());
            } else {
                if (!(continuousFeature instanceof BinaryFeature)) {
                    throw new IllegalArgumentException();
                }
                BinaryFeature binaryFeature = (BinaryFeature) continuousFeature;
                normDiscrete = new NormDiscrete(binaryFeature.getName(), binaryFeature.getValue());
            }
            neuralInputs.addNeuralInputs(new NeuralInput[]{new NeuralInput().setId("0/" + String.valueOf(i + 1)).setDerivedField(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setExpression(normDiscrete))});
        }
        List neuralInputs2 = neuralInputs.getNeuralInputs();
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        for (int i3 = 1; i3 < layers.length; i3++) {
            ArrayList arrayList2 = new ArrayList();
            int size = neuralInputs2.size();
            int i4 = layers[i3];
            for (int i5 = 0; i5 < i4; i5++) {
                Neuron id = new Neuron().setId(i3 + "/" + String.valueOf(i5 + 1));
                for (int i6 = 0; i6 < size; i6++) {
                    id.addConnections(new Connection[]{new Connection().setFrom(((Entity) neuralInputs2.get(i6)).getId()).setWeight(weights.apply(i2 + (i6 * i4) + i5))});
                }
                arrayList2.add(id);
            }
            i2 += size * i4;
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                ((Neuron) it.next()).setBias(Double.valueOf(weights.apply(i2)));
                i2++;
            }
            NeuralLayer neuralLayer = new NeuralLayer(arrayList2);
            if (i3 == layers.length - 1) {
                neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY).setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
            }
            arrayList.add(neuralLayer);
            neuralInputs2 = arrayList2;
        }
        if (i2 != weights.size()) {
            throw new IllegalArgumentException();
        }
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        for (int i7 = 0; i7 < targetCategories.size(); i7++) {
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{new NeuralOutput().setOutputNeuron(((Entity) neuralInputs2.get(i7)).getId()).setDerivedField(new DerivedField(OpType.CATEGORICAL, DataType.STRING).setExpression(new NormDiscrete(targetField, (String) targetCategories.get(i7))))});
        }
        return new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema(schema), neuralInputs, arrayList).setNeuralOutputs(neuralOutputs).setOutput(ModelUtil.createProbabilityOutput(schema));
    }
}
