package org.tribuo.interop.tensorflow;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Paths;
import java.util.Map;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.interop.tensorflow.TensorFlowUtil;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowNativeModel.class */
public final class TensorFlowNativeModel<T extends Output<T>> extends TensorFlowModel<T> {
    private static final long serialVersionUID = 200;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TensorFlowNativeModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, GraphDef graphDef, Map<String, TensorFlowUtil.TensorTuple> map, int i, String str2, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, graphDef, i, str2, featureConverter, outputConverter);
        TensorFlowUtil.restoreMarshalledVariables(this.session, map);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public TensorFlowNativeModel<T> m15copy(String str, ModelProvenance modelProvenance) {
        return new TensorFlowNativeModel<>(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.modelGraph.toGraphDef(), TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session), this.batchSize, this.outputName, this.featureConverter, this.outputConverter);
    }

    public TensorFlowCheckpointModel<T> convertToCheckpointModel(String str, String str2) {
        this.session.save(Paths.get(str, str2).toString());
        return new TensorFlowCheckpointModel<>(this.name, this.provenance, this.featureIDMap, this.outputIDInfo, this.modelGraph.toGraphDef(), str, str2, this.batchSize, this.outputName, this.featureConverter, this.outputConverter);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        if (this.closed) {
            throw new IllegalStateException("Can't serialize a closed model, the state has gone.");
        }
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeObject(this.modelGraph.toGraphDef().toByteArray());
        objectOutputStream.writeObject(TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session));
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        byte[] bArr = (byte[]) objectInputStream.readObject();
        Map map = (Map) objectInputStream.readObject();
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(GraphDef.parseFrom(bArr));
        this.session = new Session(this.modelGraph);
        TensorFlowUtil.restoreMarshalledVariables(this.session, map);
    }
}
