package org.jpmml.manager;

import java.util.List;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;

/* loaded from: input_file:org/jpmml/manager/TreeModelManager.class */
public class TreeModelManager extends ModelManager<TreeModel> {
    private TreeModel treeModel;
    private Node root;

    public TreeModelManager() {
        this.treeModel = null;
        this.root = null;
    }

    public TreeModelManager(PMML pmml) {
        this(pmml, find((List<? extends PMMLObject>) pmml.getContent(), TreeModel.class));
    }

    public TreeModelManager(PMML pmml, TreeModel treeModel) {
        super(pmml);
        this.treeModel = null;
        this.root = null;
        this.treeModel = treeModel;
    }

    @Override // org.jpmml.manager.Consumer
    public String getSummary() {
        return "Tree";
    }

    @Override // org.jpmml.manager.ModelManager
    public TreeModel getModel() {
        ensureNotNull(this.treeModel);
        return this.treeModel;
    }

    public TreeModel createClassificationModel() {
        return createModel(MiningFunctionType.CLASSIFICATION);
    }

    public TreeModel createModel(MiningFunctionType miningFunctionType) {
        ensureNull(this.treeModel);
        this.treeModel = new TreeModel(new MiningSchema(), new Node(), miningFunctionType);
        getModels().add(this.treeModel);
        return this.treeModel;
    }

    public Node getOrCreateRoot() {
        if (this.root == null) {
            TreeModel model = getModel();
            this.root = model.getNode();
            if (this.root == null) {
                this.root = new Node();
                model.setNode(this.root);
            }
            if (this.root.getPredicate() == null) {
                this.root.setPredicate(new True());
            }
        }
        return this.root;
    }

    public Node addNode(Predicate predicate) {
        return addNode(getOrCreateRoot(), predicate);
    }

    public Node addNode(Node node, Predicate predicate) {
        Node node2 = new Node();
        node2.setPredicate(predicate);
        node.getNodes().add(node2);
        return node2;
    }

    public ScoreDistribution getOrAddScoreDistribution(Node node, String str) {
        List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
        for (ScoreDistribution scoreDistribution : scoreDistributions) {
            if (scoreDistribution.getValue().equals(str)) {
                return scoreDistribution;
            }
        }
        ScoreDistribution scoreDistribution2 = new ScoreDistribution(str, 0.0d);
        scoreDistributions.add(scoreDistribution2);
        return scoreDistribution2;
    }
}
