From fb620d39e50150fa651c25bb0d164d1ff94f2bb3 Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Sat, 30 May 2020 15:34:19 +0800 Subject: [PATCH] fix bug introduced by gpu support --- example/resnet101_imagenet2012/config.py | 2 +- example/resnet50_cifar10/config.py | 2 +- example/resnet50_cifar10/train.py | 6 ++++-- example/resnet50_imagenet2012/config.py | 2 +- example/resnet50_imagenet2012/train.py | 2 ++ 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/example/resnet101_imagenet2012/config.py b/example/resnet101_imagenet2012/config.py index 5f07014ad..594b28522 100755 --- a/example/resnet101_imagenet2012/config.py +++ b/example/resnet101_imagenet2012/config.py @@ -29,7 +29,7 @@ config = ed({ "image_height": 224, "image_width": 224, "save_checkpoint": True, - "save_checkpoint_epochs": 1, + "save_checkpoint_epochs": 5, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 0, diff --git a/example/resnet50_cifar10/config.py b/example/resnet50_cifar10/config.py index c148e4329..3c50a6aae 100755 --- a/example/resnet50_cifar10/config.py +++ b/example/resnet50_cifar10/config.py @@ -28,7 +28,7 @@ config = ed({ "image_height": 224, "image_width": 224, "save_checkpoint": True, - "save_checkpoint_steps": 1950, + "save_checkpoint_epochs": 5, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 5, diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index 93efed733..275f7188a 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -43,6 +43,8 @@ args_opt = parser.parse_args() if __name__ == '__main__': target = args_opt.device_target + ckpt_save_dir = config.save_checkpoint_path + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) @@ -80,13 +82,13 @@ if __name__ == '__main__': else: loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=True) + amp_level="O2", keep_batchnorm_fp32=False) time_cb = TimeMonitor(data_size=step_size) loss_cb = LossMonitor() cb = [time_cb, loss_cb] if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size, keep_checkpoint_max=config.keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) cb += [ckpt_cb] diff --git a/example/resnet50_imagenet2012/config.py b/example/resnet50_imagenet2012/config.py index e33c2b6aa..cf5093d24 100755 --- a/example/resnet50_imagenet2012/config.py +++ b/example/resnet50_imagenet2012/config.py @@ -29,7 +29,7 @@ config = ed({ "image_height": 224, "image_width": 224, "save_checkpoint": True, - "save_checkpoint_epochs": 1, + "save_checkpoint_epochs": 5, "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 0, diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py index 630148042..a76de78f6 100755 --- a/example/resnet50_imagenet2012/train.py +++ b/example/resnet50_imagenet2012/train.py @@ -46,6 +46,8 @@ args_opt = parser.parse_args() if __name__ == '__main__': target = args_opt.device_target + ckpt_save_dir = config.save_checkpoint_path + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) -- GitLab