diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index 275f7188a7c396a5b9aa850e018f2f579a817730..323695ae291e0cdc64dfbfe66054304615d1daa7 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -15,6 +15,7 @@ """train_imagenet.""" import os import argparse +import numpy as np from dataset import create_dataset from lr_generator import get_lr from config import config @@ -45,6 +46,7 @@ if __name__ == '__main__': target = args_opt.device_target ckpt_save_dir = config.save_checkpoint_path context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + np.random.seed(1) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py index a76de78f6d5780fb133cad9e6e2bd61db3e3a878..abb55731dce84d878c376bdaba4c096e87b594a9 100755 --- a/example/resnet50_imagenet2012/train.py +++ b/example/resnet50_imagenet2012/train.py @@ -15,6 +15,7 @@ """train_imagenet.""" import os import argparse +import numpy as np from dataset import create_dataset from lr_generator import get_lr from config import config @@ -48,6 +49,7 @@ if __name__ == '__main__': target = args_opt.device_target ckpt_save_dir = config.save_checkpoint_path context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + np.random.seed(1) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID'))