未验证 提交 7b73fc9e 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15089 from panyx0718/api

try unify Executor and ParallelExecutor
......@@ -193,15 +193,14 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
size_t num_trainers, size_t trainer_id)
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ =
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
member_->nranks_ = num_trainers * places.size();
member_->nranks_ = build_strategy.num_trainers_ * places.size();
if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1,
......@@ -253,7 +252,8 @@ ParallelExecutor::ParallelExecutor(
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id));
member_->places_, nccl_id, build_strategy.num_trainers_,
build_strategy.trainer_id_));
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......
......@@ -50,8 +50,7 @@ class ParallelExecutor {
const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy,
size_t num_trainers = 1, size_t trainer_id = 0);
const BuildStrategy &build_strategy);
~ParallelExecutor();
......
......@@ -77,6 +77,7 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
};
class OpBase;
class VarBase {
......
......@@ -1019,8 +1019,7 @@ All parameter, weight, gradient are variables in Paddle.
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const ProgramDesc &,
const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t,
size_t>())
const ExecutionStrategy &, const BuildStrategy &>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import os
import six
import sys
from .. import compat as cpt
from . import core
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
def _place_obj(place):
p = core.Place()
p.set_place(place)
return p
class CompiledProgram(object):
"""
Compiles a Program for execution.
1. Users first create the program with layers.
2. Optionally, users use CompiledProgram to optimize the program before run.
3. The original program or CompiledProgram is run by executor.
The CompiledProgram is used to transform a program for various
optimizations, for example.
* Pre-compute some logic once so that each run is faster.
* Transform the program so that it can run in multiple devices.
* TODO: transform the program for optimized inference or distributed
training.
Example:
.. code-block:: python
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name)
for i in range(5):
test_loss, = exe.run(compiled_prog,
feed=feed_dict,
fetch_list=[loss.name])
Args:
program: Program instance that contains the model logic.
"""
def __init__(self, program):
self._program = program
self._scope = None
self._place = None
self._executor = None
self._compiled = False
self._is_data_parallel = False
def with_data_parallel(self,
loss_name=None,
build_strategy=None,
exec_strategy=None,
share_vars_from=None):
"""Configs the program to run in data parallel way.
Args:
loss_name (str): The loss name must set in training. Default None.
build_strategy(BuildStrategy): build_strategy is used to
build the graph so it can run on multiple devices/cores with
optimized topology.
For more information, please refer to fluid.BuildStrategy.
Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to
to select the a way to execute the graph, for example how many
threads are used, how many iterations to clean up the temp
variables. For more information, please refer
to fluid.ExecutionStrategy. Default None.
share_vars_from(CompiledProgram): If provide, this CompiledProgram
will share variables from `share_vars_from`. `share_vars_from`
must be run by the executor before this CompiledProgram so that
vars are ready.
Returns:
self
"""
assert not self._is_data_parallel, "Already compiled with parallel."
self._is_data_parallel = True
self._build_strategy = build_strategy
self._exec_strategy = exec_strategy
self._loss_name = loss_name
self._share_vars_from = share_vars_from
if self._exec_strategy is None:
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
self._build_strategy = BuildStrategy()
return self
def _with_distributed(self):
raise NotImplementedError()
def _with_inference_optimize(self):
raise NotImplementedError()
def _compile_data_parallel(self):
if self._share_vars_from:
if self._scope:
sys.stderr.write("share_vars_from is set, scope is ignored.\n")
if not self._share_vars_from._is_data_parallel:
raise ValueError("share_vars_from is not data parallel. Cannot "
"share vars from it.")
if self._share_vars_from._executor is None:
raise ValueError(
"share_vars_from is not compiled and run, so there is no "
"var to share.")
self._local_scopes = self._share_vars_from._executor.local_scopes()
else:
self._local_scopes = []
self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace)
if self._exec_strategy.use_cuda:
gpus_env = os.getenv("FLAGS_selected_gpus")
if gpus_env:
gpus = [int(s) for s in gpus_env.split(",")]
else:
gpus = [
i for i in six.moves.range(core.get_cuda_device_count())
]
self._places = [core.CUDAPlace(i) for i in gpus]
else:
cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
assert self._places, "no place for execution"
if self._exec_strategy.num_threads == 0:
if self._exec_strategy.use_cuda:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
self._exec_strategy.num_threads = len(self._places) * 4
else:
cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self._exec_strategy.num_threads = cpu_num * 2
trainers_endpoints = self._program._trainers_endpoints
if self._build_strategy.num_trainers > 1 and trainers_endpoints:
assert self._build_strategy.num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)"
self._build_strategy.trainers_endpoints = trainers_endpoints
self._persistable_vars = set([
cpt.to_text(v.name)
for v in [
var for var in self._program.list_vars()
if var.persistable and var.type != core.VarDesc.VarType.RAW
]
])
places = list(map(_place_obj, self._places))
return core.ParallelExecutor(
places, self._persistable_vars, self._program.desc,
cpt.to_text(self._loss_name)
if self._loss_name else six.u(''), self._scope, self._local_scopes,
self._exec_strategy, self._build_strategy)
def _compile(self, scope, place):
"""Compile the program based on the configs.
Args:
scope: The variables (resources) that are associated with
this compiled program.
place: The location that the compiled program will be run on.
Returns:
self
"""
if self._compiled:
if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope")
if place and self._place != place:
raise ValueError("Cannot compile with different place")
return self
self._compiled = True
self._scope = scope
self._place = place
if self._is_data_parallel:
self._executor = self._compile_data_parallel()
else:
p = _place_obj(self._place)
self._executor = core.Executor(p)
return self
......@@ -14,11 +14,15 @@
from __future__ import print_function
import os
import multiprocessing
import numpy as np
import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
from . import compiler
from .. import compat as cpt
__all__ = ['Executor', 'global_scope', 'scope_guard']
......@@ -204,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True):
return tensor
def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys())
def _to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
fetch_var_names = list(map(to_name_str, fetch_list))
def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys())
fetch_var_names = list(map(_to_name_str, fetch_list))
return str(feed_var_names + fetch_var_names)
......@@ -266,6 +270,29 @@ class Executor(object):
But the global scope variables will be persistent through different runs.
All of ops in program will be running in sequence.
Example:
.. code-block:: python
# First create the Executor.
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Run the startup program once and only once.
# Not need to optimize/compile the startup program.
exe.run(fluid.default_startup_program())
# Run the main program directly without compile.
loss, = exe.run(fluid.default_main_program(),
feed=feed_dict,
fetch_list=[loss.name])
# Or, compiled the program and run. See `CompiledProgram` for more detail.
compiled_prog = compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name)
loss, = exe.run(compiled_prog,
feed=feed_dict,
fetch_list=[loss.name])
Args:
place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device
......@@ -275,11 +302,8 @@ class Executor(object):
def __init__(self, place):
self.place = place
p = core.Place()
p.set_place(place)
self.executor = core.Executor(p)
self.program_caches = dict()
self.executor = None
self._closed = False
def _get_program_cache(self, program_cache_key):
......@@ -361,6 +385,7 @@ class Executor(object):
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
TODO(panyx0718): Why ParallelExecutor doesn't have close?
Example:
>>> cpu = core.CPUPlace()
......@@ -368,10 +393,55 @@ class Executor(object):
>>> ...
>>> exe.close()
"""
if not self._closed:
if not self._closed and self.executor:
self.executor.close()
self._closed = True
def _run_parallel(self, scope, feed, fetch_list, fetch_var_name,
return_numpy):
if isinstance(feed, dict):
feed_tensor_dict = dict()
for feed_name in feed:
feed_tensor = feed[feed_name]
if not isinstance(feed_tensor, core.LoDTensor):
feed_tensor = core.LoDTensor()
# always set to CPU place, since the tensor need to be splitted
# it is fast in CPU
feed_tensor.set(feed[feed_name], core.CPUPlace())
feed_tensor_dict[feed_name] = feed_tensor
self.executor.feed_and_split_tensor_into_local_scopes(
feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(self._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
res = list()
for i, each in enumerate(feed):
if not isinstance(each, dict):
raise TypeError(
"Each element of feed list should be a dict")
res_dict = dict()
for feed_name in each:
tensor = each[feed_name]
if not isinstance(tensor, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(tensor, self._places[i])
tensor = tmp
res_dict[feed_name] = tensor
res.append(res_dict)
self.executor.feed_tensors_into_local_scopes(res)
fetch_var_names = list(map(_to_name_str, fetch_list))
self.executor.run(fetch_var_names, fetch_var_name)
arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
if return_numpy:
return as_numpy(arr)
return [arr[i] for i in range(len(arr))]
def run(self,
program=None,
feed=None,
......@@ -391,8 +461,9 @@ class Executor(object):
operators in the program but not only the operators dependent by the fetch_list
Args:
program(Program): the program that need to run, if not provied, then default_main_program will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData}
program(Program|CompiledProgram): the program that need to run,
if not provided, then default_main_program will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator.
fetch_var_name(str): the name for the output variable of fetch Operator.
......@@ -428,14 +499,59 @@ class Executor(object):
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if scope is None:
scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly.
if not compiled:
if not self.executor:
p = core.Place()
p.set_place(self.place)
self.executor = core.Executor(p)
return self._run(
program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
scope=scope,
return_numpy=return_numpy,
use_program_cache=use_program_cache)
program._compile(scope, self.place)
self.executor = program._executor
if program._is_data_parallel:
return self._run_parallel(
scope=scope,
feed=feed,
fetch_list=fetch_list,
fetch_var_name=fetch_var_name,
return_numpy=return_numpy)
else:
# TODO(panyx0718): Can compile program to optimize executor
# performance.
return self._run(
program._program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
scope=scope,
return_numpy=return_numpy,
use_program_cache=use_program_cache)
def _run(self, program, feed, fetch_list, feed_var_name, fetch_var_name,
scope, return_numpy, use_program_cache):
if feed is None:
feed = {}
if not isinstance(feed, dict):
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" %
(type(feed)))
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()
......@@ -444,9 +560,6 @@ class Executor(object):
"Executor requires Program as its Parameter. But you passed in %s"
% (type(program)))
if scope is None:
scope = global_scope()
cache_key = _get_program_cache_key(feed, fetch_list)
if use_program_cache:
cached_program = self._get_program_cache(cache_key)
......
......@@ -181,9 +181,8 @@ class ParallelExecutor(object):
# step7: init ParallelExecutor
self.executor = core.ParallelExecutor(
places, persistable_vars, main.desc,
cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id)
cpt.to_text(loss_name) if loss_name else six.u(''), scope,
local_scopes, exec_strategy, build_strategy)
self.scope = scope
......@@ -294,7 +293,7 @@ class ParallelExecutor(object):
res.append(res_dict)
self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@'
fetch_var_name = 'fetch'
self.executor.run(fetch_list, fetch_var_name)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
......
......@@ -19,6 +19,7 @@ import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler
import time
import numpy as np
import math
......@@ -44,15 +45,8 @@ class TestParallelExecutorBase(unittest.TestCase):
optimizer=fluid.optimizer.Adam,
use_fast_executor=False,
enable_sequential_execution=False):
def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed)
elif isinstance(exe, fluid.Executor):
if program is None:
program = fluid.default_main_program()
res = exe.run(program=program, feed=feed, fetch_list=fetch_list)
else:
raise ValueError('Unkown type exe')
def run_executor(exe, binary, feed, fetch_list):
res = exe.run(binary, feed=feed, fetch_list=fetch_list)
return res
main = fluid.Program()
......@@ -72,8 +66,8 @@ class TestParallelExecutorBase(unittest.TestCase):
fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place)
startup_exe.run(startup)
exe = fluid.Executor(place)
exe.run(startup)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor:
......@@ -86,15 +80,13 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda,
binary = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
build_strategy=build_strategy,
exec_strategy=exec_strategy)
else:
exe = fluid.Executor(place=place)
binary = compiler.CompiledProgram(main)
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count(
......@@ -102,13 +94,14 @@ class TestParallelExecutorBase(unittest.TestCase):
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time()
first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
......
......@@ -26,6 +26,7 @@ import pickle
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import compiler
RUN_STEP = 10
DEFAULT_BATCH_SIZE = 2
......@@ -104,8 +105,8 @@ class TestDistRunnerBase(object):
else:
place = fluid.CPUPlace()
startup_exe = fluid.Executor(place)
startup_exe.run(fluid.default_startup_program())
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1
......@@ -125,19 +126,16 @@ class TestDistRunnerBase(object):
mypass.set_int("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2":
num_trainers = len(args.endpoints.split(","))
trainer_id = args.trainer_id
build_stra.num_trainers = len(args.endpoints.split(","))
build_stra.trainer_id = args.trainer_id
else:
num_trainers = 1
trainer_id = 0
build_stra.num_trainers = 1
build_stra.trainer_id = 0
exe = fluid.ParallelExecutor(
args.use_cuda,
binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
loss_name=avg_cost.name,
exec_strategy=strategy,
build_strategy=build_stra,
num_trainers=num_trainers,
trainer_id=trainer_id)
exec_strategy=strategy)
feed_var_list = [
var for var in trainer_prog.global_block().vars.values()
......@@ -160,7 +158,8 @@ class TestDistRunnerBase(object):
out_losses = []
for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name],
loss, = exe.run(binary,
fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
out_losses.append(loss[0])
if six.PY2:
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.core as core
import numpy as np
import unittest
......@@ -61,22 +62,21 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
exe.run(startup)
feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
train_cp = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
test_cp = compiler.CompiledProgram(test_program).with_data_parallel(
loss_name=loss.name,
main_program=main,
build_strategy=build_strategy)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=test_program,
share_vars_from=train_exe,
build_strategy=build_strategy)
build_strategy=build_strategy,
share_vars_from=train_cp)
for i in range(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
exe.run(train_cp, feed=feed_dict, fetch_list=[loss.name])
test_loss, = exe.run(test_cp,
feed=feed_dict,
fetch_list=[loss.name])
train_loss, = exe.run(train_cp,
feed=feed_dict,
fetch_list=[loss.name])
avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册