package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.rules.ZeroR;
import weka.clusterers.MakeDensityBasedClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ClusterMembership;
import weka.filters.unsupervised.attribute.Standardize;

/* loaded from: input_file:WEB-INF/classes/weka/classifiers/functions/RBFNetwork.class */
public class RBFNetwork extends Classifier implements OptionHandler {
    static final long serialVersionUID = -3669814959712675720L;
    private Logistic m_logistic;
    private LinearRegression m_linear;
    private ClusterMembership m_basisFilter;
    private Standardize m_standardize;
    private int m_numClusters = 2;
    protected double m_ridge = 1.0E-8d;
    private int m_maxIts = -1;
    private int m_clusteringSeed = 1;
    private double m_minStdDev = 0.1d;
    private Classifier m_ZeroR;

    public String globalInfo() {
        return "Class that implements a normalized Gaussian radial basisbasis function network.\nIt uses the k-means clustering algorithm to provide the basis functions and learns either a logistic regression (discrete class problems) or linear regression (numeric class problems) on top of that. Symmetric multivariate Gaussians are fit to the data from each cluster. If the class is nominal it uses the given number of clusters per class.It standardizes all numeric attributes to zero mean and unit variance.";
    }

    @Override // weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = new Logistic().getCapabilities();
        capabilities.or(new LinearRegression().getCapabilities());
        Capabilities classCapabilities = capabilities.getClassCapabilities();
        capabilities.and(new SimpleKMeans().getCapabilities());
        capabilities.or(classCapabilities);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances2);
            return;
        }
        this.m_ZeroR = null;
        this.m_standardize = new Standardize();
        this.m_standardize.setInputFormat(instances2);
        Instances useFilter = Filter.useFilter(instances2, this.m_standardize);
        SimpleKMeans simpleKMeans = new SimpleKMeans();
        simpleKMeans.setNumClusters(this.m_numClusters);
        simpleKMeans.setSeed(this.m_clusteringSeed);
        MakeDensityBasedClusterer makeDensityBasedClusterer = new MakeDensityBasedClusterer();
        makeDensityBasedClusterer.setClusterer(simpleKMeans);
        makeDensityBasedClusterer.setMinStdDev(this.m_minStdDev);
        this.m_basisFilter = new ClusterMembership();
        this.m_basisFilter.setDensityBasedClusterer(makeDensityBasedClusterer);
        this.m_basisFilter.setInputFormat(useFilter);
        Instances useFilter2 = Filter.useFilter(useFilter, this.m_basisFilter);
        if (useFilter.classAttribute().isNominal()) {
            this.m_linear = null;
            this.m_logistic = new Logistic();
            this.m_logistic.setRidge(this.m_ridge);
            this.m_logistic.setMaxIts(this.m_maxIts);
            this.m_logistic.buildClassifier(useFilter2);
            return;
        }
        this.m_logistic = null;
        this.m_linear = new LinearRegression();
        this.m_linear.setAttributeSelectionMethod(new SelectedTag(1, LinearRegression.TAGS_SELECTION));
        this.m_linear.setRidge(this.m_ridge);
        this.m_linear.buildClassifier(useFilter2);
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        this.m_standardize.input(instance);
        this.m_basisFilter.input(this.m_standardize.output());
        Instance output = this.m_basisFilter.output();
        return instance.classAttribute().isNominal() ? this.m_logistic.distributionForInstance(output) : this.m_linear.distributionForInstance(output);
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(String.valueOf(getClass().getName().replaceAll(".*\\.", "")) + "\n");
            stringBuffer.append(String.valueOf(getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=")) + "\n\n");
            stringBuffer.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            stringBuffer.append(this.m_ZeroR.toString());
            return stringBuffer.toString();
        }
        if (this.m_basisFilter == null) {
            return "No classifier built yet!";
        }
        StringBuffer stringBuffer2 = new StringBuffer();
        stringBuffer2.append("Radial basis function network\n");
        stringBuffer2.append(this.m_linear == null ? "(Logistic regression " : "(Linear regression ");
        stringBuffer2.append("applied to K-means clusters as basis functions):\n\n");
        stringBuffer2.append(this.m_linear == null ? this.m_logistic.toString() : this.m_linear.toString());
        return stringBuffer2.toString();
    }

    public String maxItsTipText() {
        return "Maximum number of iterations for the logistic regression to perform. Only applied to discrete class problems.";
    }

    public int getMaxIts() {
        return this.m_maxIts;
    }

    public void setMaxIts(int i) {
        this.m_maxIts = i;
    }

    public String ridgeTipText() {
        return "Set the Ridge value for the logistic or linear regression.";
    }

    public void setRidge(double d) {
        this.m_ridge = d;
    }

    public double getRidge() {
        return this.m_ridge;
    }

    public String numClustersTipText() {
        return "The number of clusters for K-Means to generate.";
    }

    public void setNumClusters(int i) {
        if (i > 0) {
            this.m_numClusters = i;
        }
    }

    public int getNumClusters() {
        return this.m_numClusters;
    }

    public String clusteringSeedTipText() {
        return "The random seed to pass on to K-means.";
    }

    public void setClusteringSeed(int i) {
        this.m_clusteringSeed = i;
    }

    public int getClusteringSeed() {
        return this.m_clusteringSeed;
    }

    public String minStdDevTipText() {
        return "Sets the minimum standard deviation for the clusters.";
    }

    public double getMinStdDev() {
        return this.m_minStdDev;
    }

    public void setMinStdDev(double d) {
        this.m_minStdDev = d;
    }

    @Override // weka.classifiers.Classifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(4);
        vector.addElement(new Option("\tSet the number of clusters (basis functions) to generate. (default = 2).", "B", 1, "-B <number>"));
        vector.addElement(new Option("\tSet the random seed to be used by K-means. (default = 1).", "S", 1, "-S <seed>"));
        vector.addElement(new Option("\tSet the ridge value for the logistic or linear regression.", "R", 1, "-R <ridge>"));
        vector.addElement(new Option("\tSet the maximum number of iterations for the logistic regression. (default -1, until convergence).", "M", 1, "-M <number>"));
        vector.addElement(new Option("\tSet the minimum standard deviation for the clusters. (default 0.1).", "W", 1, "-W <number>"));
        return vector.elements();
    }

    @Override // weka.classifiers.Classifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setDebug(Utils.getFlag('D', strArr));
        String option = Utils.getOption('R', strArr);
        if (option.length() != 0) {
            this.m_ridge = Double.parseDouble(option);
        } else {
            this.m_ridge = 1.0E-8d;
        }
        String option2 = Utils.getOption('M', strArr);
        if (option2.length() != 0) {
            this.m_maxIts = Integer.parseInt(option2);
        } else {
            this.m_maxIts = -1;
        }
        String option3 = Utils.getOption('B', strArr);
        if (option3.length() != 0) {
            setNumClusters(Integer.parseInt(option3));
        }
        String option4 = Utils.getOption('S', strArr);
        if (option4.length() != 0) {
            setClusteringSeed(Integer.parseInt(option4));
        }
        String option5 = Utils.getOption('W', strArr);
        if (option5.length() != 0) {
            setMinStdDev(Double.parseDouble(option5));
        }
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.Classifier, weka.core.OptionHandler
    public String[] getOptions() {
        String[] strArr = new String[10];
        int i = 0 + 1;
        strArr[0] = "-B";
        int i2 = i + 1;
        strArr[i] = new StringBuilder().append(this.m_numClusters).toString();
        int i3 = i2 + 1;
        strArr[i2] = "-S";
        int i4 = i3 + 1;
        strArr[i3] = new StringBuilder().append(this.m_clusteringSeed).toString();
        int i5 = i4 + 1;
        strArr[i4] = "-R";
        int i6 = i5 + 1;
        strArr[i5] = new StringBuilder().append(this.m_ridge).toString();
        int i7 = i6 + 1;
        strArr[i6] = "-M";
        int i8 = i7 + 1;
        strArr[i7] = new StringBuilder().append(this.m_maxIts).toString();
        int i9 = i8 + 1;
        strArr[i8] = "-W";
        int i10 = i9 + 1;
        strArr[i9] = new StringBuilder().append(this.m_minStdDev).toString();
        while (i10 < strArr.length) {
            int i11 = i10;
            i10++;
            strArr[i11] = "";
        }
        return strArr;
    }

    @Override // weka.classifiers.Classifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.10 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new RBFNetwork(), strArr);
    }
}
