package com.flipkart.fdp.ml.transformer;

import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/flipkart/fdp/ml/transformer/LogisticRegressionTransformer.class */
public class LogisticRegressionTransformer implements Transformer {
    private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionTransformer.class);
    private final LogisticRegressionModelInfo modelInfo;

    public LogisticRegressionTransformer(LogisticRegressionModelInfo logisticRegressionModelInfo) {
        this.modelInfo = logisticRegressionModelInfo;
    }

    public double getProbability(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * this.modelInfo.getWeights()[i];
        }
        return 1.0d / (1.0d + Math.exp(-(d + this.modelInfo.getIntercept())));
    }

    public double predict(double d) {
        return d > this.modelInfo.getThreshold() ? 1.0d : 0.0d;
    }

    @Override // com.flipkart.fdp.ml.transformer.Transformer
    public void transform(Map<String, Object> map) {
        map.put(this.modelInfo.getProbabilityKey(), Double.valueOf(getProbability((double[]) map.get(this.modelInfo.getInputKeys().iterator().next()))));
        map.put(this.modelInfo.getOutputKeys().iterator().next(), Double.valueOf(predict(((Double) map.get(this.modelInfo.getProbabilityKey())).doubleValue())));
    }

    @Override // com.flipkart.fdp.ml.transformer.Transformer
    public Set<String> getInputKeys() {
        return this.modelInfo.getInputKeys();
    }

    @Override // com.flipkart.fdp.ml.transformer.Transformer
    public Set<String> getOutputKeys() {
        return this.modelInfo.getOutputKeys();
    }
}
