package org.jpmml.model.visitors;

import java.util.Collections;
import java.util.Set;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.model.ChainedSegmentationTest;
import org.jpmml.model.FieldNameUtil;
import org.jpmml.model.NestedSegmentationTest;
import org.jpmml.model.ResourceUtil;
import org.junit.Assert;
import org.junit.Test;
import org.xml.sax.XMLFilter;

/* loaded from: input_file:org/jpmml/model/visitors/FieldResolverTest.class */
public class FieldResolverTest {
    @Test
    public void resolveChained() throws Exception {
        PMML unmarshal = ResourceUtil.unmarshal(ChainedSegmentationTest.class, new XMLFilter[0]);
        final Set<FieldName> create = FieldNameUtil.create("y", "x1", "x2", "x3", "x4");
        final Set<FieldName> create2 = FieldNameUtil.create(create, "x1_squared", "x1_cubed");
        FieldResolver fieldResolver = new FieldResolver() { // from class: org.jpmml.model.visitors.FieldResolverTest.1
            public VisitorAction visit(Apply apply) {
                Set fields = getFields();
                String function = apply.getFunction();
                if ("*".equals(function)) {
                    FieldName name = getParent().getName();
                    if ("x1_squared".equals(name.getValue())) {
                        FieldResolverTest.checkFields(create, fields);
                    } else {
                        if (!"x1_cubed".equals(name.getValue())) {
                            throw new AssertionError();
                        }
                        FieldResolverTest.checkFields(FieldNameUtil.create(create, "x1_squared"), fields);
                    }
                } else if ("pow".equals(function)) {
                    FieldResolverTest.checkFields(FieldNameUtil.create("x"), fields);
                } else if ("square".equals(function)) {
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output"), fields);
                } else {
                    if (!"cube".equals(function)) {
                        throw new AssertionError();
                    }
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output", "x2_squared"), fields);
                }
                return super.visit(apply);
            }
        };
        fieldResolver.applyTo(unmarshal);
        Assert.assertEquals(Collections.emptySet(), fieldResolver.getFields());
        FieldResolver fieldResolver2 = new FieldResolver() { // from class: org.jpmml.model.visitors.FieldResolverTest.2
            public VisitorAction visit(RegressionTable regressionTable) {
                Set fields = getFields();
                String id = getParent(1).getId();
                if ("first".equals(id)) {
                    FieldResolverTest.checkFields(create2, fields);
                } else if ("second".equals(id)) {
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output", "x2_squared", "x2_cubed"), fields);
                } else if ("third".equals(id)) {
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output", "second_output"), fields);
                } else {
                    if (!"sum".equals(id)) {
                        throw new AssertionError();
                    }
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output", "second_output", "third_output"), fields);
                }
                return super.visit(regressionTable);
            }
        };
        fieldResolver2.applyTo(unmarshal);
        Assert.assertEquals(Collections.emptySet(), fieldResolver2.getFields());
        FieldResolver fieldResolver3 = new FieldResolver() { // from class: org.jpmml.model.visitors.FieldResolverTest.3
            public VisitorAction visit(SimplePredicate simplePredicate) {
                Set fields = getFields();
                String id = getParent().getId();
                if ("first".equals(id)) {
                    FieldResolverTest.checkFields(create2, fields);
                } else if ("second".equals(id)) {
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output"), fields);
                } else {
                    if (!"third".equals(id)) {
                        throw new AssertionError();
                    }
                    FieldResolverTest.checkFields(FieldNameUtil.create(create2, "first_output", "second_output"), fields);
                }
                return super.visit(simplePredicate);
            }
        };
        fieldResolver3.applyTo(unmarshal);
        Assert.assertEquals(Collections.emptySet(), fieldResolver3.getFields());
    }

    @Test
    public void resolveNested() throws Exception {
        PMML unmarshal = ResourceUtil.unmarshal(NestedSegmentationTest.class, new XMLFilter[0]);
        final Set<FieldName> create = FieldNameUtil.create("y", "x1", "x2", "x3", "x4", "x5");
        new FieldResolver() { // from class: org.jpmml.model.visitors.FieldResolverTest.4
            public VisitorAction visit(Apply apply) {
                Set fields = getFields();
                FieldName name = getParent().getName();
                if ("x12".equals(name.getValue())) {
                    FieldResolverTest.checkFields(create, fields);
                } else if ("x123".equals(name.getValue())) {
                    FieldResolverTest.checkFields(FieldNameUtil.create(create, "x12"), fields);
                } else if ("x1234".equals(name.getValue())) {
                    FieldResolverTest.checkFields(FieldNameUtil.create(create, "x12", "x123"), fields);
                } else {
                    if (!"x12345".equals(name.getValue())) {
                        throw new AssertionError();
                    }
                    FieldResolverTest.checkFields(FieldNameUtil.create(create, "x12", "x123", "x1234"), fields);
                }
                return super.visit(apply);
            }
        }.applyTo(unmarshal);
        new FieldResolver() { // from class: org.jpmml.model.visitors.FieldResolverTest.5
            public VisitorAction visit(RegressionTable regressionTable) {
                FieldResolverTest.checkFields(FieldNameUtil.create(create, "x12", "x123", "x1234", "x12345"), getFields());
                return super.visit(regressionTable);
            }
        }.applyTo(unmarshal);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void checkFields(Set<FieldName> set, Set<Field<?>> set2) {
        Assert.assertEquals(set, FieldUtil.nameSet(set2));
    }
}
