import javax.swing.JFrame; public class TrainingExample { // Create a simple synthetic training example public static double[][] createSyntheticDigit(int digit) { double[][] image = new double[32][32]; // Create simple patterns for each digit (very simplified) // In reality, you'd use actual MNIST data switch(digit) { case 0: // Draw a circle/oval for 0 for (int i = 8; i < 24; i++) { for (int j = 8; j < 24; j++) { double di = i - 16; double dj = j - 16; if (di*di + dj*dj < 64 && di*di + dj*dj > 36) { image[i][j] = 1.0; } } } break; case 1: // Draw a vertical line for 1 for (int i = 8; i < 24; i++) { image[i][16] = 1.0; image[i][15] = 1.0; } break; case 2: // Draw a simplified 2 for (int j = 10; j < 22; j++) { image[10][j] = 1.0; // top image[16][j] = 1.0; // middle image[22][j] = 1.0; // bottom } for (int i = 10; i < 16; i++) { image[i][21] = 1.0; // right side } for (int i = 16; i < 22; i++) { image[i][10] = 1.0; // left side } break; default: // For other digits, create random patterns for (int i = 8; i < 24; i++) { for (int j = 8; j < 24; j++) { if (Math.random() < 0.3) { image[i][j] = 1.0; } } } } return image; } public static void main(String[] args) { System.out.println("=== LeNet-5 Training Example ===\n"); // Create network lenet5 net = new lenet5(); // Create small training dataset int numSamples = 3; double[][][] trainingData = new double[numSamples][][]; int[] labels = new int[numSamples]; // Generate synthetic data for digits 0, 1, 2 trainingData[0] = createSyntheticDigit(0); labels[0] = 0; // displayImage(trainingData[0]); trainingData[1] = createSyntheticDigit(1); labels[1] = 1; trainingData[2] = createSyntheticDigit(2); labels[2] = 2; // Test before training System.out.println("Before training:"); for (int i = 0; i < numSamples; i++) { double[] output = net.forward(trainingData[i], false); int prediction = net.getPrediction(); double loss = net.computeLoss(labels[i]); System.out.printf("Sample %d (label=%d): predicted=%d, loss=%.4f, confidence=%.4f\n", i, labels[i], prediction, loss, output[prediction]); } // Training loop System.out.println("\n=== Training ==="); int epochs = 100; for (int epoch = 0; epoch < epochs; epoch++) { double totalLoss = 0.0; int correct = 0; // Train on each sample for (int i = 0; i < numSamples; i++) { net.train(trainingData[i], labels[i], false); // Compute metrics after training net.forward(trainingData[i], false); totalLoss += net.computeLoss(labels[i]); if (net.getPrediction() == labels[i]) { correct++; } } // Print progress every 10 epochs if ((epoch + 1) % 10 == 0) { double avgLoss = totalLoss / numSamples; double accuracy = (double) correct / numSamples * 100; System.out.printf("Epoch %3d: avg_loss=%.4f, accuracy=%.1f%%\n", epoch + 1, avgLoss, accuracy); } } // Test after training System.out.println("\n=== After Training ==="); net.forward(trainingData[0], false; for (int i = 0; i < numSamples; i++) { // for (int i = 0; i < 1; i++) { double[] output = net.forward(trainingData[i], false); int prediction = net.getPrediction(); double loss = net.computeLoss(labels[i]); System.out.printf("\nSample %d (label=%d):\n", i, labels[i]); System.out.printf(" Predicted: %d\n", prediction); System.out.printf(" Loss: %.4f\n", loss); System.out.printf(" Confidence: %.4f\n", output[prediction]); System.out.println(" All probabilities:"); for (int j = 0; j < output.length; j++) { System.out.printf(" Digit %d: %.4f%s\n", j, output[j], (j == labels[i] ? " (correct)" : "")); } } System.out.println("\n=== Training Complete ==="); } }