import java.io.*; public class MNISTLoader { private double[][][] trainImages; // [numImages][28][28] private int[] trainLabels; private double[][][] testImages; // [numImages][28][28] private int[] testLabels; public MNISTLoader(String trainImagesPath, String trainLabelsPath, String testImagesPath, String testLabelsPath) throws IOException { trainImages = loadImages(trainImagesPath); trainLabels = loadLabels(trainLabelsPath); testImages = loadImages(testImagesPath); testLabels = loadLabels(testLabelsPath); // System.out.println("Loaded " + trainImages.length + " training images"); // System.out.println("Loaded " + testImages.length + " test images"); } private double[][][] loadImages(String path) throws IOException { DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path))); int magicNumber = dis.readInt(); int numImages = dis.readInt(); int numRows = dis.readInt(); int numCols = dis.readInt(); // System.out.println("Loading images from " + path); // System.out.println(" Magic number: " + magicNumber); // System.out.println(" Number of images: " + numImages); // System.out.println(" Image size: " + numRows + "x" + numCols); double[][][] images = new double[numImages][numRows][numCols]; for (int i = 0; i < numImages; i++) { for (int r = 0; r < numRows; r++) { for (int c = 0; c < numCols; c++) { // Read unsigned byte and normalize to [0, 1] images[i][r][c] = (dis.readUnsignedByte()) / 255.0; } } } dis.close(); return images; } private int[] loadLabels(String path) throws IOException { DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path))); int magicNumber = dis.readInt(); int numLabels = dis.readInt(); // System.out.println("Loading labels from " + path); // System.out.println(" Magic number: " + magicNumber); // System.out.println(" Number of labels: " + numLabels); int[] labels = new int[numLabels]; for (int i = 0; i < numLabels; i++) { labels[i] = dis.readUnsignedByte(); } dis.close(); return labels; } // Pad 28x28 image to 32x32 for LeNet-5 input public double[][] padImage(double[][] image28) { double[][] image32 = new double[32][32]; // Center the 28x28 image in the 32x32 array (2 pixels padding on each side) for (int i = 0; i < 28; i++) { for (int j = 0; j < 28; j++) { image32[i + 2][j + 2] = image28[i][j]; } } return image32; } public double[][] getTrainImage(int index) { return padImage(trainImages[index]); } public int getTrainLabel(int index) { return trainLabels[index]; } public double[][] getTestImage(int index) { return padImage(testImages[index]); } public int getTestLabel(int index) { return testLabels[index]; } public int getNumTrainImages() { return trainImages.length; } public int getNumTestImages() { return testImages.length; } // Get a subset of training data (useful for quick testing) public double[][][] getTrainImageSubset(int startIdx, int count) { double[][][] subset = new double[count][][]; for (int i = 0; i < count; i++) { subset[i] = getTrainImage(startIdx + i); } return subset; } public int[] getTrainLabelSubset(int startIdx, int count) { int[] subset = new int[count]; for (int i = 0; i < count; i++) { subset[i] = getTrainLabel(startIdx + i); } return subset; } // Print a simple visualization of an image public void printImage(double[][] image) { for (int i = 0; i < image.length; i++) { for (int j = 0; j < image[i].length; j++) { if (image[i][j] > 0.5) { System.out.print("##"); } else if (image[i][j] > 0.2) { System.out.print(".."); } else { System.out.print(" "); } } System.out.println(); } } // Helper class for visualization public static class MNISTData { public double[][][] images; // Already padded to 32x32 public int[] labels; public int numImages; public MNISTData(double[][][] images, int[] labels) { this.images = images; this.labels = labels; this.numImages = images.length; } } // Static method to load data for visualization public static MNISTData loadData(String imagesPath, String labelsPath) throws IOException { // Load raw 28x28 images DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(imagesPath))); int magicNumber = dis.readInt(); int numImages = dis.readInt(); int numRows = dis.readInt(); int numCols = dis.readInt(); double[][][] images28 = new double[numImages][numRows][numCols]; for (int i = 0; i < numImages; i++) { for (int r = 0; r < numRows; r++) { for (int c = 0; c < numCols; c++) { images28[i][r][c] = (dis.readUnsignedByte()) / 255.0; } } } dis.close(); // Load labels DataInputStream labelDis = new DataInputStream(new BufferedInputStream(new FileInputStream(labelsPath))); labelDis.readInt(); // magic number int numLabels = labelDis.readInt(); int[] labels = new int[numLabels]; for (int i = 0; i < numLabels; i++) { labels[i] = labelDis.readUnsignedByte(); } labelDis.close(); // Pad images to 32x32 double[][][] images32 = new double[numImages][32][32]; for (int i = 0; i < numImages; i++) { for (int r = 0; r < 28; r++) { for (int c = 0; c < 28; c++) { images32[i][r + 2][c + 2] = images28[i][r][c]; } } } return new MNISTData(images32, labels); } }