提交 06d51479 编写于 作者: R root

update

...@@ -37,12 +37,8 @@ def parse_args(): ...@@ -37,12 +37,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help= help='Model type for testing, which is one of {}'.format(
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' str(list(MODELS.keys()))),
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
......
...@@ -18,7 +18,8 @@ import paddle ...@@ -18,7 +18,8 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
__all__ = [ __all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......
...@@ -38,12 +38,8 @@ def parse_args(): ...@@ -38,12 +38,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help= help='Model type for training, which is one of {}'.format(
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' str(list(MODELS.keys()))),
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
...@@ -186,6 +182,7 @@ def train(model, ...@@ -186,6 +182,7 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch) total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0 num_steps = 0
best_mean_iou = -1.0 best_mean_iou = -1.0
best_model_epoch = -1
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = data[0] images = data[0]
...@@ -245,9 +242,9 @@ def train(model, ...@@ -245,9 +242,9 @@ def train(model,
best_model_dir = os.path.join(save_dir, "best_model") best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(), fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model')) os.path.join(best_model_dir, 'model'))
logging.info( logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}' 'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou)) .format(best_model_epoch, best_mean_iou))
if use_vdl: if use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou, log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
......
...@@ -39,12 +39,8 @@ def parse_args(): ...@@ -39,12 +39,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help= help='Model type for evaluation, which is one of {}'.format(
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' str(list(MODELS.keys()))),
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册