import java.io.BufferedInputStream; import java.io.DataInputStream; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; public class MNISTTrainer { public static void main(String[] args) throws IOException { System.out.println("=== LeNet-5 MNIST Trainer ===\n"); // Load MNIST data System.out.println("Loading MNIST dataset..."); MNISTLoader mnist = new MNISTLoader( "train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte" ); System.out.println("\nMNIST data loaded successfully!\n"); // // Show a sample image // System.out.println("Sample training image (label = " + mnist.getTrainLabel(0) + "):"); // mnist.printImage(mnist.getTrainImage(0)); // System.out.println(); // Create network // System.out.println("Creating LeNet-5 network..."); lenet5 net = new lenet5(); // System.out.println("Network created!\n"); // Training parameters int numTrainingSamples = 60000; // Full training set int numTestSamples = 10000; // Full test set int maxEpochs = 50; // Maximum epochs (will stop early at 97% accuracy) int reportInterval = 5000; // Report progress every 5000 samples double targetAccuracy = 0.97; // Stop when test accuracy reaches 97% // System.out.println("Training configuration:"); // System.out.println(" Training samples: " + numTrainingSamples); // System.out.println(" Test samples: " + numTestSamples); // System.out.println(" Max epochs: " + maxEpochs); // System.out.println(" Target accuracy: " + (targetAccuracy * 100) + "%"); // System.out.println(" Progress report interval: every " + reportInterval + " samples"); // System.out.println(); // Test before training System.out.println("=== Testing before training ==="); double initialAccuracy = evaluateNetwork(net, mnist, numTestSamples); System.out.printf("Initial test accuracy: %.2f%%\n\n", initialAccuracy * 100); // Training loop System.out.println("=== Starting Training ===\n"); for (int epoch = 0; epoch < maxEpochs; epoch++) { System.out.printf("Epoch %d\n", epoch + 1); long epochStart = System.currentTimeMillis(); // double totalLoss = 0.0; int correct = 0; for (int i = 0; i < numTrainingSamples; i++) { double[][] image = mnist.getTrainImage(i); int label = mnist.getTrainLabel(i); // Train net.train(image, label, false); // Compute loss and accuracy net.forward(image, false); // totalLoss += net.computeLoss(label); if (net.getPrediction() == label) { correct++; } // Print progress if ((i + 1) % reportInterval == 0) { // System.out.printf(" Progress: %d/%d samples (%.1f%%)\r", // i + 1, numTrainingSamples, (i + 1) * 100.0 / numTrainingSamples); System.out.printf(" Progress: %d/%d samples\r", i + 1, numTrainingSamples); } } long epochEnd = System.currentTimeMillis(); double epochTime = (epochEnd - epochStart) / 1000.0; System.out.printf("Time for training: time=%.1fs\n", epochTime); // double avgLoss = totalLoss / numTrainingSamples; // double trainAccuracy = (double) correct / numTrainingSamples; // Test on training set System.out.print(" Evaluating on training set...\r"); double trainingAccuracy = evaluateNetworkTrain(net, mnist, numTrainingSamples); System.out.printf("train_acc=%.2f%%\n", trainingAccuracy * 100); // Test on validation set System.out.print(" Evaluating on test set...\r"); double testAccuracy = evaluateNetwork(net, mnist, numTestSamples); System.out.printf("test_acc=%.2f%%\n", testAccuracy * 100); // Test on shifted data set System.out.print(" Evaluating on shifted test set...\r"); double shiftedAccuracy = evaluateNetworkShifted(net, mnist, 1000, 2); System.out.printf("test_acc=%.2f%%\n", shiftedAccuracy * 100); // System.out.printf("Epoch %d: loss=%.4f, train_acc=%.2f%%, test_acc=%.2f%%, time=%.1fs\n", // epoch + 1, avgLoss, trainAccuracy * 100, testAccuracy * 100, epochTime); // Early stopping if target accuracy reached if (testAccuracy >= targetAccuracy) { System.out.println("\n*** Target accuracy of " + (targetAccuracy * 100) + "% reached! ***"); System.out.println("*** Stopping training early at epoch " + (epoch + 1) + " ***\n"); break; } } // System.out.println("\n=== Training Complete ===\n"); // Final evaluation on larger test set // System.out.println("Final evaluation on " + numTestSamples + " test samples:"); // double finalAccuracy = evaluateNetwork(net, mnist, numTestSamples); // System.out.printf("Final test accuracy: %.2f%%\n", finalAccuracy * 100); // // Show some predictions // System.out.println("\n=== Sample Predictions ==="); // for (int i = 0; i < 10; i++) { // double[][] image = mnist.getTestImage(i); // int label = mnist.getTestLabel(i); // double[] output = net.forward(image, false); // int prediction = net.getPrediction(); // // System.out.printf("Sample %d: True=%d, Predicted=%d, Confidence=%.2f%% %s\n", // i, label, prediction, output[prediction] * 100, // (prediction == label ? "✓" : "✗")); // } // System.out.println("\n=== Program Complete ==="); } // // Evaluate network accuracy on test set // private static double evaluateNetworkShifted(lenet5 net) throws IOException { // // // Load shifted images // String filename = "shifted_test_images.dat"; // DataInputStream dis; // dis = new DataInputStream(new BufferedInputStream(new FileInputStream(filename))); // // // // Read header // int numImages = dis.readInt(); // int height = dis.readInt(); // int width = dis.readInt(); // //// System.out.println("File info:"); //// System.out.println(" Number of images: " + numImages); //// System.out.println(" Image size: " + height + "x" + width); //// System.out.println(); // // int correct = 0; // // for (int i = 0; i < numImages; i++) { // // Read label // int label = dis.readInt(); // // // Read image // double[][] image = new double[height][width]; // for (int row = 0; row < height; row++) { // for (int col = 0; col < width; col++) { // image[row][col] = dis.readDouble(); // } // } // // net.forward(image, false); // if (net.getPrediction() == label) { // correct++; // } // } // dis.close(); // // return (double) correct / numImages; // } // Evaluate network accuracy on test set private static double evaluateNetworkShifted(lenet5 net, MNISTLoader mnist, int numSamples, int shiftOffset) { int correct = 0; for (int i = 0; i < numSamples; i++) { double[][] image = mnist.getTestImage(i); if (Math.random() < 0.5) { image = shiftRight(image, shiftOffset); } else image = shiftLeft(image, shiftOffset); if (Math.random() < 0.5) { image = shiftDown(image, shiftOffset); } else image = shiftUp(image, shiftOffset); int label = mnist.getTestLabel(i); net.forward(image, false); if (net.getPrediction() == label) { correct++; } } return (double) correct / numSamples; } private static double[][] shiftRight(double[][] image, int dx){ // System.out.println(dx); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxRow][maxCol]; for (int row = maxRow-1; row >= dx; row--) { for (int col = 0; col < maxCol; col++) { shiftedImage[col][row] = image[col][row-dx]; } } // printImage(shiftedImage); return shiftedImage; } private static double[][] shiftLeft(double[][] image, int dx){ // System.out.println(dx); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxCol][maxRow]; for (int row = 0; row <= maxRow-dx-1; row++) { for (int col = 0; col < maxCol; col++) { shiftedImage[col][row] = image[col][row+dx]; } } // printImage(shiftedImage); return shiftedImage; } private static double[][] shiftUp(double[][] image, int dy){ // System.out.println(dy); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxCol][maxRow]; for (int col = 0; col <= maxCol-dy-1; col++) { for (int row = 0; row < maxRow; row++) { shiftedImage[col][row] = image[col+dy][row]; } } // printImage(shiftedImage); return shiftedImage; } private static double[][] shiftDown(double[][] image, int dy){ // System.out.println(dy); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxCol][maxRow]; for (int col = maxCol-1; col >= dy; col--) { for (int row = 0; row < maxRow; row++) { shiftedImage[col][row] = image[col-dy][row]; } } // printImage(shiftedImage); return shiftedImage; } private static void printImage(double[][] image) { System.out.println("----"); for (int i = 0; i < image.length; i++) { for (int j = 0; j < image[i].length; j++) { if (image[i][j] > 0.5) { System.out.print("##"); } else if (image[i][j] > 0.2) { System.out.print(".."); } else { System.out.print(" "); } } System.out.println(); } } // Evaluate network accuracy on test set private static double evaluateNetwork(lenet5 net, MNISTLoader mnist, int numSamples) { int correct = 0; for (int i = 0; i < numSamples; i++) { double[][] image = mnist.getTestImage(i); int label = mnist.getTestLabel(i); net.forward(image, false); if (net.getPrediction() == label) { correct++; } } return (double) correct / numSamples; } // Evaluate network accuracy on training set private static double evaluateNetworkTrain(lenet5 net, MNISTLoader mnist, int numSamples) { int correct = 0; for (int i = 0; i < numSamples; i++) { double[][] image = mnist.getTrainImage(i); int label = mnist.getTrainLabel(i); net.forward(image, false); if (net.getPrediction() == label) { correct++; } } return (double) correct / numSamples; } }