提交 a4f2e728 编写于 作者: M ms_yan

repair vgg precision problem

上级 81833943
......@@ -158,7 +158,7 @@ def test(cloud_args=None):
args.models = [args.pre_trained,]
for model in args.models:
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size)
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, mode='eval')
eval_dataloader = dataset.create_tuple_iterator()
network = vgg16(args.num_classes, args, phase="test")
......
......@@ -64,7 +64,7 @@ imagenet_cfg = edict({
"image_size": '224,224',
"pad_mode": 'pad',
"padding": 1,
"has_bias": True,
"has_bias": False,
"batch_norm": False,
"keep_checkpoint_max": 10,
"initialize_mode": "KaimingNormal",
......
......@@ -31,10 +31,11 @@ def _make_layer(base, args, batch_norm):
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
if args.initialize_mode == "KaimingNormal":
weight = 'normal'
weight = 'ones'
if args.initialize_mode == "XavierUniform":
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=v,
kernel_size=3,
......
......@@ -127,7 +127,7 @@ def parse_args(cloud_args=None):
# logging and checkpoint related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval')
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
......@@ -200,12 +200,12 @@ if __name__ == '__main__':
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
parameter_broadcast=True, mirror_mean=True)
else:
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# select for master rank save ckpt or all rank save, compatiable for model parallel
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册