TrainCheckpoints
TrainCheckpoints class¶
A TrainCheckpoints
object serves checkpoint saving, providing directory to save weights and reporting corresponding events.
Checkpoints are main results of training tasks and should be stored in the following way:
out_dir = the_train_checkpoints.get_dir_to_write() # write any desired data into out_dir... then the_train_checkpoints.saved(is_best=True)
Each checkpoint is a directory with NN model weights. Content of the directory is entirely dependent on model implementation. One may store here some additional information required to apply model, continue training etc.
Hint
config.json
file in the directory will be displayed in web interface. We prefer to store model metadata in it.
class TrainCheckpoints: def __init__(self, odir):
Create a TrainCheckpoints
object.
odir
— root directory to store checkpoints.
Methods¶
get_dir_to_write(self)
¶
Returns current path to directory which should be used to save model checkpoint (weights, metainfo).
get_last_ckpt_dir(self)
¶
Returns current path to directory with last correct (fully written) checkpoint.
saved
¶
def saved(self, is_best, optional_data=None):
Finishes usage of current directory for checkpoint and reports that the checkpoint has been saved. Should be called after every checkpoint saving.
-
is_best
— boolean value to determine if the stored model is best so far (during training process). Unused now. -
optional_data
— json-serializable object which will be linked to the checkpoint and may be useful to distinct different checkpoints. E.g., one may store model weights after validation and pass validation results as theoptional_data
.