package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableDataset;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.datasource.LibSVMDataSource;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TrainTest.class */
public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    /* loaded from: input_file:org/tribuo/interop/tensorflow/TrainTest$InputType.class */
    public enum InputType {
        DENSE,
        IMAGE
    }

    /* loaded from: input_file:org/tribuo/interop/tensorflow/TrainTest$TensorflowOptions.class */
    public static class TensorflowOptions implements Options {
        private static List<String> DEFAULT_PARAM_NAMES = new ArrayList();
        private static List<Float> DEFAULT_PARAM_VALUES = new ArrayList();

        @Option(charName = 'f', longName = "model-output-path", usage = "Path to serialize model to.")
        public Path outputPath;

        @Option(charName = 'u', longName = "training-file", usage = "Path to the libsvm format training file.")
        public Path trainingPath;

        @Option(charName = 'v', longName = "testing-file", usage = "Path to the libsvm format testing file.")
        public Path testingPath;

        @Option(charName = 'l', longName = "output-name", usage = "Name of the output operation.")
        public String outputName;

        @Option(charName = 'n', longName = "input-name", usage = "Name of the input placeholder.")
        public String inputName;

        @Option(charName = 'm', longName = "model-protobuf", usage = "Path to the protobuf containing the network description.")
        public Path protobufPath;

        @Option(charName = 'p', longName = "checkpoint-dir", usage = "Path to the checkpoint base directory.")
        public Path checkpointPath;

        @Option(longName = "optimizer-param-names", usage = "Gradient optimizer param names, see org.tribuo.interop.tensorflow.GradientOptimiser.")
        public List<String> gradientParamNames = DEFAULT_PARAM_NAMES;

        @Option(longName = "optimizer-param-values", usage = "Gradient optimizer param values, see org.tribuo.interop.tensorflow.GradientOptimiser.")
        public List<Float> gradientParamValues = DEFAULT_PARAM_VALUES;

        @Option(charName = 'g', longName = "gradient-optimizer", usage = "The gradient optimizer to use.")
        public GradientOptimiser optimiser = GradientOptimiser.ADAGRAD;

        @Option(longName = "test-batch-size", usage = "Test time minibatch size.")
        public int testBatchSize = 16;

        @Option(charName = 'b', longName = "batch-size", usage = "Minibatch size.")
        public int batchSize = 128;

        @Option(charName = 'e', longName = "num-epochs", usage = "Number of gradient descent epochs.")
        public int epochs = 5;

        @Option(longName = "logging-interval", usage = "Interval between logging the loss.")
        public int loggingInterval = 1000;

        @Option(longName = "image-format", usage = "Image format, in [W,H,C]. Defaults to MNIST.")
        public String imageFormat = "28,28,1";

        @Option(charName = 't', longName = "input-type", usage = "Input type.")
        public InputType inputType = InputType.IMAGE;

        public String getOptionsDescription() {
            return "Trains and tests a Tensorflow classification model.";
        }

        public Map<String, Float> getGradientParams() {
            if (this.gradientParamNames.size() != this.gradientParamValues.size()) {
                throw new IllegalArgumentException("Must supply both name and value for the gradient parameters, found " + this.gradientParamNames.size() + " names, and " + this.gradientParamValues.size() + "values.");
            }
            HashMap hashMap = new HashMap();
            for (int i = 0; i < this.gradientParamNames.size(); i++) {
                hashMap.put(this.gradientParamNames.get(i), this.gradientParamValues.get(i));
            }
            return hashMap;
        }

        static {
            DEFAULT_PARAM_NAMES.add("learningRate");
            DEFAULT_PARAM_NAMES.add("initialAccumulatorValue");
            DEFAULT_PARAM_VALUES.add(Float.valueOf(0.01f));
            DEFAULT_PARAM_VALUES.add(Float.valueOf(0.1f));
        }
    }

    private static Pair<Dataset<Label>, Dataset<Label>> load(Path path, Path path2, OutputFactory<Label> outputFactory) throws IOException {
        logger.info(String.format("Loading data from %s", path));
        LibSVMDataSource libSVMDataSource = new LibSVMDataSource(path, outputFactory);
        MutableDataset mutableDataset = new MutableDataset(libSVMDataSource);
        boolean isZeroIndexed = libSVMDataSource.isZeroIndexed();
        int maxFeatureID = libSVMDataSource.getMaxFeatureID();
        logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(mutableDataset.size()), mutableDataset.getOutputs().toString()));
        logger.info("Found " + mutableDataset.getFeatureIDMap().size() + " features");
        ImmutableDataset immutableDataset = new ImmutableDataset(new LibSVMDataSource(path2, outputFactory, isZeroIndexed, maxFeatureID), mutableDataset.getFeatureIDMap(), mutableDataset.getOutputIDInfo(), false);
        logger.info(String.format("Loaded %d testing examples", Integer.valueOf(immutableDataset.size())));
        return new Pair<>(mutableDataset, immutableDataset);
    }

    public static void main(String[] strArr) throws IOException {
        FeatureConverter denseFeatureConverter;
        TensorFlowTrainer tensorFlowTrainer;
        LabsLogFormatter.setAllLogFormatters();
        TensorflowOptions tensorflowOptions = new TensorflowOptions();
        try {
            ConfigurationManager configurationManager = new ConfigurationManager(strArr, tensorflowOptions);
            if (tensorflowOptions.trainingPath == null || tensorflowOptions.testingPath == null) {
                logger.info(configurationManager.usage());
                return;
            }
            Pair<Dataset<Label>, Dataset<Label>> load = load(tensorflowOptions.trainingPath, tensorflowOptions.testingPath, new LabelFactory());
            Dataset dataset = (Dataset) load.getA();
            Dataset dataset2 = (Dataset) load.getB();
            if (tensorflowOptions.inputName == null || tensorflowOptions.inputName.isEmpty() || tensorflowOptions.outputName == null || tensorflowOptions.outputName.isEmpty()) {
                throw new IllegalArgumentException("Must specify both 'input-name' and 'output-name'");
            }
            switch (tensorflowOptions.inputType) {
                case IMAGE:
                    String[] split = tensorflowOptions.imageFormat.split(",");
                    if (split.length == 3) {
                        denseFeatureConverter = new ImageConverter(tensorflowOptions.inputName, Integer.parseInt(split[0]), Integer.parseInt(split[1]), Integer.parseInt(split[2]));
                        break;
                    } else {
                        logger.info(configurationManager.usage());
                        logger.info("Invalid image format specified. Found " + tensorflowOptions.imageFormat);
                        return;
                    }
                case DENSE:
                    denseFeatureConverter = new DenseFeatureConverter(tensorflowOptions.inputName);
                    break;
                default:
                    logger.info(configurationManager.usage());
                    logger.info("Unknown input type. Found " + tensorflowOptions.inputType);
                    return;
            }
            LabelConverter labelConverter = new LabelConverter();
            if (tensorflowOptions.checkpointPath == null) {
                logger.info("Using TensorflowTrainer");
                tensorFlowTrainer = new TensorFlowTrainer(tensorflowOptions.protobufPath, tensorflowOptions.outputName, tensorflowOptions.optimiser, tensorflowOptions.getGradientParams(), denseFeatureConverter, labelConverter, tensorflowOptions.batchSize, tensorflowOptions.epochs, tensorflowOptions.testBatchSize, tensorflowOptions.loggingInterval);
            } else {
                logger.info("Using TensorflowCheckpointTrainer, writing to path " + tensorflowOptions.checkpointPath);
                tensorFlowTrainer = new TensorFlowTrainer(tensorflowOptions.protobufPath, tensorflowOptions.outputName, tensorflowOptions.optimiser, tensorflowOptions.getGradientParams(), denseFeatureConverter, labelConverter, tensorflowOptions.batchSize, tensorflowOptions.epochs, tensorflowOptions.testBatchSize, tensorflowOptions.loggingInterval, tensorflowOptions.checkpointPath);
            }
            logger.info("Training using " + tensorFlowTrainer.toString());
            long currentTimeMillis = System.currentTimeMillis();
            Model train = tensorFlowTrainer.train(dataset);
            logger.info("Finished training classifier " + Util.formatDuration(currentTimeMillis, System.currentTimeMillis()));
            long currentTimeMillis2 = System.currentTimeMillis();
            LabelEvaluation evaluate = new LabelEvaluator().evaluate(train, dataset2);
            logger.info("Finished evaluating model " + Util.formatDuration(currentTimeMillis2, System.currentTimeMillis()));
            if (train.generatesProbabilities()) {
                logger.info("Average AUC = " + evaluate.averageAUCROC(false));
                logger.info("Average weighted AUC = " + evaluate.averageAUCROC(true));
            }
            System.out.println(evaluate.toString());
            System.out.println(evaluate.getConfusionMatrix().toString());
            if (tensorflowOptions.outputPath != null) {
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(tensorflowOptions.outputPath.toFile()));
                Throwable th = null;
                try {
                    try {
                        objectOutputStream.writeObject(train);
                        if (objectOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    objectOutputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                objectOutputStream.close();
                            }
                        }
                        logger.info("Serialized model to file: " + tensorflowOptions.outputPath);
                    } catch (Throwable th3) {
                        th = th3;
                        throw th3;
                    }
                } catch (Throwable th4) {
                    if (objectOutputStream != null) {
                        if (th != null) {
                            try {
                                objectOutputStream.close();
                            } catch (Throwable th5) {
                                th.addSuppressed(th5);
                            }
                        } else {
                            objectOutputStream.close();
                        }
                    }
                    throw th4;
                }
            }
            if (tensorflowOptions.checkpointPath == null) {
                ((TensorFlowNativeModel) train).close();
            } else {
                ((TensorFlowCheckpointModel) train).close();
            }
        } catch (UsageException e) {
            logger.info(e.getMessage());
        }
    }
}
