package org.jpmml.converter.neural_network;

import com.google.common.collect.Iterables;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/converter/neural_network/NeuralNetworkUtil.class */
public class NeuralNetworkUtil {
    private NeuralNetworkUtil() {
    }

    public static NeuralInputs createNeuralInputs(List<? extends Feature> list, DataType dataType) {
        NormDiscrete normDiscrete;
        NeuralInputs neuralInputs = new NeuralInputs();
        for (int i = 0; i < list.size(); i++) {
            Feature feature = list.get(i);
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                normDiscrete = new NormDiscrete(binaryFeature.getName(), binaryFeature.getValue());
            } else {
                normDiscrete = feature instanceof BooleanFeature ? new NormDiscrete(((BooleanFeature) feature).getName(), "true") : feature.toContinuousFeature().ref();
            }
            neuralInputs.addNeuralInputs(new NeuralInput[]{new NeuralInput().setId("input/" + String.valueOf(i + 1)).setDerivedField(new DerivedField(OpType.CONTINUOUS, dataType).setExpression(normDiscrete))});
        }
        return neuralInputs;
    }

    public static Neuron createNeuron(List<? extends Entity> list, List<Double> list2, Double d) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        Neuron neuron = new Neuron();
        for (int i = 0; i < list.size(); i++) {
            Entity entity = list.get(i);
            Double d2 = list2.get(i);
            if (!d2.isNaN() && !ValueUtil.isZero(d2)) {
                neuron.addConnections(new Connection[]{new Connection().setFrom(entity.getId()).setWeight(d2.doubleValue())});
            }
        }
        if (!d.isNaN() && !ValueUtil.isZero(d)) {
            neuron.setBias(d);
        }
        return neuron;
    }

    public static NeuralOutputs createRegressionNeuralOutputs(List<? extends Entity> list, ContinuousLabel continuousLabel) {
        if (list.size() != 1) {
            throw new IllegalArgumentException();
        }
        return new NeuralOutputs().addNeuralOutputs(new NeuralOutput[]{new NeuralOutput().setOutputNeuron(((Entity) Iterables.getOnlyElement(list)).getId()).setDerivedField(new DerivedField(OpType.CONTINUOUS, continuousLabel.getDataType()).setExpression(new FieldRef(continuousLabel.getName())))});
    }

    public static NeuralOutputs createClassificationNeuralOutputs(List<? extends Entity> list, CategoricalLabel categoricalLabel) {
        if (list.size() != categoricalLabel.size()) {
            throw new IllegalArgumentException();
        }
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{new NeuralOutput().setOutputNeuron(list.get(i).getId()).setDerivedField(new DerivedField(OpType.CATEGORICAL, categoricalLabel.getDataType()).setExpression(new NormDiscrete(categoricalLabel.getName(), categoricalLabel.getValue(i))))});
        }
        return neuralOutputs;
    }
}
