package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.logging.Logger;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Mean;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TNumber;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;

/* loaded from: input_file:org/tribuo/interop/tensorflow/LabelConverter.class */
public class LabelConverter implements OutputConverter<Label> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(LabelConverter.class.getName());

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public BiFunction<Ops, Pair<Placeholder<? extends TNumber>, Operand<TNumber>>, Operand<TNumber>> loss() {
        return (ops, pair) -> {
            return ops.math.mean(ops.nn.softmaxCrossEntropyWithLogits((Operand) pair.getB(), (Placeholder) pair.getA()).loss(), ops.constant(0), new Mean.Options[0]);
        };
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public <V extends TNumber> BiFunction<Ops, Operand<V>, Op> outputTransformFunction() {
        return (ops, operand) -> {
            return ops.nn.softmax(operand);
        };
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Prediction<Label> convertToPrediction(Tensor tensor, ImmutableOutputInfo<Label> immutableOutputInfo, int i, Example<Label> example) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        long j = batchPredictions.shape().asArray()[0];
        if (j != serialVersionUID) {
            throw new IllegalArgumentException("Supplied tensor has too many results, batchSize = " + j);
        }
        return generatePrediction(batchPredictions.slice(new Index[]{Indices.at(0L), Indices.all()}), immutableOutputInfo, i, example);
    }

    private Prediction<Label> generatePrediction(FloatNdArray floatNdArray, ImmutableOutputInfo<Label> immutableOutputInfo, int i, Example<Label> example) {
        long[] asArray = floatNdArray.shape().asArray();
        if (asArray.length != 1) {
            throw new IllegalArgumentException("Failed to get scalar predictions. Found " + Arrays.toString(asArray));
        }
        if (asArray[0] > 2147483647L) {
            throw new IllegalArgumentException("More than Integer.MAX_VALUE predictions. Found " + asArray[0]);
        }
        int i2 = (int) asArray[0];
        Label label = null;
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < i2; i3++) {
            Label label2 = new Label(immutableOutputInfo.getOutput(i3).getLabel(), floatNdArray.getFloat(new long[]{i3}));
            hashMap.put(label2.getLabel(), label2);
            if (label == null || label2.getScore() > label.getScore()) {
                label = label2;
            }
        }
        return new Prediction<>(label, hashMap, i, example, true);
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Label convertToOutput(Tensor tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        long j = batchPredictions.shape().asArray()[0];
        if (j != serialVersionUID) {
            throw new IllegalArgumentException("Supplied tensor has too many results, batchSize = " + j);
        }
        return generateLabel(batchPredictions.slice(new Index[]{Indices.at(0L), Indices.all()}), immutableOutputInfo);
    }

    private Label generateLabel(FloatNdArray floatNdArray, ImmutableOutputInfo<Label> immutableOutputInfo) {
        long[] asArray = floatNdArray.shape().asArray();
        if (asArray.length != 1) {
            throw new IllegalArgumentException("Failed to get scalar predictions. Found " + Arrays.toString(asArray));
        }
        if (asArray[0] > 2147483647L) {
            throw new IllegalArgumentException("More than Integer.MAX_VALUE predictions. Found " + asArray[0]);
        }
        int i = (int) asArray[0];
        int i2 = 0;
        float f = Float.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < i; i3++) {
            float f2 = floatNdArray.getFloat(new long[]{i3});
            if (f2 > f) {
                i2 = i3;
                f = f2;
            }
        }
        return new Label(immutableOutputInfo.getOutput(i2).getLabel(), f);
    }

    private FloatNdArray getBatchPredictions(Tensor tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        long[] asArray = tensor.shape().asArray();
        if (asArray.length != 2) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(asArray));
        }
        int i = (int) asArray[1];
        if (i != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has incorrect number of elements, tensor output dimension: " + i + ", outputInfo dimension: " + immutableOutputInfo.size());
        }
        if (tensor instanceof TFloat16) {
            return (TFloat16) tensor;
        }
        if (tensor instanceof TFloat32) {
            return (TFloat32) tensor;
        }
        throw new IllegalArgumentException("Tensor is not a probability distribution. Found type " + tensor.getClass().getName());
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public List<Prediction<Label>> convertToBatchPrediction(Tensor tensor, ImmutableOutputInfo<Label> immutableOutputInfo, int[] iArr, List<Example<Label>> list) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        int i = (int) batchPredictions.shape().asArray()[0];
        if (i != list.size() || i != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + iArr.length + ", received " + i);
        }
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(generatePrediction(batchPredictions.slice(new Index[]{Indices.at(i2), Indices.all()}), immutableOutputInfo, iArr[i2], list.get(i2)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public List<Label> convertToBatchOutput(Tensor tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        FloatNdArray batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        int i = (int) batchPredictions.shape().asArray()[0];
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(generateLabel(batchPredictions.slice(new Index[]{Indices.at(i2), Indices.all()}), immutableOutputInfo));
        }
        return arrayList;
    }

    private int innerTransform(Label label, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int id = immutableOutputInfo.getID(label);
        if (id == -1) {
            throw new IllegalArgumentException("Label " + label + " isn't known by the supplied outputIDInfo, " + immutableOutputInfo.toString());
        }
        return id;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Tensor convertToTensor(Label label, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int innerTransform = innerTransform(label, immutableOutputInfo);
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, immutableOutputInfo.size()}));
        for (int i = 0; i < immutableOutputInfo.size(); i++) {
            tensorOf.setFloat(0.0f, new long[]{0, i});
        }
        tensorOf.setFloat(1.0f, new long[]{0, innerTransform});
        return tensorOf;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public Tensor convertToTensor(List<Example<Label>> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), immutableOutputInfo.size()}));
        int i = 0;
        Iterator<Example<Label>> it = list.iterator();
        while (it.hasNext()) {
            int innerTransform = innerTransform((Label) it.next().getOutput(), immutableOutputInfo);
            for (int i2 = 0; i2 < immutableOutputInfo.size(); i2++) {
                tensorOf.setFloat(0.0f, new long[]{i, i2});
            }
            tensorOf.setFloat(1.0f, new long[]{i, innerTransform});
            i++;
        }
        return tensorOf;
    }

    @Override // org.tribuo.interop.tensorflow.OutputConverter
    public boolean generatesProbabilities() {
        return true;
    }

    public String toString() {
        return "LabelConverter()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m6getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputConverter");
    }
}
