package network.aika.training;

import java.util.TreeSet;
import network.aika.Document;
import network.aika.Utils;
import network.aika.neuron.INeuron;
import network.aika.neuron.Synapse;
import network.aika.neuron.activation.Activation;
import network.aika.training.SynapseEvaluation;

/* loaded from: input_file:network/aika/training/LongTermLearning.class */
public class LongTermLearning {

    /* loaded from: input_file:network/aika/training/LongTermLearning$Config.class */
    public static class Config {
        public SynapseEvaluation synapseEvaluation;
        public double ltpLearnRate;
        public double ltdLearnRate;
        public double beta;
        public boolean createNewSynapses;

        public Config setSynapseEvaluation(SynapseEvaluation synapseEvaluation) {
            this.synapseEvaluation = synapseEvaluation;
            return this;
        }

        public Config setLTPLearnRate(double d) {
            this.ltpLearnRate = d;
            return this;
        }

        public Config setLTDLearnRate(double d) {
            this.ltdLearnRate = d;
            return this;
        }

        public Config setBeta(double d) {
            this.beta = d;
            return this;
        }

        public Config setCreateNewSynapses(boolean z) {
            this.createNewSynapses = z;
            return this;
        }
    }

    public static void train(Document document, Config config) {
        document.getActivations().filter(activation -> {
            return activation.targetValue == null ? activation.isFinalActivation() : activation.targetValue.doubleValue() > 0.0d;
        }).forEach(activation2 -> {
            longTermPotentiation(document, config, activation2);
            longTermDepression(document, config, activation2, false);
            longTermDepression(document, config, activation2, true);
        });
    }

    private static double hConj(Activation activation) {
        INeuron iNeuron = activation.getINeuron();
        return activation.getFinalState().net / ((iNeuron.biasSum + iNeuron.posDirSum) + iNeuron.posRecSum);
    }

    public static void longTermPotentiation(Document document, Config config, Activation activation) {
        activation.getINeuron();
        double doubleValue = config.ltpLearnRate * (1.0d - activation.getFinalState().value) * Utils.nullSafeMax(Double.valueOf(activation.getFinalState().value), activation.targetValue).doubleValue();
        if (config.createNewSynapses) {
            document.getActivations().filter(activation2 -> {
                return activation2.targetValue == null ? activation2.isFinalActivation() : activation2.targetValue.doubleValue() > 0.0d;
            }).filter(activation3 -> {
                return activation3.node != activation.node;
            }).forEach(activation4 -> {
                synapseLTP(config, null, activation4, activation, doubleValue);
            });
        } else {
            activation.neuronInputs.values().stream().filter(link -> {
                return link.input.targetValue == null ? link.input.isFinalActivation() : link.input.targetValue.doubleValue() > 0.0d;
            }).forEach(link2 -> {
                synapseLTP(config, link2.synapse, link2.input, activation, doubleValue);
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void synapseLTP(Config config, Synapse synapse, Activation activation, Activation activation2, double d) {
        SynapseEvaluation.Result evaluate = config.synapseEvaluation.evaluate(synapse, activation, activation2);
        if (evaluate == null) {
            return;
        }
        double hConj = activation.getFinalState().value * d * evaluate.significance * (synapse.isConjunction(false, false) ? hConj(activation2) : 1.0d);
        if (hConj > 0.0d) {
            Synapse.createOrLookup(activation2.doc, null, evaluate.synapseKey, evaluate.relations, evaluate.distanceFunction, activation.getNeuron(), activation2.getNeuron()).updateDelta(activation2.doc, hConj, (-config.beta) * hConj);
        }
    }

    public static void longTermDepression(Document document, Config config, Activation activation, boolean z) {
        if (activation.getFinalState().value <= 0.0d) {
            return;
        }
        INeuron iNeuron = activation.getINeuron();
        TreeSet treeSet = new TreeSet(z ? Synapse.OUTPUT_SYNAPSE_COMP : Synapse.INPUT_SYNAPSE_COMP);
        (z ? activation.neuronOutputs : activation.neuronInputs.values()).forEach(link -> {
            Activation activation2 = z ? link.output : link.input;
            if (activation2.targetValue == null) {
                if (!activation2.isFinalActivation()) {
                    return;
                }
            } else if (activation2.targetValue.doubleValue() <= 0.0d) {
                return;
            }
            treeSet.add(link.synapse);
        });
        (z ? iNeuron.outputSynapses : iNeuron.inputSynapses).values().stream().filter(synapse -> {
            return (synapse.isNegative() || treeSet.contains(synapse)) ? false : true;
        }).forEach(synapse2 -> {
            if (synapse2.isConjunction(false, false) != z) {
                SynapseEvaluation.Result evaluate = config.synapseEvaluation.evaluate(synapse2, z ? activation : null, z ? null : activation);
                if (evaluate != null) {
                    synapse2.updateDelta(document, (-config.ltdLearnRate) * activation.getFinalState().value * evaluate.significance, 0.0d);
                    evaluate.deleteMode.checkIfDelete(synapse2, false);
                }
            }
        });
    }
}
