提交 06d51479 编写于 作者: R root

update

......@@ -37,12 +37,8 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help=
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
help='Model type for testing, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str,
default='UNet')
......
......@@ -18,7 +18,8 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
__all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......
......@@ -38,12 +38,8 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help=
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str,
default='UNet')
......@@ -186,6 +182,7 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0
best_mean_iou = -1.0
best_model_epoch = -1
for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader):
images = data[0]
......@@ -245,9 +242,9 @@ def train(model,
best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model'))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
if use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
......
......@@ -39,12 +39,8 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help=
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
help='Model type for evaluation, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str,
default='UNet')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册