package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowTrainer.class */
public final class TensorFlowTrainer<T extends Output<T>> implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(TensorFlowTrainer.class.getName());

    @Config(mandatory = true, description = "Path to the protobuf containing the graph.")
    private Path graphPath;
    private GraphDef graphDef;

    @Config(description = "Test time batch size.")
    private int testBatchSize;

    @Config(mandatory = true, description = "Name of the output operation before the loss.")
    private String outputName;

    @Config(mandatory = true, description = "Feature extractor.")
    private FeatureConverter featureConverter;

    @Config(mandatory = true, description = "Response extractor.")
    private OutputConverter<T> outputConverter;

    @Config(description = "Training time batch size.")
    private int trainBatchSize;

    @Config(description = "Number of SGD epochs to run.")
    private int epochs;

    @Config(description = "Logging interval to print out the loss.")
    private int loggingInterval;

    @Config(mandatory = true, description = "The gradient optimiser to use.")
    private GradientOptimiser optimiserEnum;

    @Config(mandatory = true, description = "The gradient optimiser parameters.")
    private Map<String, Float> gradientParams;

    @Config(description = "Saved model format.")
    private TFModelFormat modelFormat;

    @Config(description = "Checkpoint output directory.")
    private Path checkpointPath;

    @Config(description = "Inter operation thread pool size. -1 uses the default TF value. Tribuo defaults to 1 for deterministic behaviour.")
    private int interOpParallelism;

    @Config(description = "Intra operation thread pool size. -1 uses the default TF value. Tribuo defaults to 1 for deterministic behaviour.")
    private int intraOpParallelism;
    private int trainInvocationCounter;

    /* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowTrainer$TFModelFormat.class */
    public enum TFModelFormat {
        TRIBUO_NATIVE,
        CHECKPOINT
    }

    /* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowTrainer$TensorFlowTrainerProvenance.class */
    public static final class TensorFlowTrainerProvenance extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1;
        public static final String GRAPH_HASH = "graph-hash";
        public static final String GRAPH_LAST_MOD = "graph-last-modified";
        private final HashProvenance graphHash;
        private final DateTimeProvenance graphLastModified;

        <T extends Output<T>> TensorFlowTrainerProvenance(TensorFlowTrainer<T> tensorFlowTrainer) {
            super(tensorFlowTrainer);
            if (((TensorFlowTrainer) tensorFlowTrainer).graphPath != null) {
                this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE, "graph-hash", ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, ((TensorFlowTrainer) tensorFlowTrainer).graphPath));
                this.graphLastModified = new DateTimeProvenance("graph-last-modified", OffsetDateTime.ofInstant(Instant.ofEpochMilli(((TensorFlowTrainer) tensorFlowTrainer).graphPath.toFile().lastModified()), ZoneId.systemDefault()));
            } else {
                this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE, "graph-hash", ProvenanceUtil.hashArray(DEFAULT_HASH_TYPE, ((TensorFlowTrainer) tensorFlowTrainer).graphDef.toByteArray()));
                this.graphLastModified = new DateTimeProvenance("graph-last-modified", OffsetDateTime.now());
            }
        }

        public TensorFlowTrainerProvenance(Map<String, Provenance> map) {
            this(extractTFProvenanceInfo(map));
        }

        private TensorFlowTrainerProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
            this.graphHash = (HashProvenance) extractedInfo.instanceValues.get("graph-hash");
            this.graphLastModified = (DateTimeProvenance) extractedInfo.instanceValues.get("graph-last-modified");
        }

        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
            Map<String, PrimitiveProvenance<?>> instanceValues = super.getInstanceValues();
            instanceValues.put(this.graphHash.getKey(), this.graphHash);
            instanceValues.put(this.graphLastModified.getKey(), this.graphLastModified);
            return instanceValues;
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractTFProvenanceInfo(Map<String, Provenance> map) {
            SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo = SkeletalTrainerProvenance.extractProvenanceInfo(map);
            extractProvenanceInfo.instanceValues.put("graph-hash", ObjectProvenance.checkAndExtractProvenance(extractProvenanceInfo.configuredParameters, "graph-hash", HashProvenance.class, TensorFlowTrainerProvenance.class.getSimpleName()));
            extractProvenanceInfo.instanceValues.put("graph-last-modified", ObjectProvenance.checkAndExtractProvenance(extractProvenanceInfo.configuredParameters, "graph-last-modified", DateTimeProvenance.class, TensorFlowTrainerProvenance.class.getSimpleName()));
            return extractProvenanceInfo;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass() || !super.equals(obj)) {
                return false;
            }
            TensorFlowTrainerProvenance tensorFlowTrainerProvenance = (TensorFlowTrainerProvenance) obj;
            return this.graphHash.equals(tensorFlowTrainerProvenance.graphHash) && this.graphLastModified.equals(tensorFlowTrainerProvenance.graphLastModified);
        }

        public int hashCode() {
            return Objects.hash(Integer.valueOf(super.hashCode()), this.graphHash, this.graphLastModified);
        }
    }

    private TensorFlowTrainer() {
        this.testBatchSize = 16;
        this.trainBatchSize = 1;
        this.epochs = 5;
        this.loggingInterval = 100;
        this.modelFormat = TFModelFormat.TRIBUO_NATIVE;
        this.interOpParallelism = 1;
        this.intraOpParallelism = 1;
        this.trainInvocationCounter = 0;
    }

    public TensorFlowTrainer(Path path, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4) throws IOException {
        this(path, loadGraphDef(path), str, gradientOptimiser, map, featureConverter, outputConverter, i, i2, i3, i4, null, TFModelFormat.TRIBUO_NATIVE);
    }

    public TensorFlowTrainer(Path path, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4, Path path2) throws IOException {
        this(path, loadGraphDef(path), str, gradientOptimiser, map, featureConverter, outputConverter, i, i2, i3, i4, path2, TFModelFormat.CHECKPOINT);
    }

    public TensorFlowTrainer(GraphDef graphDef, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4) {
        this(null, graphDef, str, gradientOptimiser, map, featureConverter, outputConverter, i, i2, i3, i4, null, TFModelFormat.TRIBUO_NATIVE);
    }

    public TensorFlowTrainer(GraphDef graphDef, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4, Path path) {
        this(null, graphDef, str, gradientOptimiser, map, featureConverter, outputConverter, i, i2, i3, i4, path, TFModelFormat.CHECKPOINT);
    }

    public TensorFlowTrainer(Graph graph, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4) {
        this(null, graph.toGraphDef(), str, gradientOptimiser, map, featureConverter, outputConverter, i, i2, i3, i4, null, TFModelFormat.TRIBUO_NATIVE);
    }

    public TensorFlowTrainer(Graph graph, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4, Path path) {
        this(null, graph.toGraphDef(), str, gradientOptimiser, map, featureConverter, outputConverter, i, i2, i3, i4, path, TFModelFormat.CHECKPOINT);
    }

    private TensorFlowTrainer(Path path, GraphDef graphDef, String str, GradientOptimiser gradientOptimiser, Map<String, Float> map, FeatureConverter featureConverter, OutputConverter<T> outputConverter, int i, int i2, int i3, int i4, Path path2, TFModelFormat tFModelFormat) {
        this.testBatchSize = 16;
        this.trainBatchSize = 1;
        this.epochs = 5;
        this.loggingInterval = 100;
        this.modelFormat = TFModelFormat.TRIBUO_NATIVE;
        this.interOpParallelism = 1;
        this.intraOpParallelism = 1;
        this.trainInvocationCounter = 0;
        if (path == null && graphDef == null) {
            throw new IllegalArgumentException("Must supply either a GraphDef or a path to a Graph");
        }
        this.graphPath = path;
        this.graphDef = graphDef;
        this.outputName = str;
        this.optimiserEnum = gradientOptimiser;
        this.gradientParams = Collections.unmodifiableMap(new HashMap(map));
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
        this.trainBatchSize = i;
        this.epochs = i2;
        this.testBatchSize = i3;
        this.loggingInterval = i4;
        this.checkpointPath = path2;
        this.modelFormat = tFModelFormat;
        validateGraph(false);
    }

    public void postConfig() throws IOException {
        this.graphDef = loadGraphDef(this.graphPath);
        if (this.checkpointPath == null && this.modelFormat == TFModelFormat.CHECKPOINT) {
            throw new PropertyException("", "checkpointPath", "Must set 'checkpointPath' when using TFModelFormat.CHECKPOINT");
        }
        validateGraph(true);
    }

    private void validateGraph(boolean z) {
        Graph graph = new Graph();
        Throwable th = null;
        try {
            graph.importGraphDef(this.graphDef);
            for (String str : this.featureConverter.inputNamesSet()) {
                if (graph.operation(str) == null) {
                    String str2 = "Unable to find an input operation, expected an op with name '" + str + "'";
                    if (!z) {
                        throw new IllegalArgumentException(str2);
                    }
                    throw new PropertyException("", "featureConverter", str2);
                }
            }
            GraphOperation operation = graph.operation(this.outputName);
            if (operation == null) {
                String str3 = "Unable to find the output operation, expected an op with name '" + this.outputName + "'";
                if (!z) {
                    throw new IllegalArgumentException(str3);
                }
                throw new PropertyException("", "outputName", str3);
            }
            Shape shape = operation.output(0).shape();
            if (shape.numDimensions() != 2) {
                String str4 = "Expected a 2 dimensional output, found " + Arrays.toString(shape.asArray());
                if (!z) {
                    throw new IllegalArgumentException(str4);
                }
                throw new PropertyException("", "outputName", str4);
            }
            if (graph != null) {
                if (0 == 0) {
                    graph.close();
                    return;
                }
                try {
                    graph.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (graph != null) {
                if (0 != 0) {
                    try {
                        graph.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    graph.close();
                }
            }
            throw th3;
        }
    }

    private static GraphDef loadGraphDef(Path path) throws IOException {
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(path.toFile()));
        Throwable th = null;
        try {
            GraphDef parseFrom = GraphDef.parseFrom(bufferedInputStream);
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            return parseFrom;
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public TensorFlowModel<T> m21train(Dataset<T> dataset) {
        return train((Dataset) dataset, Collections.emptyMap());
    }

    public TensorFlowModel<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        return train((Dataset) dataset, map, -1);
    }

    /* JADX WARN: Failed to calculate best type for var: r24v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r24v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r25v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r25v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r26v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r26v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r27v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r27v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 24, insn: 0x0576: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r24 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:195:0x0576 */
    /* JADX WARN: Not initialized variable reg: 25, insn: 0x057b: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r25 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:197:0x057b */
    /* JADX WARN: Not initialized variable reg: 26, insn: 0x0545: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r26 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:176:0x0545 */
    /* JADX WARN: Not initialized variable reg: 27, insn: 0x054a: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r27 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:178:0x054a */
    /* JADX WARN: Type inference failed for: r24v0, types: [org.tensorflow.Graph] */
    /* JADX WARN: Type inference failed for: r25v0, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r26v1, types: [org.tensorflow.Session] */
    /* JADX WARN: Type inference failed for: r27v0, types: [java.lang.Throwable] */
    public TensorFlowModel<T> train(Dataset<T> dataset, Map<String, Provenance> map, int i) {
        Path path;
        ?? r26;
        ?? r27;
        TensorFlowModel tensorFlowCheckpointModel;
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo<T> outputIDInfo = dataset.getOutputIDInfo();
        ArrayList arrayList = new ArrayList();
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            path = this.checkpointPath != null ? Paths.get(this.checkpointPath.toString(), "invocation-" + this.trainInvocationCounter, "tribuo") : null;
            this.trainInvocationCounter++;
        }
        ConfigProto.Builder newBuilder = ConfigProto.newBuilder();
        if (this.interOpParallelism > -1) {
            newBuilder.setInterOpParallelismThreads(this.interOpParallelism);
        }
        if (this.intraOpParallelism > -1) {
            newBuilder.setIntraOpParallelismThreads(this.intraOpParallelism);
        }
        ConfigProto build = newBuilder.build();
        try {
            try {
                Graph graph = new Graph();
                Throwable th = null;
                try {
                    Session session = new Session(graph, build);
                    Throwable th2 = null;
                    graph.importGraphDef(this.graphDef);
                    Ops withName = Ops.create(graph).withName("tribuo-internal");
                    org.tensorflow.Output output = graph.operation(this.outputName).output(0);
                    Shape shape = output.shape();
                    Shape of = Shape.of(new long[]{this.trainBatchSize, outputIDInfo.size()});
                    if (!shape.isCompatibleWith(of)) {
                        throw new IllegalArgumentException("Incompatible output shape, expected " + of.toString() + " found " + shape.toString());
                    }
                    Placeholder placeholder = withName.placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape(Shape.of(new long[]{this.trainBatchSize, outputIDInfo.size()}))});
                    Op op = (Op) this.outputConverter.outputTransformFunction().apply(withName, output);
                    Operand<TNumber> apply = this.outputConverter.loss().apply(withName, new Pair<>(placeholder, output));
                    Op applyOptimiser = this.optimiserEnum.applyOptimiser(graph, apply, this.gradientParams);
                    session.initialize();
                    logger.info("Initialised the model parameters");
                    int i2 = 0;
                    for (int i3 = 0; i3 < this.epochs; i3++) {
                        logger.log(Level.INFO, "Starting epoch " + i3);
                        int i4 = 0;
                        while (i4 < dataset.size()) {
                            arrayList.clear();
                            for (int i5 = i4; i5 < i4 + this.trainBatchSize && i5 < dataset.size(); i5++) {
                                arrayList.add(dataset.getExample(i5));
                            }
                            TensorMap convert = this.featureConverter.convert(arrayList, featureIDMap);
                            Throwable th3 = null;
                            try {
                                Tensor convertToTensor = this.outputConverter.convertToTensor(arrayList, outputIDInfo);
                                Throwable th4 = null;
                                try {
                                    TFloat32 tFloat32 = (Tensor) convert.feedInto(session.runner()).feed(placeholder, convertToTensor).addTarget(applyOptimiser).fetch(apply).run().get(0);
                                    Throwable th5 = null;
                                    try {
                                        try {
                                            if (this.loggingInterval != -1 && i2 % this.loggingInterval == 0) {
                                                logger.log(Level.INFO, "Training loss at itr " + i2 + " = " + tFloat32.getFloat(new long[0]));
                                            }
                                            if (tFloat32 != null) {
                                                if (0 != 0) {
                                                    try {
                                                        tFloat32.close();
                                                    } catch (Throwable th6) {
                                                        th5.addSuppressed(th6);
                                                    }
                                                } else {
                                                    tFloat32.close();
                                                }
                                            }
                                            if (convertToTensor != null) {
                                                if (0 != 0) {
                                                    try {
                                                        convertToTensor.close();
                                                    } catch (Throwable th7) {
                                                        th4.addSuppressed(th7);
                                                    }
                                                } else {
                                                    convertToTensor.close();
                                                }
                                            }
                                            if (convert != null) {
                                                if (0 != 0) {
                                                    try {
                                                        convert.close();
                                                    } catch (Throwable th8) {
                                                        th3.addSuppressed(th8);
                                                    }
                                                } else {
                                                    convert.close();
                                                }
                                            }
                                            i2++;
                                            i4 += this.trainBatchSize;
                                        } finally {
                                        }
                                    } catch (Throwable th9) {
                                        if (tFloat32 != null) {
                                            if (th5 != null) {
                                                try {
                                                    tFloat32.close();
                                                } catch (Throwable th10) {
                                                    th5.addSuppressed(th10);
                                                }
                                            } else {
                                                tFloat32.close();
                                            }
                                        }
                                        throw th9;
                                    }
                                } catch (Throwable th11) {
                                    if (convertToTensor != null) {
                                        if (0 != 0) {
                                            try {
                                                convertToTensor.close();
                                            } catch (Throwable th12) {
                                                th4.addSuppressed(th12);
                                            }
                                        } else {
                                            convertToTensor.close();
                                        }
                                    }
                                    throw th11;
                                }
                            } catch (Throwable th13) {
                                if (convert != null) {
                                    if (0 != 0) {
                                        try {
                                            convert.close();
                                        } catch (Throwable th14) {
                                            th3.addSuppressed(th14);
                                        }
                                    } else {
                                        convert.close();
                                    }
                                }
                                throw th13;
                            }
                        }
                    }
                    TensorFlowUtil.annotateGraph(graph, session);
                    if (this.modelFormat == TFModelFormat.CHECKPOINT) {
                        session.save(path.toString());
                    }
                    GraphDef graphDef = graph.toGraphDef();
                    ModelProvenance modelProvenance = new ModelProvenance(TensorFlowModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m22getProvenance(), map);
                    switch (this.modelFormat) {
                        case TRIBUO_NATIVE:
                            tensorFlowCheckpointModel = new TensorFlowNativeModel("tf-native-model", modelProvenance, featureIDMap, outputIDInfo, graphDef, TensorFlowUtil.extractMarshalledVariables(graph, session), this.testBatchSize, op.op().name(), this.featureConverter, this.outputConverter);
                            break;
                        case CHECKPOINT:
                            tensorFlowCheckpointModel = new TensorFlowCheckpointModel("tf-checkpoint-model", modelProvenance, featureIDMap, outputIDInfo, graphDef, path.getParent().toString(), path.getFileName().toString(), this.testBatchSize, op.op().name(), this.featureConverter, this.outputConverter);
                            break;
                        default:
                            throw new IllegalStateException("Unexpected enum constant " + this.modelFormat);
                    }
                    TensorFlowModel tensorFlowModel = tensorFlowCheckpointModel;
                    if (session != null) {
                        if (0 != 0) {
                            try {
                                session.close();
                            } catch (Throwable th15) {
                                th2.addSuppressed(th15);
                            }
                        } else {
                            session.close();
                        }
                    }
                    if (graph != null) {
                        if (0 != 0) {
                            try {
                                graph.close();
                            } catch (Throwable th16) {
                                th.addSuppressed(th16);
                            }
                        } else {
                            graph.close();
                        }
                    }
                    return tensorFlowModel;
                } catch (Throwable th17) {
                    if (r26 != 0) {
                        if (r27 != 0) {
                            try {
                                r26.close();
                            } catch (Throwable th18) {
                                r27.addSuppressed(th18);
                            }
                        } else {
                            r26.close();
                        }
                    }
                    throw th17;
                }
            } catch (TensorFlowException e) {
                logger.log(Level.SEVERE, "TensorFlow threw an error", e);
                throw new IllegalStateException(e);
            }
        } finally {
        }
    }

    public String toString() {
        String str = "TFTrainer(graphPath=" + (this.graphPath == null ? "" : this.graphPath.toString()) + ",exampleConverter=" + this.featureConverter.toString() + ",outputConverter=" + this.outputConverter.toString() + ",minibatchSize=" + this.trainBatchSize + ",epochs=" + this.epochs + ",gradientOptimizer=" + this.optimiserEnum + ",gradientParams=" + this.gradientParams.toString() + ",modelFormat=" + this.modelFormat;
        return this.modelFormat == TFModelFormat.CHECKPOINT ? str + ",checkpointPath=" + this.checkpointPath.toString() + ")" : str + ")";
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCounter = i;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m22getProvenance() {
        return new TensorFlowTrainerProvenance(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m19train(Dataset dataset, Map map, int i) {
        return train(dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m20train(Dataset dataset, Map map) {
        return train(dataset, (Map<String, Provenance>) map);
    }
}
