From cbff80f3e3f653a9eeee43d0d383a0385aba546b Mon Sep 17 00:00:00 2001 From: Waleed Abdulla Date: Tue, 5 Jun 2018 23:06:45 -0700 Subject: [PATCH] Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message. --- mrcnn/model.py | 14 +++++++++----- samples/balloon/balloon.py | 2 +- samples/balloon/inspect_balloon_model.ipynb | 2 +- samples/coco/coco.py | 2 +- samples/coco/inspect_model.ipynb | 2 +- samples/coco/inspect_weights.ipynb | 2 +- samples/nucleus/inspect_nucleus_model.ipynb | 2 +- samples/nucleus/nucleus.py | 2 +- samples/shapes/train_shapes.ipynb | 7 +++---- 9 files changed, 19 insertions(+), 16 deletions(-) diff --git a/mrcnn/model.py b/mrcnn/model.py index b676033..aeb83f3 100644 --- a/mrcnn/model.py +++ b/mrcnn/model.py @@ -2062,8 +2062,7 @@ class MaskRCNN(): """Finds the last checkpoint file of the last trained model in the model directory. Returns: - log_dir: The directory where events and weights are saved - checkpoint_path: the path to the last checkpoint file + The path of the last checkpoint file """ # Get directory names. Each directory corresponds to a model dir_names = next(os.walk(self.model_dir))[1] @@ -2071,7 +2070,10 @@ class MaskRCNN(): dir_names = filter(lambda f: f.startswith(key), dir_names) dir_names = sorted(dir_names) if not dir_names: - return None, None + import errno + raise FileNotFoundError( + errno.ENOENT, + "Could not find model directory under {}".format(self.model_dir)) # Pick last directory dir_name = os.path.join(self.model_dir, dir_names[-1]) # Find the last checkpoint @@ -2079,9 +2081,11 @@ class MaskRCNN(): checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints) checkpoints = sorted(checkpoints) if not checkpoints: - return dir_name, None + import errno + raise FileNotFoundError( + errno.ENOENT, "Could not find weight files in {}".format(dir_name)) checkpoint = os.path.join(dir_name, checkpoints[-1]) - return dir_name, checkpoint + return checkpoint def load_weights(self, filepath, by_name=False, exclude=None): """Modified version of the correspoding Keras function with diff --git a/samples/balloon/balloon.py b/samples/balloon/balloon.py index 0fbd2d2..7f4d34d 100644 --- a/samples/balloon/balloon.py +++ b/samples/balloon/balloon.py @@ -336,7 +336,7 @@ if __name__ == '__main__': utils.download_trained_weights(weights_path) elif args.weights.lower() == "last": # Find last trained weights - weights_path = model.find_last()[1] + weights_path = model.find_last() elif args.weights.lower() == "imagenet": # Start from ImageNet trained weights weights_path = model.get_imagenet_weights() diff --git a/samples/balloon/inspect_balloon_model.ipynb b/samples/balloon/inspect_balloon_model.ipynb index f8923dd..a554acd 100644 --- a/samples/balloon/inspect_balloon_model.ipynb +++ b/samples/balloon/inspect_balloon_model.ipynb @@ -265,7 +265,7 @@ "# weights_path = \"/path/to/mask_rcnn_balloon.h5\"\n", "\n", "# Or, load the last model you trained\n", - "weights_path = model.find_last()[1]\n", + "weights_path = model.find_last()\n", "\n", "# Load weights\n", "print(\"Loading weights \", weights_path)\n", diff --git a/samples/coco/coco.py b/samples/coco/coco.py index 61fe0fa..e844724 100644 --- a/samples/coco/coco.py +++ b/samples/coco/coco.py @@ -462,7 +462,7 @@ if __name__ == '__main__': model_path = COCO_MODEL_PATH elif args.model.lower() == "last": # Find last trained weights - model_path = model.find_last()[1] + model_path = model.find_last() elif args.model.lower() == "imagenet": # Start from ImageNet trained weights model_path = model.get_imagenet_weights() diff --git a/samples/coco/inspect_model.ipynb b/samples/coco/inspect_model.ipynb index 5a278ba..a116f12 100644 --- a/samples/coco/inspect_model.ipynb +++ b/samples/coco/inspect_model.ipynb @@ -270,7 +270,7 @@ "elif config.NAME == \"coco\":\n", " weights_path = COCO_MODEL_PATH\n", "# Or, uncomment to load the last model you trained\n", - "# weights_path = model.find_last()[1]\n", + "# weights_path = model.find_last()\n", "\n", "# Load weights\n", "print(\"Loading weights \", weights_path)\n", diff --git a/samples/coco/inspect_weights.ipynb b/samples/coco/inspect_weights.ipynb index 3443ad9..9ded431 100644 --- a/samples/coco/inspect_weights.ipynb +++ b/samples/coco/inspect_weights.ipynb @@ -150,7 +150,7 @@ "elif config.NAME == \"coco\":\n", " weights_path = COCO_MODEL_PATH\n", "# Or, uncomment to load the last model you trained\n", - "# weights_path = model.find_last()[1]\n", + "# weights_path = model.find_last()\n", "\n", "# Load weights\n", "print(\"Loading weights \", weights_path)\n", diff --git a/samples/nucleus/inspect_nucleus_model.ipynb b/samples/nucleus/inspect_nucleus_model.ipynb index 7210cf4..0c0c71a 100644 --- a/samples/nucleus/inspect_nucleus_model.ipynb +++ b/samples/nucleus/inspect_nucleus_model.ipynb @@ -258,7 +258,7 @@ "# weights_path = \"/path/to/mask_rcnn_nucleus.h5\"\n", "\n", "# Or, load the last model you trained\n", - "weights_path = model.find_last()[1]\n", + "weights_path = model.find_last()\n", "\n", "# Load weights\n", "print(\"Loading weights \", weights_path)\n", diff --git a/samples/nucleus/nucleus.py b/samples/nucleus/nucleus.py index cbf2946..c3f16b8 100644 --- a/samples/nucleus/nucleus.py +++ b/samples/nucleus/nucleus.py @@ -464,7 +464,7 @@ if __name__ == '__main__': utils.download_trained_weights(weights_path) elif args.weights.lower() == "last": # Find last trained weights - weights_path = model.find_last()[1] + weights_path = model.find_last() elif args.weights.lower() == "imagenet": # Start from ImageNet trained weights weights_path = model.get_imagenet_weights() diff --git a/samples/shapes/train_shapes.ipynb b/samples/shapes/train_shapes.ipynb index f29f2b1..d1574cb 100644 --- a/samples/shapes/train_shapes.ipynb +++ b/samples/shapes/train_shapes.ipynb @@ -458,7 +458,7 @@ " \"mrcnn_bbox\", \"mrcnn_mask\"])\n", "elif init_with == \"last\":\n", " # Load the last model you trained and continue training\n", - " model.load_weights(model.find_last()[1], by_name=True)" + " model.load_weights(model.find_last(), by_name=True)" ] }, { @@ -875,10 +875,9 @@ "# Get path to saved weights\n", "# Either set a specific path or find last trained weights\n", "# model_path = os.path.join(ROOT_DIR, \".h5 file name here\")\n", - "model_path = model.find_last()[1]\n", + "model_path = model.find_last()\n", "\n", - "# Load trained weights (fill in path to trained weights here)\n", - "assert model_path != \"\", \"Provide path to trained weights\"\n", + "# Load trained weights\n", "print(\"Loading weights from \", model_path)\n", "model.load_weights(model_path, by_name=True)" ] -- GitLab