提交 cbff80f3 编写于 作者: W Waleed Abdulla

Raise clear error if last training weights are not foundIf using the...

Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.
上级 a688a66b
...@@ -2062,8 +2062,7 @@ class MaskRCNN(): ...@@ -2062,8 +2062,7 @@ class MaskRCNN():
"""Finds the last checkpoint file of the last trained model in the """Finds the last checkpoint file of the last trained model in the
model directory. model directory.
Returns: Returns:
log_dir: The directory where events and weights are saved The path of the last checkpoint file
checkpoint_path: the path to the last checkpoint file
""" """
# Get directory names. Each directory corresponds to a model # Get directory names. Each directory corresponds to a model
dir_names = next(os.walk(self.model_dir))[1] dir_names = next(os.walk(self.model_dir))[1]
...@@ -2071,7 +2070,10 @@ class MaskRCNN(): ...@@ -2071,7 +2070,10 @@ class MaskRCNN():
dir_names = filter(lambda f: f.startswith(key), dir_names) dir_names = filter(lambda f: f.startswith(key), dir_names)
dir_names = sorted(dir_names) dir_names = sorted(dir_names)
if not dir_names: if not dir_names:
return None, None import errno
raise FileNotFoundError(
errno.ENOENT,
"Could not find model directory under {}".format(self.model_dir))
# Pick last directory # Pick last directory
dir_name = os.path.join(self.model_dir, dir_names[-1]) dir_name = os.path.join(self.model_dir, dir_names[-1])
# Find the last checkpoint # Find the last checkpoint
...@@ -2079,9 +2081,11 @@ class MaskRCNN(): ...@@ -2079,9 +2081,11 @@ class MaskRCNN():
checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints) checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints)
checkpoints = sorted(checkpoints) checkpoints = sorted(checkpoints)
if not checkpoints: if not checkpoints:
return dir_name, None import errno
raise FileNotFoundError(
errno.ENOENT, "Could not find weight files in {}".format(dir_name))
checkpoint = os.path.join(dir_name, checkpoints[-1]) checkpoint = os.path.join(dir_name, checkpoints[-1])
return dir_name, checkpoint return checkpoint
def load_weights(self, filepath, by_name=False, exclude=None): def load_weights(self, filepath, by_name=False, exclude=None):
"""Modified version of the correspoding Keras function with """Modified version of the correspoding Keras function with
......
...@@ -336,7 +336,7 @@ if __name__ == '__main__': ...@@ -336,7 +336,7 @@ if __name__ == '__main__':
utils.download_trained_weights(weights_path) utils.download_trained_weights(weights_path)
elif args.weights.lower() == "last": elif args.weights.lower() == "last":
# Find last trained weights # Find last trained weights
weights_path = model.find_last()[1] weights_path = model.find_last()
elif args.weights.lower() == "imagenet": elif args.weights.lower() == "imagenet":
# Start from ImageNet trained weights # Start from ImageNet trained weights
weights_path = model.get_imagenet_weights() weights_path = model.get_imagenet_weights()
......
...@@ -265,7 +265,7 @@ ...@@ -265,7 +265,7 @@
"# weights_path = \"/path/to/mask_rcnn_balloon.h5\"\n", "# weights_path = \"/path/to/mask_rcnn_balloon.h5\"\n",
"\n", "\n",
"# Or, load the last model you trained\n", "# Or, load the last model you trained\n",
"weights_path = model.find_last()[1]\n", "weights_path = model.find_last()\n",
"\n", "\n",
"# Load weights\n", "# Load weights\n",
"print(\"Loading weights \", weights_path)\n", "print(\"Loading weights \", weights_path)\n",
...@@ -462,7 +462,7 @@ if __name__ == '__main__': ...@@ -462,7 +462,7 @@ if __name__ == '__main__':
model_path = COCO_MODEL_PATH model_path = COCO_MODEL_PATH
elif args.model.lower() == "last": elif args.model.lower() == "last":
# Find last trained weights # Find last trained weights
model_path = model.find_last()[1] model_path = model.find_last()
elif args.model.lower() == "imagenet": elif args.model.lower() == "imagenet":
# Start from ImageNet trained weights # Start from ImageNet trained weights
model_path = model.get_imagenet_weights() model_path = model.get_imagenet_weights()
......
...@@ -270,7 +270,7 @@ ...@@ -270,7 +270,7 @@
"elif config.NAME == \"coco\":\n", "elif config.NAME == \"coco\":\n",
" weights_path = COCO_MODEL_PATH\n", " weights_path = COCO_MODEL_PATH\n",
"# Or, uncomment to load the last model you trained\n", "# Or, uncomment to load the last model you trained\n",
"# weights_path = model.find_last()[1]\n", "# weights_path = model.find_last()\n",
"\n", "\n",
"# Load weights\n", "# Load weights\n",
"print(\"Loading weights \", weights_path)\n", "print(\"Loading weights \", weights_path)\n",
...@@ -150,7 +150,7 @@ ...@@ -150,7 +150,7 @@
"elif config.NAME == \"coco\":\n", "elif config.NAME == \"coco\":\n",
" weights_path = COCO_MODEL_PATH\n", " weights_path = COCO_MODEL_PATH\n",
"# Or, uncomment to load the last model you trained\n", "# Or, uncomment to load the last model you trained\n",
"# weights_path = model.find_last()[1]\n", "# weights_path = model.find_last()\n",
"\n", "\n",
"# Load weights\n", "# Load weights\n",
"print(\"Loading weights \", weights_path)\n", "print(\"Loading weights \", weights_path)\n",
...@@ -258,7 +258,7 @@ ...@@ -258,7 +258,7 @@
"# weights_path = \"/path/to/mask_rcnn_nucleus.h5\"\n", "# weights_path = \"/path/to/mask_rcnn_nucleus.h5\"\n",
"\n", "\n",
"# Or, load the last model you trained\n", "# Or, load the last model you trained\n",
"weights_path = model.find_last()[1]\n", "weights_path = model.find_last()\n",
"\n", "\n",
"# Load weights\n", "# Load weights\n",
"print(\"Loading weights \", weights_path)\n", "print(\"Loading weights \", weights_path)\n",
...@@ -464,7 +464,7 @@ if __name__ == '__main__': ...@@ -464,7 +464,7 @@ if __name__ == '__main__':
utils.download_trained_weights(weights_path) utils.download_trained_weights(weights_path)
elif args.weights.lower() == "last": elif args.weights.lower() == "last":
# Find last trained weights # Find last trained weights
weights_path = model.find_last()[1] weights_path = model.find_last()
elif args.weights.lower() == "imagenet": elif args.weights.lower() == "imagenet":
# Start from ImageNet trained weights # Start from ImageNet trained weights
weights_path = model.get_imagenet_weights() weights_path = model.get_imagenet_weights()
......
...@@ -458,7 +458,7 @@ ...@@ -458,7 +458,7 @@
" \"mrcnn_bbox\", \"mrcnn_mask\"])\n", " \"mrcnn_bbox\", \"mrcnn_mask\"])\n",
"elif init_with == \"last\":\n", "elif init_with == \"last\":\n",
" # Load the last model you trained and continue training\n", " # Load the last model you trained and continue training\n",
" model.load_weights(model.find_last()[1], by_name=True)" " model.load_weights(model.find_last(), by_name=True)"
] ]
}, },
{ {
...@@ -875,10 +875,9 @@ ...@@ -875,10 +875,9 @@
"# Get path to saved weights\n", "# Get path to saved weights\n",
"# Either set a specific path or find last trained weights\n", "# Either set a specific path or find last trained weights\n",
"# model_path = os.path.join(ROOT_DIR, \".h5 file name here\")\n", "# model_path = os.path.join(ROOT_DIR, \".h5 file name here\")\n",
"model_path = model.find_last()[1]\n", "model_path = model.find_last()\n",
"\n", "\n",
"# Load trained weights (fill in path to trained weights here)\n", "# Load trained weights\n",
"assert model_path != \"\", \"Provide path to trained weights\"\n",
"print(\"Loading weights from \", model_path)\n", "print(\"Loading weights from \", model_path)\n",
"model.load_weights(model_path, by_name=True)" "model.load_weights(model_path, by_name=True)"
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册