Saving the model during the training process is done using the mx.do_checkpoint callback. A few important parameters are as follows:
- prefix: This defines the prefix of the filenames to save the model
- frequency: The frequency is measured in epochs to save checkpoints
Let's move back to the MNIST example we have been working on and adjust the mx.fit function to include mx.do_checkpoint:
mx.fit(nnet, mx.ADAM(), train_data_provider, eval_data = validation_data_provider, n_epoch = 50, callbacks = [mx.speedometer()]);
You can see that in the original version we have already configured the network to call the mx.speedometer callback. The new version will include a call to mx.do_checkpoint to save the model on every 5th epoch with a weights/mnist set as a prefix:
cp_prefix = "weights/mnist"
callbacks = [
mx.speedometer(),
mx.do_checkpoint(cp_prefix, frequency=5)
]
mx.fit(nnet, mx.ADAM(), train_data_provider, eval_data = validation_data_provider, n_epoch = 30, callbacks = callbacks);
This will also be reflected during the training process:
INFO: == Epoch 005/030 ==========
INFO: ## Training summary
INFO: accuracy = 0.9025
INFO: time = 1.9472 seconds
INFO: ## Validation summary
INFO: accuracy = 0.9135
INFO: Saved checkpoint to 'weights/mnist-0005.params'
The model architecture will be saved to mnist-symbol.json in the weights folder, while the weights will be saved to mnist-0005.params, mnist-0010.params, and mnist-0015.params, corresponding to the 5th, 10th, and 15th epochs.