package org.jpmml.sparkml;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Properties;
import javax.xml.bind.JAXBException;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.converter.Label;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;

/* loaded from: input_file:org/jpmml/sparkml/ConverterUtil.class */
public class ConverterUtil {
    private static final Map<Class<? extends Transformer>, Class<? extends TransformerConverter>> converters = new LinkedHashMap();
    private static final Logger logger = LogManager.getLogger(ConverterUtil.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.ConverterUtil$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/ConverterUtil$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningField$UsageType = new int[MiningField.UsageType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.PREDICTED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.TARGET.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private ConverterUtil() {
    }

    public static PMML toPMML(StructType structType, PipelineModel pipelineModel) {
        Model miningSchema;
        SparkMLEncoder sparkMLEncoder = new SparkMLEncoder(structType);
        ArrayList arrayList = new ArrayList();
        Iterator<Transformer> it = getTransformers(pipelineModel).iterator();
        while (it.hasNext()) {
            TransformerConverter createConverter = createConverter(it.next());
            if (createConverter instanceof FeatureConverter) {
                ((FeatureConverter) createConverter).registerFeatures(sparkMLEncoder);
            } else {
                if (!(createConverter instanceof ModelConverter)) {
                    throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + createConverter);
                }
                arrayList.add(((ModelConverter) createConverter).registerModel(sparkMLEncoder));
            }
        }
        if (arrayList.size() == 1) {
            miningSchema = (Model) Iterables.getOnlyElement(arrayList);
        } else {
            if (arrayList.size() <= 1) {
                throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
            }
            ArrayList arrayList2 = new ArrayList();
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                for (MiningField miningField : ((Model) it2.next()).getMiningSchema().getMiningFields()) {
                    switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningField$UsageType[miningField.getUsageType().ordinal()]) {
                        case 1:
                        case 2:
                            arrayList2.add(miningField);
                            break;
                    }
                }
            }
            miningSchema = MiningModelUtil.createModelChain(arrayList, new Schema((Label) null, Collections.emptyList())).setMiningSchema(new MiningSchema(arrayList2));
        }
        return sparkMLEncoder.encodePMML(miningSchema);
    }

    public static byte[] toPMMLByteArray(StructType structType, PipelineModel pipelineModel) {
        PMML pmml = toPMML(structType, pipelineModel);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(1048576);
        try {
            MetroJAXBUtil.marshalPMML(pmml, byteArrayOutputStream);
            return byteArrayOutputStream.toByteArray();
        } catch (JAXBException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public static FeatureConverter<?> createFeatureConverter(Transformer transformer) {
        return (FeatureConverter) createConverter(transformer);
    }

    public static ModelConverter<?> createModelConverter(Transformer transformer) {
        return (ModelConverter) createConverter(transformer);
    }

    public static <T extends Transformer> TransformerConverter<T> createConverter(T t) {
        Class<?> cls = t.getClass();
        Class<? extends TransformerConverter> converterClazz = getConverterClazz(cls);
        if (converterClazz == null) {
            throw new IllegalArgumentException("Transformer class " + cls.getName() + " is not supported");
        }
        try {
            return converterClazz.getDeclaredConstructor(cls).newInstance(t);
        } catch (ReflectiveOperationException e) {
            throw new IllegalArgumentException(e);
        }
    }

    public static Class<? extends TransformerConverter> getConverterClazz(Class<? extends Transformer> cls) {
        return converters.get(cls);
    }

    public static void putConverterClazz(Class<? extends Transformer> cls, Class<? extends TransformerConverter<?>> cls2) {
        if (cls == null || !Transformer.class.isAssignableFrom(cls)) {
            throw new IllegalArgumentException("Expected " + Transformer.class.getName() + " subclass, got " + (cls != null ? cls.getName() : null));
        }
        if (cls2 == null || !TransformerConverter.class.isAssignableFrom(cls2)) {
            throw new IllegalArgumentException("Expected " + TransformerConverter.class.getName() + " subclass, got " + (cls2 != null ? cls2.getName() : null));
        }
        converters.put(cls, cls2);
    }

    private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
        boolean z;
        ArrayList arrayList = new ArrayList();
        arrayList.add(pipelineModel);
        Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>() { // from class: org.jpmml.sparkml.ConverterUtil.1
            public List<Transformer> apply(Transformer transformer) {
                if (transformer instanceof PipelineModel) {
                    return Arrays.asList(((PipelineModel) transformer).stages());
                }
                if (transformer instanceof CrossValidatorModel) {
                    return Collections.singletonList(((CrossValidatorModel) transformer).bestModel());
                }
                if (transformer instanceof TrainValidationSplitModel) {
                    return Collections.singletonList(((TrainValidationSplitModel) transformer).bestModel());
                }
                return null;
            }
        };
        do {
            ListIterator listIterator = arrayList.listIterator();
            z = false;
            while (listIterator.hasNext()) {
                List list = (List) function.apply((Transformer) listIterator.next());
                if (list != null) {
                    listIterator.remove();
                    Iterator it = list.iterator();
                    while (it.hasNext()) {
                        listIterator.add((Transformer) it.next());
                    }
                    z = true;
                }
            }
        } while (z);
        return arrayList;
    }

    private static void init() {
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        if (contextClassLoader == null) {
            contextClassLoader = ClassLoader.getSystemClassLoader();
        }
        try {
            Enumeration<URL> resources = contextClassLoader.getResources("META-INF/sparkml2pmml.properties");
            while (resources.hasMoreElements()) {
                URL nextElement = resources.nextElement();
                logger.trace("Loading resource " + nextElement);
                try {
                    InputStream openStream = nextElement.openStream();
                    Throwable th = null;
                    try {
                        try {
                            Properties properties = new Properties();
                            properties.load(openStream);
                            init(contextClassLoader, properties);
                            if (openStream != null) {
                                if (0 != 0) {
                                    try {
                                        openStream.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    openStream.close();
                                }
                            }
                        } catch (Throwable th3) {
                            if (openStream != null) {
                                if (th != null) {
                                    try {
                                        openStream.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    openStream.close();
                                }
                            }
                            throw th3;
                            break;
                        }
                    } catch (Throwable th5) {
                        th = th5;
                        throw th5;
                        break;
                    }
                } catch (IOException e) {
                    logger.warn("Failed to load resource", e);
                }
            }
        } catch (IOException e2) {
            logger.warn("Failed to find resources", e2);
        }
    }

    private static void init(ClassLoader classLoader, Properties properties) {
        if (properties.isEmpty()) {
            return;
        }
        for (String str : properties.stringPropertyNames()) {
            String property = properties.getProperty(str);
            logger.trace("Mapping transformer class " + str + " to transformer converter class " + property);
            try {
                try {
                    putConverterClazz(classLoader.loadClass(str), classLoader.loadClass(property));
                } catch (ClassNotFoundException e) {
                    logger.warn("Failed to load transformer converter class", e);
                }
            } catch (ClassNotFoundException e2) {
                logger.warn("Failed to load transformer class", e2);
            }
        }
    }

    static {
        init();
    }
}
