package io.leonis.algieba.filter;

import io.leonis.algieba.statistic.Distribution;
import io.leonis.algieba.statistic.SimpleDistribution;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.inverse.InvertMatrix;

/* loaded from: input_file:io/leonis/algieba/filter/ExtendedKalmanFilter.class */
public class ExtendedKalmanFilter {
    public Distribution apply(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, Distribution distribution, Distribution distribution2) {
        BiFunction<INDArray, INDArray, INDArray> biFunction = (iNDArray6, iNDArray7) -> {
            return iNDArray.mmul(iNDArray6).add(iNDArray3.mmul(iNDArray3.mmul(iNDArray7))).add(iNDArray5);
        };
        iNDArray2.getClass();
        return apply(biFunction, iNDArray2::mul, iNDArray4, Nd4j.eye(iNDArray.rows()), Nd4j.eye(iNDArray2.rows()), iNDArray5, distribution, distribution2);
    }

    public Distribution apply(BiFunction<INDArray, INDArray, INDArray> biFunction, Function<INDArray, INDArray> function, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, Distribution distribution, Distribution distribution2) {
        INDArray apply = biFunction.apply(distribution2.getMean(), iNDArray3);
        INDArray add = iNDArray.mmul(distribution2.getMean().mmul(iNDArray.transpose())).add(iNDArray4);
        INDArray mmul = add.mmul(iNDArray2.transpose().mmul(InvertMatrix.invert(iNDArray2.mmul(add.mmul(iNDArray2.transpose())).add(distribution.getCovariance()), false)));
        INDArray add2 = apply.add(mmul.mmul(function.apply(distribution2.getMean()).sub(function.apply(apply))));
        return new SimpleDistribution(add2, Nd4j.eye(add2.rows()).sub(mmul.mul(iNDArray2)).mmul(add));
    }
}
