diff --git a/PaddleCV/image_classification/README.md b/PaddleCV/image_classification/README.md index 402e0efd43531e7a014f967ec300c5ed6a7daca1..6bc4291c4a5140e737bbc7ec7ccbb78eaf752974 100644 --- a/PaddleCV/image_classification/README.md +++ b/PaddleCV/image_classification/README.md @@ -90,6 +90,23 @@ python train.py \ bash run.sh train 模型名 ``` +**多进程模型训练:** + +如果你有多张GPU卡的话,我们强烈建议你使用多进程模式来训练模型,这会极大的提升训练速度。启动方式如下: +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch train.py \ + --model=ResNet50 \ + --batch_size=256 \ + --total_images=1281167 \ + --class_dim=1000 \ + --image_shape=3,224,224 \ + --model_save_dir=output/ \ + --lr_strategy=piecewise_decay \ + --reader_thread=4 \ + --lr=0.1 +``` +或者参考 scripts/train/ResNet50_dist.sh + **参数说明:** 环境配置部分: diff --git a/PaddleCV/image_classification/README_en.md b/PaddleCV/image_classification/README_en.md index 5511f1427f93870e2d5e905dd310c5a7b0793596..b66f5aeffcacd491cd55b3b16e78724c9d82b98b 100644 --- a/PaddleCV/image_classification/README_en.md +++ b/PaddleCV/image_classification/README_en.md @@ -82,6 +82,25 @@ or running run.sh scripts bash run.sh train model_name ``` +**multiprocess training:** + +If you have multiple gpus, this method is strongly recommended, because it can improve training speed dramatically. +You can start the multiprocess training step by: +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch train.py \ + --model=ResNet50 \ + --batch_size=256 \ + --total_images=1281167 \ + --class_dim=1000 \ + --image_shape=3,224,224 \ + --model_save_dir=output/ \ + --lr_strategy=piecewise_decay \ + --reader_thread=4 \ + --lr=0.1 +``` + +or reference scripts/train/ResNet50_dist.sh + **parameter introduction:** Environment settings: diff --git a/PaddleCV/image_classification/eval.py b/PaddleCV/image_classification/eval.py index 31a1e64fa9f8bea9359c32cb793a0f31ff14859d..1594774894fc86fd82d711dd9845b9601f5d8c50 100644 --- a/PaddleCV/image_classification/eval.py +++ b/PaddleCV/image_classification/eval.py @@ -101,9 +101,9 @@ def eval(args): exe.run(fluid.default_startup_program()) fluid.io.load_persistables(exe, args.pretrained_model) + imagenet_reader = reader.ImageNetReader() + val_reader = imagenet_reader.val(settings=args) - val_reader = paddle.batch( - reader.val(settings=args), batch_size=args.batch_size) feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) test_info = [[], [], []] diff --git a/PaddleCV/image_classification/fast_imagenet/train.py b/PaddleCV/image_classification/fast_imagenet/train.py index f2cdc75a61ce949453a7f4890ddb3236fac7e11f..2e2352468e13bac3d92e0cb470c9c72e08a99c15 100644 --- a/PaddleCV/image_classification/fast_imagenet/train.py +++ b/PaddleCV/image_classification/fast_imagenet/train.py @@ -284,7 +284,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, min_scale=min_scale, shuffle_seed=epoch_id + 1) train_py_reader.decorate_paddle_reader( - paddle.batch( + fluid.io.batch( train_reader, batch_size=train_bs)) test_reader = reader.test( @@ -292,7 +292,7 @@ def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, bs=val_bs * DEVICE_NUM, sz=img_dim, rect_val=rect_val) - test_batched_reader = paddle.batch( + test_batched_reader = fluid.io.batch( test_reader, batch_size=val_bs * DEVICE_NUM) return test_batched_reader diff --git a/PaddleCV/image_classification/infer.py b/PaddleCV/image_classification/infer.py index 93e06cb11e9cdc0c0a5853b9c03b5f946ef2b9d7..30fa925594e5f07f2f30591be2063e99d44b83b0 100644 --- a/PaddleCV/image_classification/infer.py +++ b/PaddleCV/image_classification/infer.py @@ -88,9 +88,9 @@ def infer(args): print("model: ", args.model, " is already saved") exit(0) - test_batch_size = 1 - test_reader = paddle.batch( - reader.test(settings=args), batch_size=test_batch_size) + args.test_batch_size = 1 + imagenet_reader = reader.ImageNetReader() + test_reader = imagenet_reader.test(settings=args) feeder = fluid.DataFeeder(place=place, feed_list=[image]) TOPK = args.topk diff --git a/PaddleCV/image_classification/legacy/dist_train/dist_train.py b/PaddleCV/image_classification/legacy/dist_train/dist_train.py index 8cc226f312c00b7919dd82c1c8edab4117821024..11ce147983e9dc36544e9150849dd03558959c56 100644 --- a/PaddleCV/image_classification/legacy/dist_train/dist_train.py +++ b/PaddleCV/image_classification/legacy/dist_train/dist_train.py @@ -92,7 +92,7 @@ def prepare_reader(is_train, pyreader, args, pass_id=1): bs = args.batch_size / get_device_num() else: bs = 16 - pyreader.decorate_paddle_reader(paddle.batch(reader, batch_size=bs)) + pyreader.decorate_paddle_reader(fluid.io.batch(reader, batch_size=bs)) def build_program(is_train, main_prog, startup_prog, args): diff --git a/PaddleCV/image_classification/legacy/reader_pil.py b/PaddleCV/image_classification/legacy/reader_pil.py index c445d233b5b212107a81fb4e8a470352a336ec0b..1f22fca4a30eefffeacff6a049a9e29c52f8837f 100755 --- a/PaddleCV/image_classification/legacy/reader_pil.py +++ b/PaddleCV/image_classification/legacy/reader_pil.py @@ -19,7 +19,7 @@ import functools import numpy as np from PIL import Image, ImageEnhance -import paddle +from paddle import fluid random.seed(0) np.random.seed(0) @@ -190,7 +190,7 @@ def _reader_creator(file_list, mapper = functools.partial( process_batch_data, mode=mode, color_jitter=color_jitter, rotate=rotate) - return paddle.reader.xmap_readers(mapper, data_reader, THREAD, BUF_SIZE) + return fluid.io.xmap_readers(mapper, data_reader, THREAD, BUF_SIZE) def train(batch_size, data_dir=DATA_DIR, shuffle_seed=0, infinite=False): diff --git a/PaddleCV/image_classification/reader.py b/PaddleCV/image_classification/reader.py index 997f788f949d9c1e0796fa92c24e52d84154f810..c600895c3b2ef7ddc8a1be06669ced2ec2d42f4c 100644 --- a/PaddleCV/image_classification/reader.py +++ b/PaddleCV/image_classification/reader.py @@ -20,6 +20,7 @@ import numpy as np import cv2 import paddle +from paddle import fluid from utils.autoaugment import ImageNetPolicy from PIL import Image @@ -163,18 +164,11 @@ def create_mixup_reader(settings, rd): tmp_l2 = [] tmp_lam = [] - batch_size = settings.batch_size alpha = settings.mixup_alpha def fetch_data(): - - data_list = [] - for i, item in enumerate(rd()): - data_list.append(item) - if i % batch_size == batch_size - 1: - - yield data_list - data_list = [] + for item in rd(): + yield item def mixup_data(): for data_list in fetch_data(): @@ -245,113 +239,151 @@ def process_image(sample, settings, mode, color_jitter, rotate): elif mode == 'test': return (img, ) - -def _reader_creator(settings, - file_list, - mode, - shuffle=False, - color_jitter=False, - rotate=False, - data_dir=None): - def reader(): - with open(file_list) as flist: - full_lines = [line.strip() for line in flist] - if mode != "test" and len(full_lines) < settings.batch_size: - print( - "Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!" - .format(len(full_lines), settings.batch_size)) - os._exit(1) - - if shuffle: - np.random.shuffle(full_lines) - for line in full_lines: - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - if not os.path.exists(img_path): - print("Warning: {} doesn't exist!".format(img_path)) - if mode == "train" or mode == "val": - yield img_path, int(label) - elif mode == "test": - yield [img_path] - - mapper = functools.partial( - process_image, - settings=settings, - mode=mode, - color_jitter=color_jitter, - rotate=rotate) - - return paddle.reader.xmap_readers( - mapper, - reader, - settings.reader_thread, - settings.reader_buf_size, - order=False) - - -def train(settings): - """Create a reader for trainning - - Args: - settings: arguments - - Returns: - train reader - """ - file_list = os.path.join(settings.data_dir, 'train_list.txt') - assert os.path.isfile( - file_list), "{} doesn't exist, please check data list path".format( - file_list) - - if 'use_aa' in settings and settings.use_aa: - global policy - policy = ImageNetPolicy() - - reader = _reader_creator( - settings, - file_list, - 'train', - shuffle=True, - color_jitter=False, - rotate=False, - data_dir=settings.data_dir) - - if settings.use_mixup == True: - reader = create_mixup_reader(settings, reader) - return reader - - -def val(settings): - """Create a reader for eval - - Args: - settings: arguments - - Returns: - eval reader - """ - - file_list = os.path.join(settings.data_dir, 'val_list.txt') - assert os.path.isfile( - file_list), "{} doesn't exist, please check data list path".format( - file_list) - - return _reader_creator( - settings, file_list, 'val', shuffle=False, data_dir=settings.data_dir) - - -def test(settings): - """Create a reader for testing - - Args: - settings: arguments - - Returns: - test reader - """ - file_list = os.path.join(settings.data_dir, 'val_list.txt') - assert os.path.isfile( - file_list), "{} doesn't exist, please check data list path".format( - file_list) - return _reader_creator( - settings, file_list, 'test', shuffle=False, data_dir=settings.data_dir) +def process_batch_data(input_data, settings, mode, color_jitter, rotate): + batch_data = [] + for sample in input_data: + if os.path.isfile(sample[0]): + batch_data.append( + process_image(sample, settings, mode, color_jitter, rotate)) + else: + print("File not exist : %s" % sample[0]) + return batch_data + +class ImageNetReader: + def __init__(self, seed=None): + self.shuffle_seed = seed + + def set_shuffle_seed(self, seed): + assert isinstance(seed, int), "shuffle seed must be int" + self.shuffle_seed = seed + + def _reader_creator(self, settings, + file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=None): + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + batch_size = settings.batch_size / paddle.fluid.core.get_cuda_device_count() + def reader(): + def read_file_list(): + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if mode != "test" and len(full_lines) < settings.batch_size: + print( + "Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!" + .format(len(full_lines), settings.batch_size)) + os._exit(1) + if num_trainers > 1 and mode == "train": + assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!" + np.random.RandomState(self.shuffle_seed).shuffle(full_lines) + elif shuffle: + np.random.shuffle(full_lines) + + batch_data = [] + for line in full_lines: + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + batch_data.append([img_path, int(label)]) + if len(batch_data) == batch_size: + if mode == 'train' or mode == 'val' or mode == 'test': + yield batch_data + + batch_data = [] + + return read_file_list + + data_reader = reader() + if mode == 'train' and num_trainers > 1: + assert self.shuffle_seed is not None, \ + "If num_trainers > 1, the shuffle_seed must be set, because " \ + "the order of batch data generated by reader " \ + "must be the same in the respective processes." + data_reader = paddle.fluid.contrib.reader.distributed_batch_reader(data_reader) + + mapper = functools.partial( + process_batch_data, + settings=settings, + mode=mode, + color_jitter=color_jitter, + rotate=rotate) + + return fluid.io.xmap_readers( + mapper, + data_reader, + settings.reader_thread, + settings.reader_buf_size, + order=False) + + + def train(self, settings): + """Create a reader for trainning + + Args: + settings: arguments + + Returns: + train reader + """ + file_list = os.path.join(settings.data_dir, 'train_list.txt') + assert os.path.isfile( + file_list), "{} doesn't exist, please check data list path".format( + file_list) + + if 'use_aa' in settings and settings.use_aa: + global policy + policy = ImageNetPolicy() + + reader = self._reader_creator( + settings, + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=settings.data_dir) + + if settings.use_mixup == True: + reader = create_mixup_reader(settings, reader) + reader = fluid.io.batch( + reader, + batch_size=int(settings.batch_size / paddle.fluid.core.get_cuda_device_count()), + drop_last=True) + return reader + + + def val(self, settings): + """Create a reader for eval + + Args: + settings: arguments + + Returns: + eval reader + """ + + file_list = os.path.join(settings.data_dir, 'val_list.txt') + assert os.path.isfile( + file_list), "{} doesn't exist, please check data list path".format( + file_list) + + return self._reader_creator( + settings, file_list, 'val', shuffle=False, data_dir=settings.data_dir) + + + def test(self, settings): + """Create a reader for testing + + Args: + settings: arguments + + Returns: + test reader + """ + file_list = os.path.join(settings.data_dir, 'val_list.txt') + assert os.path.isfile( + file_list), "{} doesn't exist, please check data list path".format( + file_list) + return self._reader_creator( + settings, file_list, 'test', shuffle=False, data_dir=settings.data_dir) diff --git a/PaddleCV/image_classification/scripts/train/ResNet50_dist.sh b/PaddleCV/image_classification/scripts/train/ResNet50_dist.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb74449d6ffd5ba9253d26f3ef4d262e74b56d85 --- /dev/null +++ b/PaddleCV/image_classification/scripts/train/ResNet50_dist.sh @@ -0,0 +1,19 @@ +##Training details +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export FLAGS_fast_eager_deletion_mode=1 +export FLAGS_eager_delete_tensor_gb=0.0 +export FLAGS_fraction_of_gpu_memory_to_use=0.98 + +#ResNet50: +python -m paddle.distributed.launch train.py \ + --model=ResNet50 \ + --batch_size=256 \ + --total_images=1281167 \ + --class_dim=1000 \ + --image_shape=3,224,224 \ + --model_save_dir=output/ \ + --lr_strategy=piecewise_decay \ + --num_epochs=120 \ + --lr=0.1 \ + --reader_thread=4 \ + --l2_decay=1e-4 diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index dfd0f591986ef53e46b438897828097b0dbfba0b..d7f7644b3657d59ce6d41a47bce9a284496984c2 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -42,7 +42,6 @@ from utils import * import models from build_model import create_model - def build_program(is_train, main_prog, startup_prog, args): """build program, and add grad op in program accroding to different mode @@ -167,25 +166,20 @@ def train(args): #init model by checkpoint or pretrianed model. init_model(exe, args, train_prog) - - train_reader = reader.train(settings=args) - train_reader = paddle.batch( - train_reader, - batch_size=int(args.batch_size / fluid.core.get_cuda_device_count()), - drop_last=True) - - test_reader = reader.val(settings=args) - test_reader = paddle.batch( - test_reader, batch_size=args.test_batch_size, drop_last=True) + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + imagenet_reader = reader.ImageNetReader(0 if num_trainers > 1 else None) + train_reader = imagenet_reader.train(settings=args) + test_reader = imagenet_reader.val(settings=args) train_py_reader.decorate_sample_list_generator(train_reader, place) test_py_reader.decorate_sample_list_generator(test_reader, place) compiled_train_prog = best_strategy_compiled(args, train_prog, - train_fetch_vars[0]) - + train_fetch_vars[0], exe) + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) for pass_id in range(args.num_epochs): - + if num_trainers > 1: + imagenet_reader.set_shuffle_seed(pass_id + (args.random_seed if args.random_seed else 0)) train_batch_id = 0 train_batch_time_record = [] train_batch_metrics_record = [] @@ -203,30 +197,32 @@ def train(args): train_batch_metrics_avg = np.mean( np.array(train_batch_metrics), axis=1) train_batch_metrics_record.append(train_batch_metrics_avg) - - print_info(pass_id, train_batch_id, args.print_step, - train_batch_metrics_avg, train_batch_elapse, "batch") - sys.stdout.flush() + if trainer_id == 0: + print_info(pass_id, train_batch_id, args.print_step, + train_batch_metrics_avg, train_batch_elapse, "batch") + sys.stdout.flush() train_batch_id += 1 except fluid.core.EOFException: train_py_reader.reset() - if args.use_ema: - print('ExponentialMovingAverage validate start...') - with ema.apply(exe): - validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) - print('ExponentialMovingAverage validate over!') + if trainer_id == 0: + if args.use_ema: + print('ExponentialMovingAverage validate start...') + with ema.apply(exe): + validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) + print('ExponentialMovingAverage validate over!') - validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) - #For now, save model per epoch. - if pass_id % args.save_step == 0: - save_model(args, exe, train_prog, pass_id) + validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) + #For now, save model per epoch. + if pass_id % args.save_step == 0: + save_model(args, exe, train_prog, pass_id) def main(): args = parse_args() - print_arguments(args) + if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: + print_arguments(args) check_args(args) train(args) diff --git a/PaddleCV/image_classification/utils/dist_utils.py b/PaddleCV/image_classification/utils/dist_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..29df3d3b110357653bd46723298de1d98d296659 --- /dev/null +++ b/PaddleCV/image_classification/utils/dist_utils.py @@ -0,0 +1,93 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import paddle.fluid as fluid + + +def nccl2_prepare(args, startup_prog, main_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + + envs = args.dist_env + + t.transpile( + envs["trainer_id"], + trainers=','.join(envs["trainer_endpoints"]), + current_endpoint=envs["current_endpoint"], + startup_program=startup_prog, + program=main_prog) + + +def pserver_prepare(args, train_prog, startup_prog): + config = fluid.DistributeTranspilerConfig() + config.slice_var_up = args.split_var + t = fluid.DistributeTranspiler(config=config) + envs = args.dist_env + training_role = envs["training_role"] + + t.transpile( + envs["trainer_id"], + program=train_prog, + pservers=envs["pserver_endpoints"], + trainers=envs["num_trainers"], + sync_mode=not args.async_mode, + startup_program=startup_prog) + if training_role == "PSERVER": + pserver_program = t.get_pserver_program(envs["current_endpoint"]) + pserver_startup_program = t.get_startup_program( + envs["current_endpoint"], + pserver_program, + startup_program=startup_prog) + return pserver_program, pserver_startup_program + elif training_role == "TRAINER": + train_program = t.get_trainer_program() + return train_program, startup_prog + else: + raise ValueError( + 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' + ) + + +def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + t.transpile( + trainer_id, + trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'), + current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'), + startup_program=startup_prog, + program=main_prog) + + +def prepare_for_multi_process(exe, build_strategy, train_prog): + # prepare for multi-process + trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers < 2: return + print("PADDLE_TRAINERS_NUM", num_trainers) + print("PADDLE_TRAINER_ID", trainer_id) + build_strategy.num_trainers = num_trainers + build_strategy.trainer_id = trainer_id + # NOTE(zcd): use multi processes to train the model, + # and each process use one GPU card. + startup_prog = fluid.Program() + nccl2_prepare_paddle(trainer_id, startup_prog, train_prog) + # the startup_prog are run two times, but it doesn't matter. + exe.run(startup_prog) diff --git a/PaddleCV/image_classification/utils/optimizer.py b/PaddleCV/image_classification/utils/optimizer.py index dd33bd3794fcd3caae48005b4638c4742d00a729..16b96267d274434c6e496e586cef47a13ae9e074 100644 --- a/PaddleCV/image_classification/utils/optimizer.py +++ b/PaddleCV/image_classification/utils/optimizer.py @@ -18,7 +18,6 @@ from __future__ import print_function import math -import paddle import paddle.fluid as fluid import paddle.fluid.layers.ops as ops from paddle.fluid.initializer import init_on_cpu @@ -142,7 +141,6 @@ class Optimizer(object): """ def __init__(self, args): - self.batch_size = args.batch_size self.lr = args.lr self.lr_strategy = args.lr_strategy diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index 46633aadda376a3cf7a2b922dc419b8c91a1cbb3..c0e9b6c58b3762af825164a5da9229dc6447f7a0 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -32,6 +32,7 @@ import paddle.fluid as fluid from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program from paddle.fluid import unique_name, layers +from utils import dist_utils def print_arguments(args): """Print argparse's arguments. @@ -376,7 +377,7 @@ def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode): raise Exception("Illegal info_mode") -def best_strategy_compiled(args, program, loss): +def best_strategy_compiled(args, program, loss, exe): """make a program which wrapped by a compiled program """ @@ -391,6 +392,13 @@ def best_strategy_compiled(args, program, loss): exec_strategy.num_threads = fluid.core.get_cuda_device_count() exec_strategy.num_iteration_per_drop_scope = 10 + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers > 1 and args.use_gpu: + dist_utils.prepare_for_multi_process(exe, build_strategy, program) + # NOTE: the process is fast when num_threads is 1 + # for multi-process training. + exec_strategy.num_threads = 1 + compiled_program = fluid.CompiledProgram(program).with_data_parallel( loss_name=loss.name, build_strategy=build_strategy,