diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 450fe1508f2a505a233b3d300cb7c500894231e7..f61c9e3a91146704faa6c5b1058137bef67d2a3e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -193,15 +193,14 @@ ParallelExecutor::ParallelExecutor( const std::unordered_set &bcast_vars, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, - const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, - size_t num_trainers, size_t trainer_id) + const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; member_->use_cuda_ = exec_strategy.use_cuda_; member_->build_strategy_ = build_strategy; member_->use_all_reduce_ = build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; - member_->nranks_ = num_trainers * places.size(); + member_->nranks_ = build_strategy.num_trainers_ * places.size(); if (!member_->use_all_reduce_) { PADDLE_ENFORCE(places.size() > 1, @@ -253,7 +252,8 @@ ParallelExecutor::ParallelExecutor( } member_->nccl_ctxs_.reset(new platform::NCCLContextMap( - member_->places_, nccl_id, num_trainers, trainer_id)); + member_->places_, nccl_id, build_strategy.num_trainers_, + build_strategy.trainer_id_)); #else PADDLE_THROW("Not compiled with CUDA"); #endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 49d3f0d3f6f2a8965d39b656071d86bde42bfd93..121bbd55ad575477424a2fb12baab82585eae517 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -50,8 +50,7 @@ class ParallelExecutor { const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, const ExecutionStrategy &exec_strategy, - const BuildStrategy &build_strategy, - size_t num_trainers = 1, size_t trainer_id = 0); + const BuildStrategy &build_strategy); ~ParallelExecutor(); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 2abda933cfc983bcef2433c3be4681c51e78ff1c..34cffd1aa3f8c3df9d4f67db37cbe0985862118f 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -77,6 +77,7 @@ class PreparedOp { framework::OperatorWithKernel::OpKernelFunc func; platform::DeviceContext* dev_ctx; }; + class OpBase; class VarBase { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index dce755c91a58d3291d740bd05c1cf835cbfbf1f0..a540c6fca1f18830b5ac07d6b45c0e586351a21d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1019,8 +1019,7 @@ All parameter, weight, gradient are variables in Paddle. pe.def(py::init &, const std::unordered_set &, const ProgramDesc &, const std::string &, Scope *, std::vector &, - const ExecutionStrategy &, const BuildStrategy &, size_t, - size_t>()) + const ExecutionStrategy &, const BuildStrategy &>()) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope* diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0ef8d150d2102874a9bbea79c177c49b3ebad7 --- /dev/null +++ b/python/paddle/fluid/compiler.py @@ -0,0 +1,204 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +import six +import sys +from .. import compat as cpt + +from . import core + +ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy +BuildStrategy = core.ParallelExecutor.BuildStrategy + + +def _place_obj(place): + p = core.Place() + p.set_place(place) + return p + + +class CompiledProgram(object): + """ + Compiles a Program for execution. + + 1. Users first create the program with layers. + 2. Optionally, users use CompiledProgram to optimize the program before run. + 3. The original program or CompiledProgram is run by executor. + + The CompiledProgram is used to transform a program for various + optimizations, for example. + * Pre-compute some logic once so that each run is faster. + * Transform the program so that it can run in multiple devices. + * TODO: transform the program for optimized inference or distributed + training. + + Example: + .. code-block:: python + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + compiled_prog = compiler.CompiledProgram(main).with_data_parallel( + loss_name=loss.name) + for i in range(5): + test_loss, = exe.run(compiled_prog, + feed=feed_dict, + fetch_list=[loss.name]) + + Args: + program: Program instance that contains the model logic. + """ + + def __init__(self, program): + self._program = program + self._scope = None + self._place = None + self._executor = None + self._compiled = False + self._is_data_parallel = False + + def with_data_parallel(self, + loss_name=None, + build_strategy=None, + exec_strategy=None, + share_vars_from=None): + """Configs the program to run in data parallel way. + + Args: + loss_name (str): The loss name must set in training. Default None. + build_strategy(BuildStrategy): build_strategy is used to + build the graph so it can run on multiple devices/cores with + optimized topology. + For more information, please refer to fluid.BuildStrategy. + Default None. + exec_strategy(ExecutionStrategy): exec_strategy is used to + to select the a way to execute the graph, for example how many + threads are used, how many iterations to clean up the temp + variables. For more information, please refer + to fluid.ExecutionStrategy. Default None. + share_vars_from(CompiledProgram): If provide, this CompiledProgram + will share variables from `share_vars_from`. `share_vars_from` + must be run by the executor before this CompiledProgram so that + vars are ready. + Returns: + self + """ + assert not self._is_data_parallel, "Already compiled with parallel." + self._is_data_parallel = True + self._build_strategy = build_strategy + self._exec_strategy = exec_strategy + self._loss_name = loss_name + self._share_vars_from = share_vars_from + if self._exec_strategy is None: + self._exec_strategy = ExecutionStrategy() + if self._build_strategy is None: + self._build_strategy = BuildStrategy() + return self + + def _with_distributed(self): + raise NotImplementedError() + + def _with_inference_optimize(self): + raise NotImplementedError() + + def _compile_data_parallel(self): + if self._share_vars_from: + if self._scope: + sys.stderr.write("share_vars_from is set, scope is ignored.\n") + if not self._share_vars_from._is_data_parallel: + raise ValueError("share_vars_from is not data parallel. Cannot " + "share vars from it.") + if self._share_vars_from._executor is None: + raise ValueError( + "share_vars_from is not compiled and run, so there is no " + "var to share.") + self._local_scopes = self._share_vars_from._executor.local_scopes() + else: + self._local_scopes = [] + + self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace) + if self._exec_strategy.use_cuda: + gpus_env = os.getenv("FLAGS_selected_gpus") + if gpus_env: + gpus = [int(s) for s in gpus_env.split(",")] + else: + gpus = [ + i for i in six.moves.range(core.get_cuda_device_count()) + ] + self._places = [core.CUDAPlace(i) for i in gpus] + else: + cpu_num = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] + assert self._places, "no place for execution" + + if self._exec_strategy.num_threads == 0: + if self._exec_strategy.use_cuda: + # Experiments on se-resnext shows that too many threads hurt + # performance. Worth tunning for other models in the future. + self._exec_strategy.num_threads = len(self._places) * 4 + else: + cpu_num = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + self._exec_strategy.num_threads = cpu_num * 2 + + trainers_endpoints = self._program._trainers_endpoints + if self._build_strategy.num_trainers > 1 and trainers_endpoints: + assert self._build_strategy.num_trainers == len( + trainers_endpoints), "num_trainers == len(end_points)" + self._build_strategy.trainers_endpoints = trainers_endpoints + + self._persistable_vars = set([ + cpt.to_text(v.name) + for v in [ + var for var in self._program.list_vars() + if var.persistable and var.type != core.VarDesc.VarType.RAW + ] + ]) + + places = list(map(_place_obj, self._places)) + return core.ParallelExecutor( + places, self._persistable_vars, self._program.desc, + cpt.to_text(self._loss_name) + if self._loss_name else six.u(''), self._scope, self._local_scopes, + self._exec_strategy, self._build_strategy) + + def _compile(self, scope, place): + """Compile the program based on the configs. + + Args: + scope: The variables (resources) that are associated with + this compiled program. + place: The location that the compiled program will be run on. + + Returns: + self + """ + if self._compiled: + if scope and self._scope != scope: + raise ValueError("Cannot compile with different scope") + if place and self._place != place: + raise ValueError("Cannot compile with different place") + return self + self._compiled = True + + self._scope = scope + self._place = place + if self._is_data_parallel: + self._executor = self._compile_data_parallel() + else: + p = _place_obj(self._place) + self._executor = core.Executor(p) + return self diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 5a9e908b61eeeea3fdfdfcc54d1f150f59a3973b..1a940b30c1564c7622f646b4697375179c607f91 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -14,11 +14,15 @@ from __future__ import print_function +import os +import multiprocessing import numpy as np import contextlib import six from .framework import Program, default_main_program, Variable from . import core +from . import compiler +from .. import compat as cpt __all__ = ['Executor', 'global_scope', 'scope_guard'] @@ -204,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True): return tensor -def _get_program_cache_key(feed, fetch_list): - feed_var_names = list(feed.keys()) +def _to_name_str(var): + if isinstance(var, Variable): + return var.desc.name() + elif isinstance(var, str): + return var + elif isinstance(var, six.string_types): + return str(var) + else: + raise TypeError(str(var) + " should be Variable or str") - def to_name_str(var): - if isinstance(var, Variable): - return var.desc.name() - elif isinstance(var, str): - return var - elif isinstance(var, six.string_types): - return str(var) - else: - raise TypeError(str(var) + " should be Variable or str") - fetch_var_names = list(map(to_name_str, fetch_list)) +def _get_program_cache_key(feed, fetch_list): + feed_var_names = list(feed.keys()) + fetch_var_names = list(map(_to_name_str, fetch_list)) return str(feed_var_names + fetch_var_names) @@ -266,6 +270,29 @@ class Executor(object): But the global scope variables will be persistent through different runs. All of ops in program will be running in sequence. + + Example: + .. code-block:: python + # First create the Executor. + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Run the startup program once and only once. + # Not need to optimize/compile the startup program. + exe.run(fluid.default_startup_program()) + + # Run the main program directly without compile. + loss, = exe.run(fluid.default_main_program(), + feed=feed_dict, + fetch_list=[loss.name]) + # Or, compiled the program and run. See `CompiledProgram` for more detail. + compiled_prog = compiler.CompiledProgram( + fluid.default_main_program()).with_data_parallel( + loss_name=loss.name) + loss, = exe.run(compiled_prog, + feed=feed_dict, + fetch_list=[loss.name]) + Args: place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device @@ -275,11 +302,8 @@ class Executor(object): def __init__(self, place): self.place = place - p = core.Place() - p.set_place(place) - self.executor = core.Executor(p) - self.program_caches = dict() + self.executor = None self._closed = False def _get_program_cache(self, program_cache_key): @@ -361,6 +385,7 @@ class Executor(object): You can no long use this executor after calling this method. For the distributed training, this method would free the resource on PServers related to the current Trainer. + TODO(panyx0718): Why ParallelExecutor doesn't have close? Example: >>> cpu = core.CPUPlace() @@ -368,10 +393,55 @@ class Executor(object): >>> ... >>> exe.close() """ - if not self._closed: + if not self._closed and self.executor: self.executor.close() self._closed = True + def _run_parallel(self, scope, feed, fetch_list, fetch_var_name, + return_numpy): + if isinstance(feed, dict): + feed_tensor_dict = dict() + for feed_name in feed: + feed_tensor = feed[feed_name] + if not isinstance(feed_tensor, core.LoDTensor): + feed_tensor = core.LoDTensor() + # always set to CPU place, since the tensor need to be splitted + # it is fast in CPU + feed_tensor.set(feed[feed_name], core.CPUPlace()) + feed_tensor_dict[feed_name] = feed_tensor + + self.executor.feed_and_split_tensor_into_local_scopes( + feed_tensor_dict) + elif isinstance(feed, list) or isinstance(feed, tuple): + if len(feed) != len(self._places): + raise ValueError( + "Feed a list of tensor, the list should be the same size as places" + ) + + res = list() + for i, each in enumerate(feed): + if not isinstance(each, dict): + raise TypeError( + "Each element of feed list should be a dict") + res_dict = dict() + for feed_name in each: + tensor = each[feed_name] + if not isinstance(tensor, core.LoDTensor): + tmp = core.LoDTensor() + tmp.set(tensor, self._places[i]) + tensor = tmp + res_dict[feed_name] = tensor + res.append(res_dict) + self.executor.feed_tensors_into_local_scopes(res) + + fetch_var_names = list(map(_to_name_str, fetch_list)) + self.executor.run(fetch_var_names, fetch_var_name) + arr = scope.find_var(fetch_var_name).get_lod_tensor_array() + + if return_numpy: + return as_numpy(arr) + return [arr[i] for i in range(len(arr))] + def run(self, program=None, feed=None, @@ -391,8 +461,9 @@ class Executor(object): operators in the program but not only the operators dependent by the fetch_list Args: - program(Program): the program that need to run, if not provied, then default_main_program will be used. - feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData} + program(Program|CompiledProgram): the program that need to run, + if not provided, then default_main_program will be used. + feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData} fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list. feed_var_name(str): the name for the input variable of feed Operator. fetch_var_name(str): the name for the output variable of fetch Operator. @@ -428,14 +499,59 @@ class Executor(object): if self._closed: raise RuntimeError("Attempted to use a closed Executor") + if scope is None: + scope = global_scope() + if fetch_list is None: + fetch_list = [] + + compiled = isinstance(program, compiler.CompiledProgram) + # For backward compatibility, run directly. + if not compiled: + if not self.executor: + p = core.Place() + p.set_place(self.place) + self.executor = core.Executor(p) + return self._run( + program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + scope=scope, + return_numpy=return_numpy, + use_program_cache=use_program_cache) + + program._compile(scope, self.place) + self.executor = program._executor + if program._is_data_parallel: + return self._run_parallel( + scope=scope, + feed=feed, + fetch_list=fetch_list, + fetch_var_name=fetch_var_name, + return_numpy=return_numpy) + else: + # TODO(panyx0718): Can compile program to optimize executor + # performance. + return self._run( + program._program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + scope=scope, + return_numpy=return_numpy, + use_program_cache=use_program_cache) + + def _run(self, program, feed, fetch_list, feed_var_name, fetch_var_name, + scope, return_numpy, use_program_cache): + if feed is None: feed = {} if not isinstance(feed, dict): raise TypeError( "feed requires dict as its Parameter. But you passed in %s" % (type(feed))) - if fetch_list is None: - fetch_list = [] if program is None: program = default_main_program() @@ -444,9 +560,6 @@ class Executor(object): "Executor requires Program as its Parameter. But you passed in %s" % (type(program))) - if scope is None: - scope = global_scope() - cache_key = _get_program_cache_key(feed, fetch_list) if use_program_cache: cached_program = self._get_program_cache(cache_key) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 3b066eda110275dc02e451dfaf0cfe28a3fb7a53..9601a9e73f0e247e3763228b29ee2989bbbb200f 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -181,9 +181,8 @@ class ParallelExecutor(object): # step7: init ParallelExecutor self.executor = core.ParallelExecutor( places, persistable_vars, main.desc, - cpt.to_text(loss_name) - if loss_name else six.u(''), scope, local_scopes, exec_strategy, - build_strategy, num_trainers, trainer_id) + cpt.to_text(loss_name) if loss_name else six.u(''), scope, + local_scopes, exec_strategy, build_strategy) self.scope = scope @@ -294,7 +293,7 @@ class ParallelExecutor(object): res.append(res_dict) self.executor.feed_tensors_into_local_scopes(res) - fetch_var_name = '@FETCHED_VAR_NAME@' + fetch_var_name = 'fetch' self.executor.run(fetch_list, fetch_var_name) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 2b0ab0cc3bc23bab140d2b7e8cb765e537ff3f5c..1ba47d5a57665e00f08ffde9ebf3b5b10412c2ee 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -19,6 +19,7 @@ import os import unittest import paddle.fluid as fluid import paddle.fluid.core as core +from paddle.fluid import compiler import time import numpy as np import math @@ -44,15 +45,8 @@ class TestParallelExecutorBase(unittest.TestCase): optimizer=fluid.optimizer.Adam, use_fast_executor=False, enable_sequential_execution=False): - def run_executor(exe, feed, fetch_list, program=None): - if isinstance(exe, fluid.ParallelExecutor): - res = exe.run(fetch_list=fetch_list, feed=feed) - elif isinstance(exe, fluid.Executor): - if program is None: - program = fluid.default_main_program() - res = exe.run(program=program, feed=feed, fetch_list=fetch_list) - else: - raise ValueError('Unkown type exe') + def run_executor(exe, binary, feed, fetch_list): + res = exe.run(binary, feed=feed, fetch_list=fetch_list) return res main = fluid.Program() @@ -72,8 +66,8 @@ class TestParallelExecutorBase(unittest.TestCase): fluid.memory_optimize(main) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - startup_exe = fluid.Executor(place) - startup_exe.run(startup) + exe = fluid.Executor(place) + exe.run(startup) exec_strategy = fluid.ExecutionStrategy() exec_strategy.allow_op_delay = allow_op_delay if use_fast_executor: @@ -86,15 +80,13 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy.enable_sequential_execution = enable_sequential_execution if use_cuda and core.is_compiled_with_cuda(): build_strategy.remove_unnecessary_lock = True - if use_parallel_executor: - exe = fluid.ParallelExecutor( - use_cuda, + binary = compiler.CompiledProgram(main).with_data_parallel( loss_name=loss.name, - exec_strategy=exec_strategy, - build_strategy=build_strategy) + build_strategy=build_strategy, + exec_strategy=exec_strategy) else: - exe = fluid.Executor(place=place) + binary = compiler.CompiledProgram(main) if batch_size is not None: batch_size *= fluid.core.get_cuda_device_count( @@ -102,13 +94,14 @@ class TestParallelExecutorBase(unittest.TestCase): os.environ.get('CPU_NUM', multiprocessing.cpu_count())) begin = time.time() first_loss, = run_executor( - exe=exe, feed=feed_dict, fetch_list=[loss.name]) + exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) for i in range(iter): - run_executor(exe=exe, feed=feed_dict, fetch_list=[]) + run_executor( + exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) last_loss, = run_executor( - exe=exe, feed=feed_dict, fetch_list=[loss.name]) + exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) end = time.time() if batch_size is not None: diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 0caab08f0dc19efceadd723b474c10e1a2deb449..3fcdc57906c214bdc8179c55b576e2e9e8d80973 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -26,6 +26,7 @@ import pickle import numpy as np import paddle.fluid as fluid +from paddle.fluid import compiler RUN_STEP = 10 DEFAULT_BATCH_SIZE = 2 @@ -104,8 +105,8 @@ class TestDistRunnerBase(object): else: place = fluid.CPUPlace() - startup_exe = fluid.Executor(place) - startup_exe.run(fluid.default_startup_program()) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) strategy = fluid.ExecutionStrategy() strategy.num_threads = 1 @@ -125,19 +126,16 @@ class TestDistRunnerBase(object): mypass.set_int("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2": - num_trainers = len(args.endpoints.split(",")) - trainer_id = args.trainer_id + build_stra.num_trainers = len(args.endpoints.split(",")) + build_stra.trainer_id = args.trainer_id else: - num_trainers = 1 - trainer_id = 0 + build_stra.num_trainers = 1 + build_stra.trainer_id = 0 - exe = fluid.ParallelExecutor( - args.use_cuda, + binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( loss_name=avg_cost.name, - exec_strategy=strategy, build_strategy=build_stra, - num_trainers=num_trainers, - trainer_id=trainer_id) + exec_strategy=strategy) feed_var_list = [ var for var in trainer_prog.global_block().vars.values() @@ -160,7 +158,8 @@ class TestDistRunnerBase(object): out_losses = [] for _ in six.moves.xrange(RUN_STEP): - loss, = exe.run(fetch_list=[avg_cost.name], + loss, = exe.run(binary, + fetch_list=[avg_cost.name], feed=feeder.feed(get_data())) out_losses.append(loss[0]) if six.PY2: diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py index db2826653edf6bf6ddd498cbd56b07da646cebf4..d89fd87a38be460c561dbff656cdaa069ffbbd53 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py @@ -15,6 +15,7 @@ from __future__ import print_function import paddle.fluid as fluid +from paddle.fluid import compiler import paddle.fluid.core as core import numpy as np import unittest @@ -61,22 +62,21 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): exe.run(startup) feed_dict = {'image': image, 'label': label} - train_exe = fluid.ParallelExecutor( - use_cuda=use_cuda, + train_cp = compiler.CompiledProgram(main).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + test_cp = compiler.CompiledProgram(test_program).with_data_parallel( loss_name=loss.name, - main_program=main, - build_strategy=build_strategy) - - test_exe = fluid.ParallelExecutor( - use_cuda=use_cuda, - main_program=test_program, - share_vars_from=train_exe, - build_strategy=build_strategy) + build_strategy=build_strategy, + share_vars_from=train_cp) for i in range(5): - test_loss, = test_exe.run([loss.name], feed=feed_dict) - - train_loss, = train_exe.run([loss.name], feed=feed_dict) + exe.run(train_cp, feed=feed_dict, fetch_list=[loss.name]) + test_loss, = exe.run(test_cp, + feed=feed_dict, + fetch_list=[loss.name]) + train_loss, = exe.run(train_cp, + feed=feed_dict, + fetch_list=[loss.name]) avg_test_loss_val = np.array(test_loss).mean() if math.isnan(float(avg_test_loss_val)):