package deepboof.visualization;

import com.mxgraph.swing.mxGraphComponent;
import com.mxgraph.util.mxConstants;
import com.mxgraph.view.mxGraph;
import deepboof.Tensor;
import deepboof.forward.ActivationReLU;
import deepboof.forward.ConfigConvolve2D;
import deepboof.forward.ConfigSpatial;
import deepboof.forward.ConstantPadding2D;
import deepboof.forward.FunctionBatchNorm;
import deepboof.forward.FunctionLinear;
import deepboof.forward.SpatialBatchNorm;
import deepboof.forward.SpatialConvolve2D;
import deepboof.forward.SpatialMaxPooling;
import deepboof.forward.SpatialPadding2D;
import deepboof.graph.Node;
import java.awt.BorderLayout;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.swing.BoxLayout;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;

/* loaded from: input_file:deepboof/visualization/SequentialNetworkDisplay.class */
public class SequentialNetworkDisplay extends JPanel {
    mxGraphComponent graphComponent;
    mxGraph graph;
    Map<String, Node> nameToNode = new HashMap();
    JTextPaneAA textConfig = new JTextPaneAA();
    JPanel visualizePanel = new JPanel();

    /* loaded from: input_file:deepboof/visualization/SequentialNetworkDisplay$ClickHandler.class */
    private class ClickHandler extends MouseAdapter {
        private ClickHandler() {
        }

        public void mousePressed(MouseEvent mouseEvent) {
            Object cellAt = SequentialNetworkDisplay.this.graphComponent.getCellAt(mouseEvent.getX(), mouseEvent.getY());
            if (cellAt != null) {
                String label = SequentialNetworkDisplay.this.graph.getLabel(cellAt);
                System.out.println("cell=" + SequentialNetworkDisplay.this.graph.getLabel(cellAt));
                Node node = SequentialNetworkDisplay.this.nameToNode.get(label);
                if (node == null) {
                    return;
                }
                String str = "";
                SequentialNetworkDisplay.this.visualizePanel.removeAll();
                if (node.function instanceof SpatialConvolve2D) {
                    str = str + SequentialNetworkDisplay.this.configString(node.function);
                    List parameters = node.function.getParameters();
                    if (parameters != null) {
                        SequentialNetworkDisplay.this.visualizePanel.add(new Kernel2DGridPanel((Tensor) parameters.get(0), 60, 60));
                    }
                } else if (node.function instanceof FunctionBatchNorm) {
                    str = str + SequentialNetworkDisplay.this.configString(node.function);
                } else if (node.function instanceof SpatialMaxPooling) {
                    str = str + SequentialNetworkDisplay.this.configString(node.function);
                }
                SequentialNetworkDisplay.this.visualizePanel.invalidate();
                final String str2 = str;
                SwingUtilities.invokeLater(new Runnable() { // from class: deepboof.visualization.SequentialNetworkDisplay.ClickHandler.1
                    @Override // java.lang.Runnable
                    public void run() {
                        SequentialNetworkDisplay.this.textConfig.setText(str2);
                    }
                });
            }
        }
    }

    public SequentialNetworkDisplay(List<Node<?, ?>> list) {
        setLayout(new BorderLayout());
        this.graph = new mxGraph();
        Object defaultParent = this.graph.getDefaultParent();
        this.graph.getModel().beginUpdate();
        try {
            Node<?, ?> node = list.get(0);
            Object insertVertex = this.graph.insertVertex(defaultParent, (String) null, getTitle(node), 20.0d, 0, 140, 35);
            this.nameToNode.put(getTitle(node), node);
            setCellColor(node, insertVertex);
            int i = 0 + 35 + 20;
            for (int i2 = 1; i2 < list.size(); i2++) {
                Node<?, ?> node2 = list.get(i2);
                String title = getTitle(node2);
                Object insertVertex2 = this.graph.insertVertex(defaultParent, (String) null, title, 20.0d, i, 140, 35);
                this.nameToNode.put(title, node2);
                setCellColor(node2, insertVertex2);
                this.graph.insertEdge(defaultParent, (String) null, (Object) null, insertVertex, insertVertex2);
                insertVertex = insertVertex2;
                i += 35 + 20;
            }
            JPanel jPanel = new JPanel();
            jPanel.setLayout(new BoxLayout(jPanel, 1));
            this.textConfig.setMinimumSize(new Dimension(200, 300));
            this.textConfig.setPreferredSize(new Dimension(200, 300));
            this.textConfig.setFont(new Font("monospaced", 0, 12));
            jPanel.add(this.textConfig);
            jPanel.add(this.visualizePanel);
            this.graphComponent = new mxGraphComponent(this.graph);
            add(this.graphComponent, "Center");
            add(jPanel, "East");
            this.graphComponent.getGraphControl().addMouseListener(new ClickHandler());
        } finally {
            this.graph.getModel().endUpdate();
        }
    }

    private void setCellColor(Node node, Object obj) {
        String str = "DDDAFF";
        if (SpatialConvolve2D.class.isAssignableFrom(node.function.getClass())) {
            str = "#DDDABB";
        } else if (SpatialBatchNorm.class.isAssignableFrom(node.function.getClass()) || FunctionBatchNorm.class.isAssignableFrom(node.function.getClass())) {
            str = "#EEFABB";
        } else if (FunctionLinear.class.isAssignableFrom(node.function.getClass())) {
            str = "#FFDADD";
        } else if (ActivationReLU.class.isAssignableFrom(node.function.getClass())) {
            str = "#DAFFDD";
        }
        this.graph.setCellStyles(mxConstants.STYLE_FILLCOLOR, str, new Object[]{obj});
    }

    private String getTitle(Node<?, ?> node) {
        return node.function.getClass().getSimpleName() + "\n" + node.name;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String configString(SpatialConvolve2D spatialConvolve2D) {
        ConfigConvolve2D configuration = spatialConvolve2D.getConfiguration();
        return (("" + String.format("Number of kernels   %d\n", Integer.valueOf(configuration.getTotalKernels()))) + configString((ConfigSpatial) configuration)) + configString(spatialConvolve2D.getPadding());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String configString(FunctionBatchNorm functionBatchNorm) {
        return "gamma-beta          " + functionBatchNorm.hasGammaBeta() + "\n";
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String configString(SpatialMaxPooling spatialMaxPooling) {
        return "" + configString(spatialMaxPooling.getConfiguration());
    }

    private String configString(ConfigSpatial configSpatial) {
        return ((("" + String.format("Period X            %d\n", Integer.valueOf(configSpatial.periodX))) + String.format("       Y            %d\n", Integer.valueOf(configSpatial.periodY))) + String.format("Window width        %d\n", Integer.valueOf(configSpatial.HH))) + String.format("       height       %d\n", Integer.valueOf(configSpatial.WW));
    }

    private String configString(SpatialPadding2D spatialPadding2D) {
        String str = "\n" + spatialPadding2D.getClass().getSimpleName() + "\n";
        if (spatialPadding2D instanceof ConstantPadding2D) {
            str = str + String.format("   value         %6.2f\n", Double.valueOf(((ConstantPadding2D) spatialPadding2D).getPaddingValue()));
        }
        return (((str + String.format("      x0           %2d\n", Integer.valueOf(spatialPadding2D.getPaddingCol0()))) + String.format("      y0           %2d\n", Integer.valueOf(spatialPadding2D.getPaddingRow0()))) + String.format("      x1           %2d\n", Integer.valueOf(spatialPadding2D.getPaddingCol1()))) + String.format("      y1           %2d\n", Integer.valueOf(spatialPadding2D.getPaddingRow1()));
    }
}
