package org.jpmml.sparkml.model;

import java.util.List;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;

/* loaded from: input_file:org/jpmml/sparkml/model/GBTClassificationModelConverter.class */
public class GBTClassificationModelConverter extends ClassificationModelConverter<GBTClassificationModel> {
    public GBTClassificationModelConverter(GBTClassificationModel gBTClassificationModel) {
        super(gBTClassificationModel);
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo4encodeModel(Schema schema) {
        GBTClassificationModel gBTClassificationModel = (GBTClassificationModel) getTransformer();
        Schema anonymousSchema = schema.toAnonymousSchema();
        List<TreeModel> encodeDecisionTreeEnsemble = TreeModelUtil.encodeDecisionTreeEnsemble(gBTClassificationModel, gBTClassificationModel.treeWeights(), anonymousSchema);
        return MiningModelUtil.createBinaryLogisticClassification(schema, new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(anonymousSchema)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, encodeDecisionTreeEnsemble)).setOutput(encodeOutput()), 1000.0d, false);
    }

    private static Output encodeOutput() {
        OutputField finalResult = new OutputField(FieldName.create("gbtValue"), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.PREDICTED_VALUE).setFinalResult(false);
        return new Output().addOutputFields(new OutputField[]{finalResult, new OutputField(FieldName.create("binarizedGbtValue"), DataType.DOUBLE).setOpType(OpType.CONTINUOUS).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(false).setExpression(PMMLUtil.createApply("if", new Expression[]{PMMLUtil.createApply("greaterThan", new Expression[]{new FieldRef(finalResult.getName()), PMMLUtil.createConstant(Double.valueOf(0.0d))}), PMMLUtil.createConstant(Double.valueOf(-1.0d)), PMMLUtil.createConstant(Double.valueOf(1.0d))}))});
    }
}
