package io.kroxylicious.filter.encryption.inband;

import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import io.kroxylicious.filter.encryption.DecryptionManager;
import io.kroxylicious.filter.encryption.EncryptionException;
import io.kroxylicious.filter.encryption.EncryptionVersion;
import io.kroxylicious.filter.encryption.FilterThreadExecutor;
import io.kroxylicious.filter.encryption.dek.Dek;
import io.kroxylicious.filter.encryption.dek.DekManager;
import io.kroxylicious.filter.encryption.inband.DecryptionDekCache;
import io.kroxylicious.filter.encryption.records.RecordStream;
import io.kroxylicious.kms.service.Serde;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.IntFunction;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.record.MemoryRecords;
import org.apache.kafka.common.record.Record;
import org.apache.kafka.common.utils.ByteBufferOutputStream;

/* loaded from: input_file:io/kroxylicious/filter/encryption/inband/InBandDecryptionManager.class */
public class InBandDecryptionManager<K, E> implements DecryptionManager {
    private final DekManager<K, E> dekManager;
    private final FilterThreadExecutor filterThreadExecutor;
    private final DecryptionDekCache<K, E> dekCache;

    public InBandDecryptionManager(@NonNull DekManager<K, E> dekManager, @NonNull DecryptionDekCache<K, E> decryptionDekCache, @Nullable FilterThreadExecutor filterThreadExecutor) {
        this.dekManager = (DekManager) Objects.requireNonNull(dekManager);
        this.dekCache = (DecryptionDekCache) Objects.requireNonNull(decryptionDekCache);
        this.filterThreadExecutor = filterThreadExecutor;
    }

    static EncryptionVersion decryptionVersion(@NonNull String str, int i, @NonNull Record record) {
        for (Header header : record.headers()) {
            if ("kroxylicious.io/encryption".equals(header.key())) {
                byte[] value = header.value();
                if (value.length == 1) {
                    return EncryptionVersion.fromCode(value[0]);
                }
                EncryptionException encryptionException = new EncryptionException("Invalid value for header with key 'kroxylicious.io/encryption' in record at offset " + record.offset() + " in partition " + encryptionException + " of topic " + i);
                throw encryptionException;
            }
        }
        return null;
    }

    @Override // io.kroxylicious.filter.encryption.DecryptionManager
    @NonNull
    public CompletionStage<MemoryRecords> decrypt(@NonNull String str, int i, @NonNull MemoryRecords memoryRecords, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        if (memoryRecords.sizeInBytes() != 0 && !InBandEncryptionManager.batchRecordCounts(memoryRecords).stream().allMatch(num -> {
            return num.intValue() == 0;
        })) {
            return resolveAll(str, i, memoryRecords).thenApply(list -> {
                try {
                    MemoryRecords decrypt = decrypt(str, i, memoryRecords, list, allocateBufferForDecrypt(memoryRecords, intFunction));
                    Iterator<E> it = list.iterator();
                    while (it.hasNext()) {
                        DecryptState decryptState = (DecryptState) it.next();
                        if (decryptState != null && decryptState.decryptor() != null) {
                            decryptState.decryptor().close();
                        }
                    }
                    return decrypt;
                } catch (Throwable th) {
                    Iterator<E> it2 = list.iterator();
                    while (it2.hasNext()) {
                        DecryptState decryptState2 = (DecryptState) it2.next();
                        if (decryptState2 != null && decryptState2.decryptor() != null) {
                            decryptState2.decryptor().close();
                        }
                    }
                    throw th;
                }
            });
        }
        return CompletableFuture.completedFuture(memoryRecords);
    }

    private CompletionStage<List<DecryptState<E>>> resolveAll(String str, int i, MemoryRecords memoryRecords) {
        Serde<E> edekSerde = this.dekManager.edekSerde();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        RecordStream.ofRecords(memoryRecords).forEachRecord((recordBatch, record, r13) -> {
            EncryptionVersion decryptionVersion = decryptionVersion(str, i, record);
            if (decryptionVersion == null) {
                arrayList.add(DecryptionDekCache.CacheKey.unencrypted());
                arrayList2.add(DecryptState.none());
            } else {
                arrayList.add((DecryptionDekCache.CacheKey) decryptionVersion.wrapperVersion().readSpecAndEdek(record.value(), edekSerde, DecryptionDekCache.CacheKey::new));
                arrayList2.add(new DecryptState(decryptionVersion));
            }
        });
        return this.filterThreadExecutor.completingOnFilterThread(this.dekCache.getAll(arrayList, this.filterThreadExecutor)).thenApply(map -> {
            return issueDecryptors(map, arrayList, arrayList2);
        });
    }

    @NonNull
    private List<DecryptState<E>> issueDecryptors(@NonNull Map<DecryptionDekCache.CacheKey<E>, Dek<E>> map, @NonNull List<DecryptionDekCache.CacheKey<E>> list, @NonNull List<DecryptState<E>> list2) {
        HashMap hashMap = new HashMap();
        try {
            int size = list.size();
            for (int i = 0; i < size; i++) {
                DecryptionDekCache.CacheKey<E> cacheKey = list.get(i);
                list2.get(i).withDecryptor((Dek.Decryptor) hashMap.computeIfAbsent(cacheKey, cacheKey2 -> {
                    Dek dek = (Dek) map.get(cacheKey);
                    if (dek != null) {
                        return dek.decryptor();
                    }
                    return null;
                }));
            }
            return list2;
        } catch (RuntimeException e) {
            hashMap.forEach((cacheKey3, decryptor) -> {
                if (decryptor != null) {
                    decryptor.close();
                }
            });
            throw e;
        }
    }

    private static ByteBufferOutputStream allocateBufferForDecrypt(MemoryRecords memoryRecords, IntFunction<ByteBufferOutputStream> intFunction) {
        return intFunction.apply(memoryRecords.sizeInBytes());
    }

    @NonNull
    private MemoryRecords decrypt(@NonNull String str, int i, @NonNull MemoryRecords memoryRecords, @NonNull List<DecryptState<E>> list, @NonNull ByteBufferOutputStream byteBufferOutputStream) {
        return RecordStream.ofRecordsWithIndex(memoryRecords).mapPerRecord((recordBatch, record, num) -> {
            return (DecryptState) list.get(num.intValue());
        }).toMemoryRecords(byteBufferOutputStream, new RecordDecryptor(str, i));
    }
}
