Saving the model

Saving the model during the training process is done using the mx.do_checkpoint callback. A few important parameters are as follows:

Let's move back to the MNIST example we have been working on and adjust the mx.fit function to include mx.do_checkpoint:

mx.fit(nnet, mx.ADAM(), train_data_provider, eval_data = validation_data_provider, n_epoch = 50, callbacks = [mx.speedometer()]);

You can see that in the original version we have already configured the network to call the mx.speedometer callback. The new version will include a call to mx.do_checkpoint to save the model on every 5th epoch with a weights/mnist set as a prefix:

cp_prefix = "weights/mnist"
callbacks = [
mx.speedometer(),
mx.do_checkpoint(cp_prefix, frequency=5)
]

mx.fit(nnet, mx.ADAM(), train_data_provider, eval_data = validation_data_provider, n_epoch = 30, callbacks = callbacks);

This will also be reflected during the training process:

INFO: == Epoch 005/030 ==========
INFO: ## Training summary
INFO: accuracy = 0.9025
INFO: time = 1.9472 seconds
INFO: ## Validation summary
INFO: accuracy = 0.9135
INFO: Saved checkpoint to 'weights/mnist-0005.params'

The model architecture will be saved to mnist-symbol.json in the weights folder, while the weights will be saved to mnist-0005.params, mnist-0010.params, and mnist-0015.params, corresponding to the 5th, 10th, and 15th epochs.