提交 23c82fd6 编写于 作者: N Nick 提交者: Waleed

Add custom callbacks to model training

Add an optional parameter for calling a list of keras.callbacks to be add to the original list.
上级 3ba867e2
......@@ -2272,7 +2272,7 @@ class MaskRCNN():
"*epoch*", "{epoch:04d}")
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None):
augmentation=None, custom_callbacks=[]):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
......@@ -2299,6 +2299,10 @@ class MaskRCNN():
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: (list) Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
"""
assert self.mode == "training", "Create model in training mode."
......@@ -2330,6 +2334,9 @@ class MaskRCNN():
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
verbose=0, save_weights_only=True),
]
# Add custom callbacks to the list
callbacks+=custom_callbacks
# Train
log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册