diff --git a/dygraph/mobilenet/reader.py b/dygraph/mobilenet/reader.py index b96d1366690edc34f62be75546c47c49476d85e1..bba33c355ba02983c5d9d54b3bc5f2535d53cfb1 100644 --- a/dygraph/mobilenet/reader.py +++ b/dygraph/mobilenet/reader.py @@ -256,8 +256,9 @@ def process_batch_data(input_data, settings, mode, color_jitter, rotate): class ImageNetReader: - def __init__(self, seed=None): + def __init__(self, seed=None, place_num=1): self.shuffle_seed = seed + self.place_num = place_num def set_shuffle_seed(self, seed): assert isinstance(seed, int), "shuffle seed must be int" @@ -275,8 +276,7 @@ class ImageNetReader: if mode == 'test': batch_size = 1 else: - batch_size = settings.batch_size / paddle.fluid.core.get_cuda_device_count( - ) + batch_size = settings.batch_size / self.place_num def reader(): def read_file_list(): @@ -365,8 +365,7 @@ class ImageNetReader: reader = create_mixup_reader(settings, reader) reader = fluid.io.batch( reader, - batch_size=int(settings.batch_size / - paddle.fluid.core.get_cuda_device_count()), + batch_size=int(settings.batch_size / self.place_num), drop_last=True) return reader diff --git a/dygraph/mobilenet/train.py b/dygraph/mobilenet/train.py index 254279baedf3879ada6bc5c92ab3f733e5f3d524..fbf5d54beac044f76228076eb5d6f13e70e252af 100644 --- a/dygraph/mobilenet/train.py +++ b/dygraph/mobilenet/train.py @@ -42,11 +42,12 @@ def eval(net, test_data_loader, eop): total_acc5 = 0.0 total_sample = 0 t_last = 0 + place_num = paddle.fluid.core.get_cuda_device_count( + ) if args.use_gpu else int(os.environ.get('CPU_NUM', 1)) for img, label in test_data_loader(): t1 = time.time() label = to_variable(label.numpy().astype('int64').reshape( - int(args.batch_size // paddle.fluid.core.get_cuda_device_count()), - 1)) + int(args.batch_size // place_num), 1)) out = net(img) softmax_out = fluid.layers.softmax(out, use_cudnn=False) loss = fluid.layers.cross_entropy(input=softmax_out, label=label) @@ -77,6 +78,8 @@ def train_mobilenet(): place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) with fluid.dygraph.guard(place): # 1. init net and optimizer + place_num = paddle.fluid.core.get_cuda_device_count( + ) if args.use_gpu else int(os.environ.get('CPU_NUM', 1)) if args.ce: print("ce mode") seed = 33 @@ -118,7 +121,7 @@ def train_mobilenet(): test_data_loader, test_data = utility.create_data_loader( is_train=False, args=args) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - imagenet_reader = reader.ImageNetReader(0) + imagenet_reader = reader.ImageNetReader(seed=0, place_num=place_num) train_reader = imagenet_reader.train(settings=args) test_reader = imagenet_reader.val(settings=args) train_data_loader.set_sample_list_generator(train_reader, place) @@ -140,8 +143,7 @@ def train_mobilenet(): for img, label in train_data_loader(): t1 = time.time() label = to_variable(label.numpy().astype('int64').reshape( - int(args.batch_size // - paddle.fluid.core.get_cuda_device_count()), 1)) + int(args.batch_size // place_num), 1)) t_start = time.time() # 4.1.1 call net() @@ -190,7 +192,7 @@ def train_mobilenet(): print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f %2.4f sec" % \ (eop, batch_id, total_loss / total_sample, \ total_acc1 / total_sample, total_acc5 / total_sample, train_batch_elapse)) - + # 4.2 save checkpoint save_parameters = (not args.use_data_parallel) or ( args.use_data_parallel and @@ -199,7 +201,8 @@ def train_mobilenet(): if not os.path.isdir(args.model_save_dir): os.makedirs(args.model_save_dir) model_path = os.path.join( - args.model_save_dir, "_" + model_path_pre + "_epoch{}".format(eop)) + args.model_save_dir, + "_" + model_path_pre + "_epoch{}".format(eop)) fluid.dygraph.save_dygraph(net.state_dict(), model_path) fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path) @@ -212,8 +215,8 @@ def train_mobilenet(): args.use_data_parallel and fluid.dygraph.parallel.Env().local_rank == 0) if save_parameters: - model_path = os.path.join( - args.model_save_dir, "_" + model_path_pre + "_final") + model_path = os.path.join(args.model_save_dir, + "_" + model_path_pre + "_final") fluid.dygraph.save_dygraph(net.state_dict(), model_path)