package ml.dmlc.xgboost4j.java.flink;

import java.lang.invoke.SerializedLambda;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/flink/XGBoost.class */
public class XGBoost {
    private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/flink/XGBoost$MapFunction.class */
    private static class MapFunction extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
        private final Map<String, Object> params;
        private final int round;
        private final Map<String, Object> workerEnvs;

        /* loaded from: input_file:ml/dmlc/xgboost4j/java/flink/XGBoost$MapFunction$VectorToPointMapper.class */
        private static class VectorToPointMapper implements Function<Tuple2<Vector, Double>, LabeledPoint> {
            public static VectorToPointMapper INSTANCE = new VectorToPointMapper();

            private VectorToPointMapper() {
            }

            @Override // java.util.function.Function
            public LabeledPoint apply(Tuple2<Vector, Double> tuple2) {
                SparseVector sparse = ((Vector) tuple2.f0).toSparse();
                double[] dArr = sparse.values;
                int length = dArr.length;
                float[] fArr = new float[length];
                for (int i = 0; i < length; i++) {
                    fArr[i] = (float) dArr[i];
                }
                return new LabeledPoint(((Double) tuple2.f1).floatValue(), sparse.size(), sparse.indices, fArr);
            }
        }

        public MapFunction(Map<String, Object> map, int i, Map<String, Object> map2) {
            this.params = map;
            this.round = i;
            this.workerEnvs = map2;
        }

        public void mapPartition(Iterable<Tuple2<Vector, Double>> iterable, Collector<XGBoostModel> collector) throws XGBoostError {
            this.workerEnvs.put("DMLC_TASK_ID", String.valueOf(getRuntimeContext().getIndexOfThisSubtask()));
            if (XGBoost.logger.isInfoEnabled()) {
                XGBoost.logger.info("start with env: {}", this.workerEnvs.entrySet().stream().map(entry -> {
                    return String.format("\"%s\": \"%s\"", entry.getKey(), entry.getValue());
                }).collect(Collectors.joining(", ")));
            }
            Iterator it = StreamSupport.stream(iterable.spliterator(), false).map(VectorToPointMapper.INSTANCE).iterator();
            if (it.hasNext()) {
                collector.collect(new XGBoostModel(trainBooster(new DMatrix(it, (String) null), ((Integer) Optional.ofNullable(this.params.get("numEarlyStoppingRounds")).map(obj -> {
                    return Integer.valueOf(Integer.parseInt(obj.toString()));
                }).orElse(0)).intValue())));
            } else {
                XGBoost.logger.warn("Nothing to train with.");
            }
        }

        private Booster trainBooster(final DMatrix dMatrix, int i) throws XGBoostError {
            HashMap<String, DMatrix> hashMap = new HashMap<String, DMatrix>() { // from class: ml.dmlc.xgboost4j.java.flink.XGBoost.MapFunction.1
                {
                    put("train", dMatrix);
                }
            };
            try {
                try {
                    Communicator.init(this.workerEnvs);
                    Booster train = ml.dmlc.xgboost4j.java.XGBoost.train(dMatrix, this.params, this.round, hashMap, (float[][]) null, (IObjective) null, (IEvaluation) null, i);
                    Communicator.shutdown();
                    return train;
                } catch (XGBoostError e) {
                    XGBoost.logger.warn(String.format("XGBooster worker %s has failed due to", String.valueOf(getRuntimeContext().getIndexOfThisSubtask())), e);
                    throw e;
                }
            } catch (Throwable th) {
                Communicator.shutdown();
                throw th;
            }
        }
    }

    public static XGBoostModel loadModelFromHadoopFile(String str) throws Exception {
        FSDataInputStream open = FileSystem.get(new Configuration()).open(new Path(str));
        try {
            XGBoostModel xGBoostModel = new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(open));
            if (open != null) {
                open.close();
            }
            return xGBoostModel;
        } catch (Throwable th) {
            if (open != null) {
                try {
                    open.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dataSet, Map<String, Object> map, int i) throws Exception {
        RabitTracker rabitTracker = new RabitTracker(dataSet.getExecutionEnvironment().getParallelism());
        if (rabitTracker.start()) {
            return (XGBoostModel) dataSet.mapPartition(new MapFunction(map, i, rabitTracker.getWorkerArgs())).reduce((xGBoostModel, xGBoostModel2) -> {
                return xGBoostModel;
            }).collect().get(0);
        }
        throw new Error("Tracker cannot be started");
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 295831057:
                if (implMethodName.equals("lambda$train$8cddfdbd$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("ml/dmlc/xgboost4j/java/flink/XGBoost") && serializedLambda.getImplMethodSignature().equals("(Lml/dmlc/xgboost4j/java/flink/XGBoostModel;Lml/dmlc/xgboost4j/java/flink/XGBoostModel;)Lml/dmlc/xgboost4j/java/flink/XGBoostModel;")) {
                    return (xGBoostModel, xGBoostModel2) -> {
                        return xGBoostModel;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
