import javax.swing.*; import java.awt.*; import java.awt.image.BufferedImage; public class NetworkVisualizer extends JFrame { private lenet5 network; private JPanel mainPanel; private LayerPanel inputPanel; private LayerPanel conv1Panel; private LayerPanel pool1Panel; private LayerPanel conv2Panel; private LayerPanel pool2Panel; private BarChartPanel fc1Panel; private BarChartPanel fc2Panel; private BarChartPanel outputPanel; private JLabel predictionLabel; public NetworkVisualizer(lenet5 network) { this.network = network; setupUI(); } private void setupUI() { setTitle("LeNet-5 Network Visualization"); setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); setLayout(new BorderLayout()); mainPanel = new JPanel(); mainPanel.setLayout(new BoxLayout(mainPanel, BoxLayout.Y_AXIS)); mainPanel.setBackground(Color.WHITE); // Create panels for each layer (all feature maps in single rows) inputPanel = new LayerPanel("Input (32x32)", 1, 1); conv1Panel = new LayerPanel("Conv1 Output (6x28x28)", 1, 6); pool1Panel = new LayerPanel("Pool1 Output (6x14x14)", 1, 6); conv2Panel = new LayerPanel("Conv2 Output (16x10x10)", 1, 16); pool2Panel = new LayerPanel("Pool2 Output (16x5x5)", 1, 16); fc1Panel = new BarChartPanel("FC1 Output (120 neurons)", 120); fc2Panel = new BarChartPanel("FC2 Output (84 neurons)", 84); outputPanel = new BarChartPanel("Output Probabilities", 10); predictionLabel = new JLabel("Prediction: -"); predictionLabel.setFont(new Font("Arial", Font.BOLD, 18)); predictionLabel.setAlignmentX(Component.CENTER_ALIGNMENT); // Add all panels mainPanel.add(inputPanel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(conv1Panel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(pool1Panel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(conv2Panel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(pool2Panel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(fc1Panel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(fc2Panel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(outputPanel); mainPanel.add(Box.createVerticalStrut(10)); mainPanel.add(predictionLabel); mainPanel.add(Box.createVerticalStrut(10)); JScrollPane scrollPane = new JScrollPane(mainPanel); scrollPane.setVerticalScrollBarPolicy(JScrollPane.VERTICAL_SCROLLBAR_ALWAYS); scrollPane.getVerticalScrollBar().setUnitIncrement(16); add(scrollPane, BorderLayout.CENTER); setSize(1200, 900); setLocationRelativeTo(null); } public void updateVisualization() { // Update input layer double[][] input = network.getInputLayer(); if (input != null) { inputPanel.setFeatureMaps(new double[][][]{input}); } // Update conv1 double[][][] conv1 = network.getConv1Output(); if (conv1 != null) { conv1Panel.setFeatureMaps(conv1); } // Update pool1 double[][][] pool1 = network.getPool1Output(); if (pool1 != null) { pool1Panel.setFeatureMaps(pool1); } // Update conv2 double[][][] conv2 = network.getConv2Output(); if (conv2 != null) { conv2Panel.setFeatureMaps(conv2); } // Update pool2 double[][][] pool2 = network.getPool2Output(); if (pool2 != null) { pool2Panel.setFeatureMaps(pool2); } // Update FC layers double[] fc1 = network.getFc1Output(); if (fc1 != null) { fc1Panel.setValues(fc1); } double[] fc2 = network.getFc2Output(); if (fc2 != null) { fc2Panel.setValues(fc2); } // Update output layer double[] output = network.getOutputLayer(); if (output != null) { outputPanel.setValues(output); int prediction = network.getPrediction(); predictionLabel.setText(String.format("Prediction: %d (%.2f%%)", prediction, output[prediction] * 100)); } repaint(); } // Panel for displaying feature maps as grayscale images private class LayerPanel extends JPanel { private String title; private int gridRows; private int gridCols; private double[][][] featureMaps; public LayerPanel(String title, int gridRows, int gridCols) { this.title = title; this.gridRows = gridRows; this.gridCols = gridCols; setBackground(Color.WHITE); setBorder(BorderFactory.createTitledBorder(title)); // Adjust height based on layout - single rows need less height int height = (gridRows == 1 && gridCols > 1) ? 150 : 200; setMaximumSize(new Dimension(1150, height)); setPreferredSize(new Dimension(1150, height)); } public void setFeatureMaps(double[][][] maps) { this.featureMaps = maps; repaint(); } @Override protected void paintComponent(Graphics g) { super.paintComponent(g); if (featureMaps == null) return; Graphics2D g2d = (Graphics2D) g; int numMaps = featureMaps.length; int mapHeight = featureMaps[0].length; int mapWidth = featureMaps[0][0].length; int cellWidth = (getWidth() - 20) / gridCols; int cellHeight = (getHeight() - 40) / gridRows; int imageSize = Math.min(cellWidth - 10, cellHeight - 10); for (int i = 0; i < numMaps && i < gridRows * gridCols; i++) { int row = i / gridCols; int col = i % gridCols; int x = 10 + col * cellWidth; int y = 30 + row * cellHeight; BufferedImage img = createGrayscaleImage(featureMaps[i], imageSize, imageSize); g2d.drawImage(img, x, y, null); // Draw border g2d.setColor(Color.GRAY); g2d.drawRect(x, y, imageSize, imageSize); } } private BufferedImage createGrayscaleImage(double[][] data, int width, int height) { BufferedImage img = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY); int dataHeight = data.length; int dataWidth = data[0].length; // Find min and max for normalization double min = Double.MAX_VALUE; double max = Double.MIN_VALUE; for (int i = 0; i < dataHeight; i++) { for (int j = 0; j < dataWidth; j++) { min = Math.min(min, data[i][j]); max = Math.max(max, data[i][j]); } } // Normalize and draw double range = max - min; if (range == 0) range = 1; for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { int dataY = y * dataHeight / height; int dataX = x * dataWidth / width; double normalized = (data[dataY][dataX] - min) / range; int gray = (int) (normalized * 255); gray = Math.max(0, Math.min(255, gray)); int rgb = (gray << 16) | (gray << 8) | gray; img.setRGB(x, y, rgb); } } return img; } } // Panel for displaying bar charts private class BarChartPanel extends JPanel { private String title; private int numValues; private double[] values; public BarChartPanel(String title, int numValues) { this.title = title; this.numValues = numValues; setBackground(Color.WHITE); setBorder(BorderFactory.createTitledBorder(title)); setMaximumSize(new Dimension(1150, 120)); setPreferredSize(new Dimension(1150, 120)); } public void setValues(double[] values) { this.values = values; repaint(); } @Override protected void paintComponent(Graphics g) { super.paintComponent(g); if (values == null) return; Graphics2D g2d = (Graphics2D) g; int width = getWidth() - 40; int height = getHeight() - 60; int barWidth = Math.max(1, width / numValues); // Find max for scaling double max = 0.0; for (double v : values) { max = Math.max(max, v); } if (max == 0) max = 1; // Draw bars for (int i = 0; i < values.length; i++) { int barHeight = (int) ((values[i] / max) * height); int x = 20 + i * barWidth; int y = 40 + height - barHeight; // Color gradient based on value float hue = 0.55f - (float) (values[i] / max) * 0.55f; // Blue to red g2d.setColor(Color.getHSBColor(hue, 0.7f, 0.9f)); g2d.fillRect(x, y, Math.max(1, barWidth - 1), barHeight); // Draw border g2d.setColor(Color.GRAY); g2d.drawRect(x, y, Math.max(1, barWidth - 1), barHeight); } // Draw axis g2d.setColor(Color.BLACK); g2d.drawLine(20, 40 + height, 20 + width, 40 + height); // Draw labels for output layer if (numValues == 10) { g2d.setFont(new Font("Arial", Font.PLAIN, 10)); for (int i = 0; i < 10; i++) { int x = 20 + i * barWidth + barWidth / 2; g2d.drawString(String.valueOf(i), x - 3, 40 + height + 15); } } } } // Example usage public static void main(String[] args) { // Load a trained model or create a new one lenet5 network = new lenet5(); // Create some dummy input double[][] input = new double[32][32]; for (int i = 8; i < 24; i++) { for (int j = 8; j < 24; j++) { double di = i - 16; double dj = j - 16; if (di*di + dj*dj < 64 && di*di + dj*dj > 36) { input[i][j] = 1.0; } } } // Run forward pass network.forward(input, false); // Create and display visualization SwingUtilities.invokeLater(() -> { NetworkVisualizer viz = new NetworkVisualizer(network); viz.updateVisualization(); viz.setVisible(true); }); } // Method to process and display a new image public void processAndDisplay(double[][] input) { network.forward(input, false); updateVisualization(); } }