diff --git a/dygraph/infer.py b/dygraph/infer.py index 1cc15d319f09e86693eb35006fd6d7efc3f5becc..f5caf7a435d3083f7d84106024096684a9d4f3b8 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -37,12 +37,8 @@ def parse_args(): parser.add_argument( '--model_name', dest='model_name', - help= - 'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' - '"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", ' - '"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", ' - '"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", ' - '"SE_HRNet_W60", "SE_HRNet_W64")', + help='Model type for testing, which is one of {}'.format( + str(list(MODELS.keys()))), type=str, default='UNet') diff --git a/dygraph/models/hrnet.py b/dygraph/models/hrnet.py index fac8a929be40acce2d801c3cdbbe89bb634bead3..bccc303bb435e48554991a21b4fd72dd90a3cb37 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/hrnet.py @@ -18,7 +18,8 @@ import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear +from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm __all__ = [ "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", diff --git a/dygraph/train.py b/dygraph/train.py index 8573591e25f2964610bd3da33b224c52bcfe1da9..709a66bb8c7f55dac0a83e5435a42893eb4d2e9a 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -38,12 +38,8 @@ def parse_args(): parser.add_argument( '--model_name', dest='model_name', - help= - 'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' - '"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", ' - '"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", ' - '"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", ' - '"SE_HRNet_W60", "SE_HRNet_W64")', + help='Model type for training, which is one of {}'.format( + str(list(MODELS.keys()))), type=str, default='UNet') @@ -186,6 +182,7 @@ def train(model, total_steps = steps_per_epoch * (num_epochs - start_epoch) num_steps = 0 best_mean_iou = -1.0 + best_model_epoch = -1 for epoch in range(start_epoch, num_epochs): for step, data in enumerate(loader): images = data[0] @@ -245,9 +242,9 @@ def train(model, best_model_dir = os.path.join(save_dir, "best_model") fluid.save_dygraph(model.state_dict(), os.path.join(best_model_dir, 'model')) - logging.info( - 'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}' - .format(best_model_epoch, best_mean_iou)) + logging.info( + 'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}' + .format(best_model_epoch, best_mean_iou)) if use_vdl: log_writer.add_scalar('Evaluate/mean_iou', mean_iou, diff --git a/dygraph/val.py b/dygraph/val.py index 36d4242966f0e98e381130c533d67c46b31aefe1..41d0d33485d1052bef3b1c4d70b546cdf89d3922 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -39,12 +39,8 @@ def parse_args(): parser.add_argument( '--model_name', dest='model_name', - help= - 'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' - '"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", ' - '"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", ' - '"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", ' - '"SE_HRNet_W60", "SE_HRNet_W64")', + help='Model type for evaluation, which is one of {}'.format( + str(list(MODELS.keys()))), type=str, default='UNet')