Machine Learning - Reading in MNIST dataset in C++

If you want to write your own Multi Layer Perceptron (MLP) or Convolutional Neural Network (CNN) in C++, you need some training data to test your implementation.

MNIST is the classical approach to do a "Machine Learning Hello World". It consists of 60.000 training images of handwritten digits 0,1,...,9 and 10.000 test images together with ground-truth training and test labels.

But this means, you need some code that reads in training/testing images + labels.

Here is my C++ code to realize this task. It also uses OpenCV in order to visualize some sample training images. Hope it helps!

main.cpp:

		
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/videoio.hpp"

#include 
#include 
#include                  // for random numbers & 1D normal distribution
#define _USE_MATH_DEFINES
#include                  // for M_PI

#include "mnist_dataset_reader.h"


using namespace cv;
using namespace std;

int main()
{
 string path_to_extracted_mnist_files = "V:\\01_job\\00_vorlesungen_meine\\17_deep_learning\\07_datasets\\01_mnist\\extracted";
 mnist_dataset_reader my_reader( path_to_extracted_mnist_files );

 /*
 Mat* img = my_reader.get_mnist_image_as_cvmat( my_reader.get_train_images(), 1 );
 Mat img_resized;
 resize( *img, img_resized, cv::Size(0,0), 4,4);
 imshow("A sample digit", img_resized);
 */

 Mat* samples_as_image = my_reader.get_board_of_sample_images( my_reader.get_train_images(), my_reader.get_train_labels(), 60000 );
 imshow("Some sample MNIST training images", *samples_as_image );
 
 waitKey(0);

 _getch();

}			
        

mnist_dataset_reader.h

#pragma once

#include 
#include 
#include 

#include "opencv2/core.hpp"    // for cv::Mat
#include "opencv2/highgui.hpp" // for CV_RGB
#include "opencv2/imgproc.hpp" // for cv::putText

using namespace std;
using namespace cv;

class mnist_dataset_reader
{
  public:

                            mnist_dataset_reader(string path_to_extracted_mnist_files); // reads in all training / test images + labels

    unsigned char**         get_train_images();

    unsigned char**         get_test_images();

    unsigned char*          get_train_labels();

    unsigned char*          get_test_labels();

    Mat*                    get_mnist_image_as_cvmat(unsigned char** images, int image_idx);

    Mat*                    get_board_of_sample_images(unsigned char** images, unsigned char* labels, int nr_of_images);





  private:

    unsigned char**         mnist_dataset_reader::read_mnist_images(string full_path, int& number_of_images, int& image_size); // for reading in images

    unsigned char*          mnist_dataset_reader::read_mnist_labels(string full_path, int& number_of_labels);                  // for reading in ground truth image labels


    string                  path_to_extracted_mnist_files; // where you have extracted the files 

    unsigned char**         train_images;                  // all the 60.000 training images of size 28x28 

    unsigned char**         test_images;                   // all the 10.000 test     images of size 28x28

    unsigned  char*         train_labels;                  // all the 60.000 ground truth labels for the training images

    unsigned  char*         test_labels;                   // all the 10.000 ground truth labels for the test     images

    int                     nr_train_images_read;          // should be 60.000

    int                     nr_test_images_read;           // should be 10.000
    
};
    
    

mnist_dataset_reader.cpp

// MNIST dataset reader
//
// reads in the 60.000 training and 10.000 testing images of
// the "MNIST database of handwritten digits" (0,1,2,...,9)
// see http://yann.lecun.com/exdb/mnist/
//
// note: download yourself the dataset and extract the files to a folder.
//       after extracting, you should get the MNIST dataset files:
//         t10k-images.idx3-ubyte
//         t10k-labels.idx1-ubyte
//         train-images.idx3-ubyte
//         train-labels.idx1-ubyte
//
// by Prof. Dr.-Ing. Jürgen Brauer, www.juergenbrauer.org
//
// parts of the code are inspired by
// http://stackoverflow.com/questions/8286668/how-to-read-mnist-data-in-c


#include "mnist_dataset_reader.h"


//
// read in all the training / test images and their corresponding ground truth labels
//
mnist_dataset_reader::mnist_dataset_reader(std::string path_to_extracted_mnist_files)
{
  // 1. store path to folder where dataset was extracted by user
  this->path_to_extracted_mnist_files = path_to_extracted_mnist_files;


  // 2. read in training data
  int size_of_single_image_read;
  int nr_labels_read;

  // 2.1 read in training images
  cout << endl;  
  train_images = read_mnist_images( path_to_extracted_mnist_files + "\\\\" + "train-images.idx3-ubyte", nr_train_images_read, size_of_single_image_read);
  cout << "I have read " << nr_train_images_read << " training images of size " << size_of_single_image_read << "!" << endl;

  // 2.2 read in training ground truth labels
  cout << endl;
  train_labels = read_mnist_labels( path_to_extracted_mnist_files + "\\\\" + "train-labels.idx1-ubyte", nr_labels_read);
  cout << "I have read " << nr_labels_read << " labels for the training images!" << endl;


  // 3. read in test data

  // 3.1 read in test images
  cout << endl;
  test_images = read_mnist_images(path_to_extracted_mnist_files + "\\\\" + "t10k-images.idx3-ubyte", nr_test_images_read, size_of_single_image_read);
  cout << "I have read " << nr_test_images_read << " test images of size " << size_of_single_image_read << "!" << endl;

  // 3.2 read in testing ground truth labels
  cout << endl;
  test_labels = read_mnist_labels(path_to_extracted_mnist_files + "\\\\" + "t10k-labels.idx1-ubyte", nr_labels_read);
  cout << "I have read " << nr_labels_read << " labels for the testing images!" << endl;


}



//
// read in a set of images and return a pointer to the 2D array (pointer char** to 1D array of row pointers char*)
//
unsigned char** mnist_dataset_reader::read_mnist_images(string full_path, int& number_of_images, int& image_size)
{

  auto reverseInt = [](int i) {
    unsigned char c1, c2, c3, c4;
    c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
    return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
  };

  cout << "reading MNIST file " << full_path << " ... " << endl;

  ifstream file(full_path, ios::binary);

  if (file.is_open()) {
    int magic_number = 0, n_rows = 0, n_cols = 0;

    file.read((char *)&magic_number, sizeof(magic_number));
    magic_number = reverseInt(magic_number);

    if (magic_number != 2051) throw runtime_error("Invalid MNIST image file!");

    file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
    file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
    file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);

    cout << "nr of rows x nr of cols = " << n_rows << " x " << n_cols << endl;

    image_size = n_rows * n_cols;

    unsigned char** _dataset = new unsigned char*[number_of_images];
    for (int i = 0; i < number_of_images; i++) {
      _dataset[i] = new unsigned char[image_size];
      file.read((char*)_dataset[i], image_size);
    }
    cout << ".. all images read!" << endl;

    return _dataset;
  }
  else {
    throw runtime_error("Cannot open file `" + full_path + "`!");
  }

} // read_mnist_images



unsigned char* mnist_dataset_reader::read_mnist_labels(string full_path, int& number_of_labels)
{
  auto reverseInt = [](int i) {
    unsigned char c1, c2, c3, c4;
    c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
    return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
  };

  cout << "reading MNIST file " << full_path << " ... " << endl;

  ifstream file(full_path, ios::binary);

  if (file.is_open()) {
    int magic_number = 0;
    file.read((char *)&magic_number, sizeof(magic_number));
    magic_number = reverseInt(magic_number);

    if (magic_number != 2049) throw runtime_error("Invalid MNIST label file!");

    file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);

    uchar* _dataset = new uchar[number_of_labels];
    for (int i = 0; i < number_of_labels; i++) {
      file.read((char*)&_dataset[i], 1);
    }
    cout << ".. all labels read!" << endl;
    return _dataset;
  }
  else {
    throw runtime_error("Unable to open file `" + full_path + "`!");
  }

} // read_mnist_labels



unsigned char** mnist_dataset_reader::get_train_images()
{
  return train_images;
}


unsigned char** mnist_dataset_reader::get_test_images()
{
  return test_images;
}


unsigned char* mnist_dataset_reader::get_train_labels()
{
  return train_labels;
}


unsigned char* mnist_dataset_reader::get_test_labels()
{
  return test_labels;
}


//
// given a specified set of images (train or test images)
// this method can be used to render a cv::Mat image of a MNIST digit sample
// that can be displayed to the user
//
Mat* mnist_dataset_reader::get_mnist_image_as_cvmat(unsigned char** images, int image_idx)
{
  Mat* visu = new Mat(28,28, CV_8UC1);

  int idx_1d = 0;
  for (int y = 0; y < 28; y++)
  {
    for (int x = 0; x < 28; x++)    
    {
      unsigned char gray_value = images[image_idx][idx_1d++];

      visu->at(y,x) = gray_value;

      //printf("%03d ", gray_value);
    }
    //cout << endl;
  }

  return visu;
   
} // get_mnist_image_as_cvmat


//
// select some sample images randomly and visualize them together with their
// labels
//
Mat* mnist_dataset_reader::get_board_of_sample_images(unsigned char** images, unsigned char* labels, int nr_of_images)
{
  const int w=1200;
  const int h=800;
  bool still_space_for_drawing = true;
  Mat* visu = new Mat(h,w, CV_8UC3);

  int y = 0;
  int x = 0;
  const int some_space = 7;
  while (still_space_for_drawing)
  {
    // get random image index
    int rnd_idx = rand() % nr_of_images;

    // copy visualization of image of digit to board
    Mat* sample_image = get_mnist_image_as_cvmat( train_images, rnd_idx );
    Mat sample_image_3channels;
    cvtColor(*sample_image, sample_image_3channels, CV_GRAY2RGB);
    Mat* roi = new Mat( *visu, Rect(x,y,28,28) );
    sample_image_3channels.copyTo( *roi );
    x += 28 + some_space;

    // draw image label to the right of image of digit on board as well
    char txt[10];
    sprintf_s(txt, "%d", labels[rnd_idx]);
    putText(*visu,
            txt,
            Point(x, y+20),
            FONT_HERSHEY_SIMPLEX, 0.7, // font face and scale
            CV_RGB(255, 255, 0), // yellow
            1); // line thickness and type
    x += 28 + some_space*2;

    // line break before next visualization of sample digit?
    if (x + 2 * 28 + some_space >= w)
    {
      // yes! line break
      x=0;
      y += 28+2*some_space;

      // is there still enough space for rendering another line?
      if (y+28+some_space>=h)
      {
        // no space for rendering left
        still_space_for_drawing=false;
      }

    } // if
                
  } // while (still_space_for_drawing)

  return visu;

} // get_board_of_sample_images    
    

Output generated: