package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.spark.ml.fpm.FPGrowthModel;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.association.AssociationModel;
import org.dmg.pmml.association.AssociationRule;
import org.dmg.pmml.association.Item;
import org.dmg.pmml.association.ItemRef;
import org.dmg.pmml.association.Itemset;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.AssociationRulesModelConverter;
import org.jpmml.sparkml.ItemSetFeature;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.collection.JavaConversions;
import scala.collection.Seq;

/* loaded from: input_file:org/jpmml/sparkml/model/FPGrowthModelConverter.class */
public class FPGrowthModelConverter extends AssociationRulesModelConverter<FPGrowthModel> {
    public FPGrowthModelConverter(FPGrowthModel fPGrowthModel) {
        super(fPGrowthModel);
    }

    @Override // org.jpmml.sparkml.ModelConverter
    public List<Feature> getFeatures(SparkMLEncoder sparkMLEncoder) {
        String itemsCol = ((FPGrowthModel) getModel()).getItemsCol();
        if (itemsCol.endsWith("s")) {
            itemsCol = itemsCol.substring(0, itemsCol.length() - 1);
        }
        sparkMLEncoder.createDataField("transaction", OpType.CATEGORICAL, DataType.STRING);
        return Collections.singletonList(new ItemSetFeature(sparkMLEncoder, sparkMLEncoder.createDataField(itemsCol, OpType.CATEGORICAL, DataType.STRING)));
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public AssociationModel mo25encodeModel(Schema schema) {
        FPGrowthModel fPGrowthModel = (FPGrowthModel) getModel();
        List features = schema.getFeatures();
        SchemaUtil.checkSize(1, features);
        Feature feature = (Feature) features.get(0);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        ArrayList arrayList = new ArrayList();
        for (Row row : fPGrowthModel.associationRules().collectAsList()) {
            List<String> formatValues = formatValues(JavaConversions.seqAsJavaList((Seq) row.apply(0)));
            List<String> formatValues2 = formatValues(JavaConversions.seqAsJavaList((Seq) row.apply(1)));
            Double d = (Double) row.apply(2);
            arrayList.add(new AssociationRule().setAntecedent(ensureItemset(feature, formatValues, linkedHashMap2, linkedHashMap).requireId()).setConsequent(ensureItemset(feature, formatValues2, linkedHashMap2, linkedHashMap).requireId()).setConfidence(d).setLift(Double.valueOf(0.0d)).setSupport(Double.valueOf(0.0d)));
        }
        AssociationModel associationModel = new AssociationModel(MiningFunction.ASSOCIATION_RULES, 0, Double.valueOf(fPGrowthModel.getMinSupport()), Double.valueOf(fPGrowthModel.getMinConfidence()), Integer.valueOf(linkedHashMap.size()), Integer.valueOf(linkedHashMap2.size()), Integer.valueOf(arrayList.size()), new MiningSchema().addMiningFields(new MiningField[]{ModelUtil.createMiningField("transaction", MiningField.UsageType.GROUP)}));
        associationModel.getItems().addAll(linkedHashMap.values());
        associationModel.getItemsets().addAll(linkedHashMap2.values());
        associationModel.getAssociationRules().addAll(arrayList);
        return associationModel;
    }

    private static Itemset ensureItemset(Feature feature, List<String> list, Map<List<String>, Itemset> map, Map<String, Item> map2) {
        Itemset itemset = map.get(list);
        if (itemset == null) {
            itemset = new Itemset(String.valueOf(map.size() + 1));
            for (String str : list) {
                Item item = map2.get(str);
                if (item == null) {
                    item = new Item(String.valueOf(map2.size() + 1), str).setField(feature.getName());
                    map2.put(str, item);
                }
                itemset.addItemRefs(new ItemRef[]{new ItemRef(item.getId())});
            }
            List itemRefs = itemset.getItemRefs();
            if (itemRefs.size() > 1) {
                Collections.sort(itemRefs, new Comparator<ItemRef>() { // from class: org.jpmml.sparkml.model.FPGrowthModelConverter.1
                    @Override // java.util.Comparator
                    public int compare(ItemRef itemRef, ItemRef itemRef2) {
                        return Integer.compare(Integer.parseInt(itemRef.requireItemRef()), Integer.parseInt(itemRef2.requireItemRef()));
                    }
                });
            }
            map.put(list, itemset);
        }
        return itemset;
    }

    public static List<String> formatValues(List<?> list) {
        return (List) list.stream().map(obj -> {
            return ValueUtil.asString(obj);
        }).collect(Collectors.toList());
    }
}
