package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

/* loaded from: input_file:ai/djl/pytorch/jni/IValueUtils.class */
public final class IValueUtils {
    private static final Pattern PATTERN_LIST = Pattern.compile("\\w+\\[]");
    private static final Pattern PATTERN_TUPLE = Pattern.compile("\\w+\\(\\)");

    private IValueUtils() {
    }

    public static NDList forward(PtSymbolBlock ptSymbolBlock, NDList nDList, boolean z) {
        Pair<IValue[], String> inputs = getInputs(nDList);
        IValue[] iValueArr = (IValue[]) inputs.getKey();
        long moduleRunMethod = PyTorchLibrary.LIB.moduleRunMethod(ptSymbolBlock.getHandle().longValue(), (String) inputs.getValue(), Arrays.stream(iValueArr).mapToLong((v0) -> {
            return v0.getHandle();
        }).toArray(), z);
        PtNDManager manager = ((NDArray) nDList.get(0)).getManager();
        Arrays.stream(iValueArr).forEach((v0) -> {
            v0.close();
        });
        IValue iValue = new IValue(moduleRunMethod);
        try {
            NDList nDList2 = iValue.toNDList(manager);
            iValue.close();
            return nDList2;
        } catch (Throwable th) {
            try {
                iValue.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static IValue forward(PtSymbolBlock ptSymbolBlock, IValue... iValueArr) {
        return runMethod(ptSymbolBlock, "forward", iValueArr);
    }

    public static IValue runMethod(PtSymbolBlock ptSymbolBlock, String str, IValue... iValueArr) {
        return new IValue(PyTorchLibrary.LIB.moduleRunMethod(ptSymbolBlock.getHandle().longValue(), str, Arrays.stream(iValueArr).mapToLong((v0) -> {
            return v0.getHandle();
        }).toArray(), false));
    }

    private static int addToMap(Map<String, Integer> map, String str, List<PairList<String, PtNDArray>> list) {
        return map.computeIfAbsent(str, str2 -> {
            list.add(new PairList());
            return Integer.valueOf(list.size() - 1);
        }).intValue();
    }

    static Pair<IValue[], String> getInputs(NDList nDList) {
        ArrayList arrayList = new ArrayList();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        String str = "forward";
        Iterator it = nDList.iterator();
        while (it.hasNext()) {
            NDArray nDArray = (NDArray) it.next();
            String name = nDArray.getName();
            if (name != null && name.contains(".")) {
                String[] split = name.split("\\.", 2);
                ((PairList) arrayList.get(addToMap(concurrentHashMap, split[0], arrayList))).add(split[1], (PtNDArray) nDArray);
            } else if (name != null && name.startsWith("module_method:")) {
                str = name.substring(14);
            } else if (name != null && PATTERN_LIST.matcher(name).matches()) {
                ((PairList) arrayList.get(addToMap(concurrentHashMap, name, arrayList))).add("[]", (PtNDArray) nDArray);
            } else if (name == null || !PATTERN_TUPLE.matcher(name).matches()) {
                PairList pairList = new PairList();
                pairList.add((Object) null, (PtNDArray) nDArray);
                arrayList.add(pairList);
            } else {
                ((PairList) arrayList.get(addToMap(concurrentHashMap, name, arrayList))).add("()", (PtNDArray) nDArray);
            }
        }
        IValue[] iValueArr = new IValue[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            PairList pairList2 = (PairList) arrayList.get(i);
            String str2 = (String) pairList2.get(0).getKey();
            if (str2 == null) {
                iValueArr[i] = IValue.from((PtNDArray) pairList2.get(0).getValue());
            } else if ("[]".equals(str2)) {
                iValueArr[i] = IValue.listFrom((PtNDArray[]) pairList2.values().toArray(new PtNDArray[0]));
            } else if ("()".equals(str2)) {
                iValueArr[i] = IValue.tupleFrom((IValue[]) pairList2.values().stream().map(IValue::from).toArray(i2 -> {
                    return new IValue[i2];
                }));
            } else {
                ConcurrentHashMap concurrentHashMap2 = new ConcurrentHashMap();
                Iterator it2 = pairList2.iterator();
                while (it2.hasNext()) {
                    Pair pair = (Pair) it2.next();
                    concurrentHashMap2.put((String) pair.getKey(), (PtNDArray) pair.getValue());
                }
                iValueArr[i] = IValue.stringMapFrom(concurrentHashMap2);
            }
        }
        return new Pair<>(iValueArr, str);
    }
}
