提交 059ce9d6 编写于 作者: W Waleed Abdulla

Add --logs argument to coco.py

上级 33618386
......@@ -51,8 +51,9 @@ ROOT_DIR = os.getcwd()
# Path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
# Directory to save logs and model checkpoints, if not provided
# through the command line argument --logs
DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")
############################################################
......@@ -321,10 +322,15 @@ if __name__ == '__main__':
parser.add_argument('--model', required=True,
metavar="/path/to/weights.h5",
help="Path to weights .h5 file or 'coco'")
parser.add_argument('--logs', required=False,
default=DEFAULT_LOGS_DIR,
metavar="/path/to/logs/",
help='Directory to save logs and checkpoints. Defaults to logs/')
args = parser.parse_args()
print("Command: ", args.command)
print("Model: ", args.model)
print("Dataset: ", args.dataset)
print("Logs: ", args.logs)
# Configurations
if args.command == "train":
......@@ -341,10 +347,10 @@ if __name__ == '__main__':
# Create model
if args.command == "train":
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)
model_dir=args.logs)
else:
model = modellib.MaskRCNN(mode="inference", config=config,
model_dir=MODEL_DIR)
model_dir=args.logs)
# Select weights file to load
if args.model.lower() == "coco":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册