diff --git a/coco.py b/coco.py index ff4b11a04bff5889290b8ee5af268055a6f90fcc..94b6cd18b068cb6233002709ad240cc93f4b84bc 100644 --- a/coco.py +++ b/coco.py @@ -51,8 +51,9 @@ ROOT_DIR = os.getcwd() # Path to trained weights file COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5") -# Directory to save logs and trained model -MODEL_DIR = os.path.join(ROOT_DIR, "logs") +# Directory to save logs and model checkpoints, if not provided +# through the command line argument --logs +DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs") ############################################################ @@ -321,10 +322,15 @@ if __name__ == '__main__': parser.add_argument('--model', required=True, metavar="/path/to/weights.h5", help="Path to weights .h5 file or 'coco'") + parser.add_argument('--logs', required=False, + default=DEFAULT_LOGS_DIR, + metavar="/path/to/logs/", + help='Directory to save logs and checkpoints. Defaults to logs/') args = parser.parse_args() print("Command: ", args.command) print("Model: ", args.model) print("Dataset: ", args.dataset) + print("Logs: ", args.logs) # Configurations if args.command == "train": @@ -341,10 +347,10 @@ if __name__ == '__main__': # Create model if args.command == "train": model = modellib.MaskRCNN(mode="training", config=config, - model_dir=MODEL_DIR) + model_dir=args.logs) else: model = modellib.MaskRCNN(mode="inference", config=config, - model_dir=MODEL_DIR) + model_dir=args.logs) # Select weights file to load if args.model.lower() == "coco":