Skip to content

TrainCheckpoints

TrainCheckpoints class

A TrainCheckpoints object serves checkpoint saving, providing directory to save weights and reporting corresponding events. Checkpoints are main results of training tasks and should be stored in the following way:

out_dir = the_train_checkpoints.get_dir_to_write()
# write any desired data into out_dir... then
the_train_checkpoints.saved(is_best=True)

Each checkpoint is a directory with NN model weights. Content of the directory is entirely dependent on model implementation. One may store here some additional information required to apply model, continue training etc.

Hint

config.json file in the directory will be displayed in web interface. We prefer to store model metadata in it.

class TrainCheckpoints:
    def __init__(self, odir):

Create a TrainCheckpoints object.

  • odir — root directory to store checkpoints.

Methods

get_dir_to_write(self)

Returns current path to directory which should be used to save model checkpoint (weights, metainfo).

get_last_ckpt_dir(self)

Returns current path to directory with last correct (fully written) checkpoint.

saved

def saved(self, is_best, optional_data=None):

Finishes usage of current directory for checkpoint and reports that the checkpoint has been saved. Should be called after every checkpoint saving.

  • is_best — boolean value to determine if the stored model is best so far (during training process). Unused now.

  • optional_data — json-serializable object which will be linked to the checkpoint and may be useful to distinct different checkpoints. E.g., one may store model weights after validation and pass validation results as the optional_data.