import java.awt.Color; import java.awt.Graphics; import java.awt.image.BufferedImage; import java.io.BufferedInputStream; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import javax.swing.ImageIcon; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.border.Border; public class MNISTOneLayer { public static void main(String[] args) { FileInputStream fin; FileInputStream finn; BufferedInputStream inputImageFile; FileInputStream lin; FileInputStream linn; BufferedInputStream inputLabelFile; try { fin = new FileInputStream("train-images-idx3-ubyte"); lin = new FileInputStream("train-labels-idx1-ubyte"); inputImageFile = new BufferedInputStream(fin); inputLabelFile = new BufferedInputStream(lin); for (int i = 0; i < 16; i++) inputImageFile.read(); // reads off 16 bytes double[][] images; System.out.println("Reading training file images"); images = new double[60000][784]; //60,000 for (int k = 0; k < 60000; k++) { // 60,000 for (int i = 0; i < 784; i++) { int pixel = inputImageFile.read(); if (pixel > 80) {images[k][i] = 1.0; } else images[k][i] = 0.0; // images[k][i] = pixel/255; } } System.out.println("Finished reading training file images"); for (int i = 0; i < 8; i++) inputLabelFile.read(); // reads off 8 bytes double[][] labels; System.out.println("Reading training file labels"); labels = new double[60000][10]; // 60,000 for (int k = 0; k < 60000; k++) { // 60,000 for (int j = 0; j < 10; j++) labels[k][j] = 0.0; int l = inputLabelFile.read(); labels[k][l] = 1.0; } System.out.println("Finished reading training file labels"); FFNetOneLayerRelu net = new FFNetOneLayerRelu(); //FFNetOneLayer net = new FFNetOneLayer(); net.initNetwork(784, 1000, 10, images, labels); net.trainNetwork(2, true); net.testNetwork(images, labels, 60000); finn = new FileInputStream("t10k-images-idx3-ubyte"); linn = new FileInputStream("t10k-labels-idx1-ubyte"); inputImageFile = new BufferedInputStream(finn); inputLabelFile = new BufferedInputStream(linn); for (int i = 0; i < 16; i++) inputImageFile.read(); // reads off 16 bytes double[][] testingImages; System.out.println("Reading testing file images"); testingImages = new double[10000][784]; // 10,000 for (int k = 0; k < 10000; k++) { // 10,000 for (int i = 0; i < 784; i++) { int pixel = inputImageFile.read(); if (pixel > 80) {testingImages[k][i] = 1.0; } else testingImages[k][i] = 0.0; } } System.out.println("Finished reading testing file images"); double[][] testingLabels; for (int i = 0; i < 8; i++) inputLabelFile.read(); // reads off 8 bytes System.out.println("Reading testing file labels"); testingLabels = new double[10000][10]; // 10,000 for (int k = 0; k < 10000; k++) { // 10,000 for (int j = 0; j < 10; j++) testingLabels[k][j] = 0.0; int l = inputLabelFile.read(); testingLabels[k][l] = 1.0; } System.out.println("Finished reading testing file labels"); net.testNetwork(testingImages, testingLabels, 10000); double[][] twoDImage = new double[28][28]; int shiftOffset = 2; for (int k = 0; k < 1000; k++) { // convert to 2D int c = 0; for (int i = 0; i < 28; i++) { for (int j = 0; j < 28; j++) { twoDImage[i][j] = testingImages[k][c++]; } } // shift if (Math.random() < 0.5) { twoDImage = shiftRight(twoDImage, shiftOffset); } else twoDImage = shiftLeft(twoDImage, shiftOffset); if (Math.random() < 0.5) { twoDImage = shiftDown(twoDImage, shiftOffset); } else twoDImage = shiftUp(twoDImage, shiftOffset); // convert back to 1D c = 0; for (int i = 0; i < 28; i++) { for (int j = 0; j < 28; j++) { testingImages[k][c++] = twoDImage[i][j]; } } } net.testNetwork(testingImages, testingLabels, 1000); //net.printWeightsToFile("weights.txt"); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } } private static double[][] shiftRight(double[][] image, int dx){ // System.out.println(dx); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxRow][maxCol]; for (int row = maxRow-1; row >= dx; row--) { for (int col = 0; col < maxCol; col++) { shiftedImage[col][row] = image[col][row-dx]; } } // printImage(shiftedImage); return shiftedImage; } private static double[][] shiftLeft(double[][] image, int dx){ // System.out.println(dx); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxCol][maxRow]; for (int row = 0; row <= maxRow-dx-1; row++) { for (int col = 0; col < maxCol; col++) { shiftedImage[col][row] = image[col][row+dx]; } } // printImage(shiftedImage); return shiftedImage; } private static double[][] shiftUp(double[][] image, int dy){ // System.out.println(dy); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxCol][maxRow]; for (int col = 0; col <= maxCol-dy-1; col++) { for (int row = 0; row < maxRow; row++) { shiftedImage[col][row] = image[col+dy][row]; } } // printImage(shiftedImage); return shiftedImage; } private static double[][] shiftDown(double[][] image, int dy){ // System.out.println(dy); // printImage(image); int maxRow = image.length; int maxCol = image[0].length; double[][] shiftedImage = new double[maxCol][maxRow]; for (int col = maxCol-1; col >= dy; col--) { for (int row = 0; row < maxRow; row++) { shiftedImage[col][row] = image[col-dy][row]; } } // printImage(shiftedImage); return shiftedImage; } }