package com.code_intelligence.jazzer.junit;

import com.code_intelligence.jazzer.driver.LifecycleMethodsInvoker;
import com.code_intelligence.jazzer.utils.UnsafeProvider;
import java.lang.reflect.Constructor;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.TestInstancePostProcessor;
import org.junit.jupiter.api.extension.TestInstances;
import org.junit.jupiter.engine.execution.AfterEachMethodAdapter;
import org.junit.jupiter.engine.execution.BeforeEachMethodAdapter;
import org.junit.jupiter.engine.execution.DefaultExecutableInvoker;
import org.junit.jupiter.engine.extension.ExtensionRegistry;

/* loaded from: input_file:com/code_intelligence/jazzer/junit/JUnitLifecycleMethodsInvoker.class */
final class JUnitLifecycleMethodsInvoker implements LifecycleMethodsInvoker {
    private final LifecycleMethodsInvoker.ThrowingRunnable testClassInstanceUpdater;
    private final Supplier<Object> testClassInstanceSupplier;
    private final LifecycleMethodsInvoker.ThrowingRunnable[] beforeEachRunnables;
    private final LifecycleMethodsInvoker.ThrowingRunnable[] afterEachRunnables;

    /* JADX INFO: Access modifiers changed from: package-private */
    @FunctionalInterface
    /* loaded from: input_file:com/code_intelligence/jazzer/junit/JUnitLifecycleMethodsInvoker$ThrowingConsumer.class */
    public interface ThrowingConsumer {
        void accept(Object obj) throws Exception;
    }

    private JUnitLifecycleMethodsInvoker(LifecycleMethodsInvoker.ThrowingRunnable throwingRunnable, Supplier<Object> supplier, LifecycleMethodsInvoker.ThrowingRunnable[] throwingRunnableArr, LifecycleMethodsInvoker.ThrowingRunnable[] throwingRunnableArr2) {
        this.testClassInstanceUpdater = throwingRunnable;
        this.testClassInstanceSupplier = supplier;
        this.beforeEachRunnables = throwingRunnableArr;
        this.afterEachRunnables = throwingRunnableArr2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static LifecycleMethodsInvoker of(ExtensionContext extensionContext, Lifecycle lifecycle) {
        if (lifecycle == Lifecycle.PER_TEST) {
            return LifecycleMethodsInvoker.noop(extensionContext.getRequiredTestInstance());
        }
        if (extensionContext.getTestInstances().isPresent() && ((TestInstances) extensionContext.getTestInstances().get()).getAllInstances().size() > 1) {
            throw new IllegalArgumentException("Jazzer does not support nested test classes with LifecycleMode.PER_EXECUTION. Either move your fuzz test to a top-level class or set lifecycle = LifecycleMode.PER_TEST on @FuzzTest.");
        }
        Optional<ExtensionRegistry> extensionRegistryViaHack = getExtensionRegistryViaHack(extensionContext);
        if (!extensionRegistryViaHack.isPresent()) {
            throw new IllegalArgumentException("Jazzer does not support BeforeEach and AfterEach callbacks with this version of JUnit. Either update to at least JUnit 5.9.0 or set lifecycle = LifecycleMode.PER_TEST on @FuzzTest.");
        }
        ExtensionRegistry extensionRegistry = extensionRegistryViaHack.get();
        Object[] objArr = {extensionContext.getRequiredTestInstance()};
        TestInstances makeTestInstances = makeTestInstances(extensionContext.getRequiredTestClass(), () -> {
            return objArr[0];
        });
        ExtensionContext extensionContext2 = (ExtensionContext) Proxy.newProxyInstance(JUnitLifecycleMethodsInvoker.class.getClassLoader(), new Class[]{ExtensionContext.class}, (obj, method, objArr2) -> {
            String name = method.getName();
            boolean z = -1;
            switch (name.hashCode()) {
                case -1398989379:
                    if (name.equals("getTestInstance")) {
                        z = false;
                        break;
                    }
                    break;
                case -1077194500:
                    if (name.equals("getRequiredTestInstance")) {
                        z = 2;
                        break;
                    }
                    break;
                case -418997674:
                    if (name.equals("getTestInstances")) {
                        z = true;
                        break;
                    }
                    break;
                case 966708983:
                    if (name.equals("getRequiredTestInstances")) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                case true:
                    return Optional.empty();
                case true:
                case true:
                    return Optional.empty().get();
                default:
                    return method.invoke(extensionContext, objArr2);
            }
        });
        ExtensionContext extensionContext3 = (ExtensionContext) Proxy.newProxyInstance(JUnitLifecycleMethodsInvoker.class.getClassLoader(), new Class[]{ExtensionContext.class}, (obj2, method2, objArr3) -> {
            String name = method2.getName();
            boolean z = -1;
            switch (name.hashCode()) {
                case -1398989379:
                    if (name.equals("getTestInstance")) {
                        z = false;
                        break;
                    }
                    break;
                case -1077194500:
                    if (name.equals("getRequiredTestInstance")) {
                        z = true;
                        break;
                    }
                    break;
                case -418997674:
                    if (name.equals("getTestInstances")) {
                        z = 2;
                        break;
                    }
                    break;
                case 966708983:
                    if (name.equals("getRequiredTestInstances")) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return Optional.of(objArr[0]);
                case true:
                    return objArr[0];
                case true:
                    return Optional.of(makeTestInstances);
                case true:
                    return makeTestInstances;
                default:
                    return method2.invoke(extensionContext, objArr3);
            }
        });
        LifecycleMethodsInvoker.ThrowingRunnable[] throwingRunnableArr = (LifecycleMethodsInvoker.ThrowingRunnable[]) Stream.concat(extensionRegistry.stream(BeforeEachCallback.class).map(beforeEachCallback -> {
            return () -> {
                beforeEachCallback.beforeEach(extensionContext3);
            };
        }), extensionRegistry.stream(BeforeEachMethodAdapter.class).map(beforeEachMethodAdapter -> {
            return () -> {
                beforeEachMethodAdapter.invokeBeforeEachMethod(extensionContext3, extensionRegistry);
            };
        })).toArray(i -> {
            return new LifecycleMethodsInvoker.ThrowingRunnable[i];
        });
        ArrayList arrayList = (ArrayList) Stream.concat(extensionRegistry.stream(AfterEachCallback.class).map(afterEachCallback -> {
            return () -> {
                afterEachCallback.afterEach(extensionContext3);
            };
        }), extensionRegistry.stream(AfterEachMethodAdapter.class).map(afterEachMethodAdapter -> {
            return () -> {
                afterEachMethodAdapter.invokeAfterEachMethod(extensionContext3, extensionRegistry);
            };
        })).collect(Collectors.toCollection(ArrayList::new));
        Collections.reverse(arrayList);
        Constructor<?> testClassNoArgsConstructor = getTestClassNoArgsConstructor(extensionContext3);
        ThrowingConsumer[] throwingConsumerArr = (ThrowingConsumer[]) extensionRegistry.stream(TestInstancePostProcessor.class).map(testInstancePostProcessor -> {
            return obj3 -> {
                testInstancePostProcessor.postProcessTestInstance(obj3, extensionContext2);
            };
        }).toArray(i2 -> {
            return new ThrowingConsumer[i2];
        });
        return new JUnitLifecycleMethodsInvoker(() -> {
            Object newInstance = testClassNoArgsConstructor.newInstance(new Object[0]);
            for (ThrowingConsumer throwingConsumer : throwingConsumerArr) {
                throwingConsumer.accept(newInstance);
            }
            objArr[0] = newInstance;
        }, () -> {
            return objArr[0];
        }, throwingRunnableArr, (LifecycleMethodsInvoker.ThrowingRunnable[]) arrayList.toArray(new LifecycleMethodsInvoker.ThrowingRunnable[0]));
    }

    private static TestInstances makeTestInstances(final Class<?> cls, final Supplier<Object> supplier) {
        return new TestInstances() { // from class: com.code_intelligence.jazzer.junit.JUnitLifecycleMethodsInvoker.1
            public Object getInnermostInstance() {
                return supplier.get();
            }

            public List<Object> getEnclosingInstances() {
                return Collections.emptyList();
            }

            public List<Object> getAllInstances() {
                return Collections.singletonList(supplier.get());
            }

            public <T> Optional<T> findInstance(Class<T> cls2) {
                return cls == cls2 ? Optional.of(supplier.get()) : Optional.empty();
            }
        };
    }

    private static Constructor<?> getTestClassNoArgsConstructor(ExtensionContext extensionContext) {
        Class requiredTestClass = extensionContext.getRequiredTestClass();
        if (requiredTestClass.getEnclosingClass() != null) {
            throw new IllegalArgumentException(String.format("The test class %s is an inner class, which is not supported with LifecycleMode.PER_EXECUTION. Either make it a top-level class or set lifecycle = LifecycleMode.PER_TEST on @FuzzTest.", requiredTestClass.getName()));
        }
        try {
            Constructor<?> declaredConstructor = requiredTestClass.getDeclaredConstructor(new Class[0]);
            declaredConstructor.setAccessible(true);
            return declaredConstructor;
        } catch (NoSuchMethodException e) {
            throw new IllegalArgumentException(String.format("The test class %s has no default constructor, which is not supported with LifecycleMode.PER_EXECUTION. Either add such a constructor or set lifecycle = LifecycleMode.PER_TEST on @FuzzTest.", requiredTestClass.getName()));
        }
    }

    private static Optional<ExtensionRegistry> getExtensionRegistryViaHack(ExtensionContext extensionContext) {
        try {
            Class.forName("org.junit.jupiter.engine.execution.DefaultExecutableInvoker");
            return Arrays.stream(DefaultExecutableInvoker.class.getDeclaredFields()).filter(field -> {
                return field.getType() == ExtensionRegistry.class;
            }).findFirst().flatMap(field2 -> {
                return Optional.ofNullable((ExtensionRegistry) UnsafeProvider.getUnsafe().getObject(extensionContext.getExecutableInvoker(), UnsafeProvider.getUnsafe().objectFieldOffset(field2)));
            });
        } catch (ClassNotFoundException e) {
            return Optional.empty();
        }
    }

    public void beforeFirstExecution() {
    }

    public void beforeEachExecution() throws Throwable {
        this.testClassInstanceUpdater.run();
        for (LifecycleMethodsInvoker.ThrowingRunnable throwingRunnable : this.beforeEachRunnables) {
            throwingRunnable.run();
        }
    }

    public void afterEachExecution() throws Throwable {
        for (LifecycleMethodsInvoker.ThrowingRunnable throwingRunnable : this.afterEachRunnables) {
            throwingRunnable.run();
        }
    }

    public void afterLastExecution() {
    }

    public Object getTestClassInstance() {
        return this.testClassInstanceSupplier.get();
    }
}
