package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.types.TFloat32;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.VectorTuple;

/* loaded from: input_file:org/tribuo/interop/tensorflow/ImageConverter.class */
public class ImageConverter implements FeatureConverter {
    private static final long serialVersionUID = 1;

    @Config(mandatory = true, description = "TensorFlow Placeholder Input name.")
    private String inputName;

    @Config(mandatory = true, description = "Image width.")
    private int width;

    @Config(mandatory = true, description = "Image height.")
    private int height;

    @Config(mandatory = true, description = "Number of channels.")
    private int channels;
    private int totalPixels;

    private ImageConverter() {
    }

    public ImageConverter(String str, int i, int i2, int i3) {
        if (i < 1 || i2 < 1 || i3 < 1) {
            throw new IllegalArgumentException("Inputs must be positive integers, found [" + i + "," + i2 + "," + i3 + "]");
        }
        if (str == null || str.isEmpty()) {
            throw new IllegalArgumentException("The input name must be a valid String");
        }
        long j = i * i2 * i3;
        if (j > 2147483647L) {
            throw new IllegalArgumentException("Image size must be less than 2^31, found " + j);
        }
        this.inputName = str;
        this.totalPixels = (int) j;
        this.width = i;
        this.height = i2;
        this.channels = i3;
    }

    public void postConfig() {
        if (this.width < 1 || this.height < 1 || this.channels < 1) {
            throw new PropertyException("", "Inputs must be positive integers, found [" + this.width + "," + this.height + "," + this.channels + "]");
        }
        long j = this.width * this.height * this.channels;
        if (j > 2147483647L) {
            throw new PropertyException("", "Image size must be less than 2^31, found " + j);
        }
        this.totalPixels = (int) j;
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public Set<String> inputNamesSet() {
        return Collections.singleton(this.inputName);
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(Example<?> example, ImmutableFeatureMap immutableFeatureMap) {
        return new TensorMap(this.inputName, TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, this.width, this.height, this.channels}), DataBuffers.of(innerTransform(example, immutableFeatureMap))));
    }

    float[] innerTransform(Example<?> example, ImmutableFeatureMap immutableFeatureMap) {
        if (immutableFeatureMap.size() > this.totalPixels) {
            throw new IllegalArgumentException("Found more values than expected, expected " + this.totalPixels + ", found " + immutableFeatureMap.size());
        }
        float[] fArr = new float[this.totalPixels];
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            fArr[immutableFeatureMap.getID(feature.getName())] = (float) feature.getValue();
        }
        return fArr;
    }

    float[] innerTransform(SGDVector sGDVector) {
        if (sGDVector.size() > this.totalPixels) {
            throw new IllegalArgumentException("Found more values than expected, expected " + this.totalPixels + ", found " + sGDVector.size());
        }
        float[] fArr = new float[this.totalPixels];
        Iterator it = sGDVector.iterator();
        while (it.hasNext()) {
            VectorTuple vectorTuple = (VectorTuple) it.next();
            fArr[vectorTuple.index] = (float) vectorTuple.value;
        }
        return fArr;
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(List<? extends Example<?>> list, ImmutableFeatureMap immutableFeatureMap) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), this.width, this.height, this.channels}));
        int i = 0;
        Iterator<? extends Example<?>> it = list.iterator();
        while (it.hasNext()) {
            tensorOf.set(NdArrays.wrap(Shape.of(new long[]{this.width, this.height, this.channels}), DataBuffers.of(innerTransform(it.next(), immutableFeatureMap))), new long[]{i});
            i++;
        }
        return new TensorMap(this.inputName, tensorOf);
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(SGDVector sGDVector) {
        return new TensorMap(this.inputName, TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, this.width, this.height, this.channels}), DataBuffers.of(innerTransform(sGDVector))));
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(List<? extends SGDVector> list) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), this.width, this.height, this.channels}));
        int i = 0;
        Iterator<? extends SGDVector> it = list.iterator();
        while (it.hasNext()) {
            tensorOf.set(NdArrays.wrap(Shape.of(new long[]{this.width, this.height, this.channels}), DataBuffers.of(innerTransform(it.next()))), new long[]{i});
            i++;
        }
        return new TensorMap(this.inputName, tensorOf);
    }

    public String toString() {
        return "ImageConverter(inputName='" + this.inputName + "',width=" + this.width + ",height=" + this.height + ",channels=" + this.channels + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m4getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "FeatureConverter");
    }
}
