package io.kroxylicious.filter.encryption.inband;

import edu.umd.cs.findbugs.annotations.NonNull;
import io.kroxylicious.filter.encryption.EncryptionException;
import io.kroxylicious.filter.encryption.EncryptionManager;
import io.kroxylicious.filter.encryption.EncryptionScheme;
import io.kroxylicious.filter.encryption.EncryptionVersion;
import io.kroxylicious.filter.encryption.FilterThreadExecutor;
import io.kroxylicious.filter.encryption.dek.BufferTooSmallException;
import io.kroxylicious.filter.encryption.dek.Dek;
import io.kroxylicious.filter.encryption.dek.ExhaustedDekException;
import io.kroxylicious.filter.encryption.records.RecordStream;
import io.kroxylicious.kms.service.Serde;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.IntFunction;
import org.apache.kafka.common.record.MemoryRecords;
import org.apache.kafka.common.record.MutableRecordBatch;
import org.apache.kafka.common.utils.BufferSupplier;
import org.apache.kafka.common.utils.ByteBufferOutputStream;
import org.apache.kafka.common.utils.CloseableIterator;

/* loaded from: input_file:io/kroxylicious/filter/encryption/inband/InBandEncryptionManager.class */
public class InBandEncryptionManager<K, E> implements EncryptionManager<K> {
    private static final int MAX_ATTEMPTS = 3;
    private final EncryptionVersion encryptionVersion = EncryptionVersion.V2;
    private final Serde<E> edekSerde;
    private final EncryptionDekCache<K, E> dekCache;

    @NonNull
    private final FilterThreadExecutor filterThreadExecutor;
    private final int recordBufferInitialBytes;
    private final int recordBufferMaxBytes;

    /* JADX INFO: Access modifiers changed from: package-private */
    @NonNull
    public static List<Integer> batchRecordCounts(@NonNull MemoryRecords memoryRecords) {
        ArrayList arrayList = new ArrayList();
        Iterator it = memoryRecords.batches().iterator();
        while (it.hasNext()) {
            arrayList.add(Integer.valueOf(recordCount((MutableRecordBatch) it.next())));
        }
        return arrayList;
    }

    private static int recordCount(@NonNull MutableRecordBatch mutableRecordBatch) {
        Integer countOrNull = mutableRecordBatch.countOrNull();
        if (countOrNull == null) {
            CloseableIterator skipKeyValueIterator = mutableRecordBatch.skipKeyValueIterator(BufferSupplier.NO_CACHING);
            int i = 0;
            while (skipKeyValueIterator.hasNext()) {
                i++;
                skipKeyValueIterator.next();
            }
            countOrNull = Integer.valueOf(i);
        }
        return countOrNull.intValue();
    }

    public InBandEncryptionManager(@NonNull Serde<E> serde, int i, int i2, @NonNull EncryptionDekCache<K, E> encryptionDekCache, @NonNull FilterThreadExecutor filterThreadExecutor) {
        this.filterThreadExecutor = filterThreadExecutor;
        this.edekSerde = (Serde) Objects.requireNonNull(serde);
        if (i <= 0) {
            throw new IllegalArgumentException();
        }
        this.recordBufferInitialBytes = i;
        if (i2 <= 0) {
            throw new IllegalArgumentException();
        }
        this.recordBufferMaxBytes = i2;
        this.dekCache = encryptionDekCache;
    }

    CompletionStage<Dek<E>> currentDek(@NonNull EncryptionScheme<K> encryptionScheme) {
        return this.dekCache.get(encryptionScheme, this.filterThreadExecutor);
    }

    @Override // io.kroxylicious.filter.encryption.EncryptionManager
    @NonNull
    public CompletionStage<MemoryRecords> encrypt(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull MemoryRecords memoryRecords, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        if (memoryRecords.sizeInBytes() == 0) {
            return CompletableFuture.completedFuture(memoryRecords);
        }
        List<Integer> batchRecordCounts = batchRecordCounts(memoryRecords);
        return batchRecordCounts.stream().allMatch(num -> {
            return num.intValue() == 0;
        }) ? CompletableFuture.completedFuture(memoryRecords) : attemptEncrypt(str, i, encryptionScheme, memoryRecords, 0, batchRecordCounts, intFunction);
    }

    private ByteBufferOutputStream allocateBufferForEncrypt(@NonNull MemoryRecords memoryRecords, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        return intFunction.apply(2 * memoryRecords.sizeInBytes());
    }

    private CompletionStage<MemoryRecords> attemptEncrypt(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull MemoryRecords memoryRecords, int i2, @NonNull List<Integer> list, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        int sum = list.stream().mapToInt(num -> {
            return num.intValue();
        }).sum();
        return i2 >= MAX_ATTEMPTS ? CompletableFuture.failedFuture(new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + sum + " records for topic " + str + " partition " + i + " after " + i2 + " attempts")) : currentDek(encryptionScheme).thenCompose(dek -> {
            if (!dek.isDestroyed()) {
                try {
                    Dek<E>.Encryptor encryptor = dek.encryptor(sum);
                    try {
                        CompletableFuture completedFuture = CompletableFuture.completedFuture(encryptBatches(str, i, encryptionScheme, memoryRecords, encryptor, intFunction));
                        if (encryptor != null) {
                            encryptor.close();
                        }
                        return completedFuture;
                    } finally {
                    }
                } catch (ExhaustedDekException e) {
                    rotateKeyContext(encryptionScheme, dek);
                } catch (Exception e2) {
                    return CompletableFuture.failedFuture(e2);
                }
            }
            return attemptEncrypt(str, i, encryptionScheme, memoryRecords, i2 + 1, list, intFunction);
        });
    }

    @NonNull
    private MemoryRecords encryptBatches(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull MemoryRecords memoryRecords, @NonNull Dek<E>.Encryptor encryptor, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        ByteBuffer byteBuffer;
        ByteBuffer allocate = ByteBuffer.allocate(this.recordBufferInitialBytes);
        while (true) {
            try {
                byteBuffer = allocate;
                return RecordStream.ofRecords(memoryRecords).mapConstant(encryptor).toMemoryRecords(allocateBufferForEncrypt(memoryRecords, intFunction), new RecordEncryptor(str, i, this.encryptionVersion, encryptionScheme, this.edekSerde, byteBuffer));
            } catch (BufferTooSmallException e) {
                int capacity = 2 * byteBuffer.capacity();
                if (capacity > this.recordBufferMaxBytes) {
                    throw new EncryptionException("Record buffer cannot grow greater than " + this.recordBufferMaxBytes + " bytes");
                }
                allocate = ByteBuffer.allocate(capacity);
            }
        }
    }

    private void rotateKeyContext(@NonNull EncryptionScheme<K> encryptionScheme, @NonNull Dek<E> dek) {
        dek.destroyForEncrypt();
        this.dekCache.invalidate(encryptionScheme);
    }
}
