package io.kroxylicious.proxy.filter.oauthbearer;

import com.github.benmanes.caffeine.cache.LoadingCache;
import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.RequestFilterResult;
import io.kroxylicious.proxy.filter.ResponseFilterResult;
import io.kroxylicious.proxy.filter.SaslAuthenticateRequestFilter;
import io.kroxylicious.proxy.filter.SaslAuthenticateResponseFilter;
import io.kroxylicious.proxy.filter.SaslHandshakeRequestFilter;
import io.kroxylicious.proxy.filter.oauthbearer.sasl.BackoffStrategy;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.HexFormat;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.message.SaslAuthenticateResponseData;
import org.apache.kafka.common.message.SaslHandshakeRequestData;
import org.apache.kafka.common.message.SaslHandshakeResponseData;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallbackHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kroxylicious/proxy/filter/oauthbearer/OauthBearerValidationFilter.class */
public class OauthBearerValidationFilter implements SaslHandshakeRequestFilter, SaslAuthenticateRequestFilter, SaslAuthenticateResponseFilter {
    private static final Logger LOGGER = LoggerFactory.getLogger(OauthBearerValidationFilter.class);
    private final ScheduledExecutorService executorService;
    private final BackoffStrategy strategy;
    private final LoadingCache<String, AtomicInteger> rateLimiter;
    private final OAuthBearerValidatorCallbackHandler oauthHandler;
    private SaslServer saslServer;
    private boolean validateAuthentication = true;

    public OauthBearerValidationFilter(ScheduledExecutorService scheduledExecutorService, SharedOauthBearerValidationContext sharedOauthBearerValidationContext) {
        this.executorService = scheduledExecutorService;
        this.strategy = sharedOauthBearerValidationContext.backoffStrategy();
        this.rateLimiter = sharedOauthBearerValidationContext.rateLimiter();
        this.oauthHandler = sharedOauthBearerValidationContext.oauthHandler();
    }

    public CompletionStage<RequestFilterResult> onSaslHandshakeRequest(short s, RequestHeaderData requestHeaderData, SaslHandshakeRequestData saslHandshakeRequestData, FilterContext filterContext) {
        if (this.saslServer != null) {
            LOGGER.debug("SASL error : Handshake request with a not null SASL server");
            return filterContext.requestFilterResultBuilder().shortCircuitResponse(new SaslHandshakeResponseData().setErrorCode(Errors.ILLEGAL_SASL_STATE.code())).withCloseConnection().completed();
        }
        try {
            if ("OAUTHBEARER".equals(saslHandshakeRequestData.mechanism()) && this.validateAuthentication) {
                this.saslServer = Sasl.createSaslServer("OAUTHBEARER", "kafka", (String) null, (Map) null, this.oauthHandler);
            } else {
                this.validateAuthentication = false;
            }
            return filterContext.forwardRequest(requestHeaderData, saslHandshakeRequestData);
        } catch (SaslException e) {
            LOGGER.debug("SASL error : {}", e.getMessage(), e);
            return filterContext.requestFilterResultBuilder().shortCircuitResponse(new SaslHandshakeResponseData().setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code())).withCloseConnection().completed();
        }
    }

    public CompletionStage<RequestFilterResult> onSaslAuthenticateRequest(short s, RequestHeaderData requestHeaderData, SaslAuthenticateRequestData saslAuthenticateRequestData, FilterContext filterContext) {
        if (!this.validateAuthentication) {
            return filterContext.forwardRequest(requestHeaderData, saslAuthenticateRequestData);
        }
        SaslServer saslServer = this.saslServer;
        if (saslServer != null) {
            this.saslServer = null;
            return authenticate(saslServer, saslAuthenticateRequestData.authBytes()).thenCompose(bArr -> {
                return filterContext.forwardRequest(requestHeaderData, saslAuthenticateRequestData);
            }).exceptionallyCompose(th -> {
                if (!(th.getCause() instanceof SaslAuthenticationException)) {
                    LOGGER.debug("SASL error : {}", th.getMessage(), th);
                    return filterContext.requestFilterResultBuilder().shortCircuitResponse(new SaslAuthenticateResponseData().setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()).setAuthBytes(saslAuthenticateRequestData.authBytes())).withCloseConnection().completed();
                }
                SaslAuthenticateResponseData authBytes = new SaslAuthenticateResponseData().setErrorCode(Errors.SASL_AUTHENTICATION_FAILED.code()).setErrorMessage(th.getMessage()).setAuthBytes(saslAuthenticateRequestData.authBytes());
                LOGGER.debug("SASL Authentication failed : {}", th.getMessage(), th);
                return filterContext.requestFilterResultBuilder().shortCircuitResponse(authBytes).withCloseConnection().completed();
            });
        }
        SaslAuthenticateResponseData authBytes = new SaslAuthenticateResponseData().setErrorCode(Errors.ILLEGAL_SASL_STATE.code()).setErrorMessage("Unexpected SASL request").setAuthBytes(saslAuthenticateRequestData.authBytes());
        LOGGER.debug("SASL invalid state");
        return filterContext.requestFilterResultBuilder().shortCircuitResponse(authBytes).withCloseConnection().completed();
    }

    public CompletionStage<ResponseFilterResult> onSaslAuthenticateResponse(short s, ResponseHeaderData responseHeaderData, SaslAuthenticateResponseData saslAuthenticateResponseData, FilterContext filterContext) {
        if (saslAuthenticateResponseData.errorCode() == Errors.NONE.code()) {
            this.validateAuthentication = false;
        }
        return filterContext.forwardResponse(responseHeaderData, saslAuthenticateResponseData);
    }

    private CompletionStage<byte[]> authenticate(SaslServer saslServer, byte[] bArr) {
        try {
            String createCacheKey = createCacheKey(bArr);
            return schedule(() -> {
                try {
                    return CompletableFuture.completedStage(doAuthenticate(saslServer, bArr));
                } catch (Exception e) {
                    return CompletableFuture.failedStage(e);
                }
            }, this.strategy.getDelay(((AtomicInteger) this.rateLimiter.get(createCacheKey)).get())).whenComplete((bArr2, th) -> {
                if (th != null) {
                    ((AtomicInteger) this.rateLimiter.get(createCacheKey)).incrementAndGet();
                } else {
                    this.rateLimiter.invalidate(createCacheKey);
                }
            });
        } catch (NoSuchAlgorithmException e) {
            return CompletableFuture.failedStage(e);
        }
    }

    private byte[] doAuthenticate(SaslServer saslServer, byte[] bArr) throws SaslException {
        try {
            byte[] evaluateResponse = saslServer.evaluateResponse(bArr);
            if (saslServer.isComplete()) {
                return evaluateResponse;
            }
            throw new SaslAuthenticationException("SASL failed : " + new String(evaluateResponse, StandardCharsets.UTF_8));
        } finally {
            saslServer.dispose();
        }
    }

    private <A> CompletionStage<A> schedule(Supplier<CompletionStage<A>> supplier, Duration duration) {
        if (duration.equals(Duration.ZERO)) {
            return supplier.get();
        }
        CompletableFuture completableFuture = new CompletableFuture();
        this.executorService.schedule(() -> {
            ((CompletionStage) supplier.get()).whenComplete((obj, th) -> {
                if (th != null) {
                    completableFuture.completeExceptionally(th);
                } else {
                    completableFuture.complete(obj);
                }
            });
        }, duration.toMillis(), TimeUnit.MILLISECONDS);
        return completableFuture;
    }

    static String createCacheKey(byte[] bArr) throws NoSuchAlgorithmException {
        return HexFormat.of().formatHex(MessageDigest.getInstance("SHA-256").digest(bArr));
    }
}
