提交 5e928e57 编写于 作者: X Xin Pan

try unify Executor and ParallelExecutor

test=develop
上级 a1e60ab1
...@@ -193,8 +193,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -193,8 +193,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, Scope *scope, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy)
size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
...@@ -253,7 +252,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -253,7 +252,8 @@ ParallelExecutor::ParallelExecutor(
} }
member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id)); member_->places_, nccl_id, build_strategy.num_trainers_,
build_strategy.trainer_id_));
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
......
...@@ -50,8 +50,7 @@ class ParallelExecutor { ...@@ -50,8 +50,7 @@ class ParallelExecutor {
const std::string &loss_var_name, Scope *scope, const std::string &loss_var_name, Scope *scope,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy);
size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -1022,8 +1022,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1022,8 +1022,7 @@ All parameter, weight, gradient are variables in Paddle.
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const ProgramDesc &, const std::unordered_set<std::string> &, const ProgramDesc &,
const std::string &, Scope *, std::vector<Scope *> &, const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t, const ExecutionStrategy &, const BuildStrategy &>())
size_t>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element // We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope* // of vec<Scope*> will be freed by Python GC. We can only return Scope*
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import os
import six
from .. import compat as cpt
from . import core
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
def _place_obj(place):
p = core.Place()
p.set_place(place)
return p
class _ProgramCompiler(object):
def __init__(self, program):
self._program = program
self._compiled = False
self._is_data_parallel = False
def _with_data_parallel(self,
loss_name=None,
build_strategy=None,
exec_strategy=None):
assert not self._is_data_parallel, "Already compiled with parallel."
self._is_data_parallel = True
self._build_strategy = build_strategy
self._exec_strategy = exec_strategy
self._loss_name = loss_name
return self
def _compile_data_parallel(self):
self._places = []
self._local_scopes = []
if self._exec_strategy is None:
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
self._build_strategy = BuildStrategy()
self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace)
if self._exec_strategy.use_cuda:
gpus_env = os.getenv("FLAGS_selected_gpus")
if gpus_env:
gpus = [int(s) for s in gpus_env.split(",")]
else:
gpus = [
i for i in six.moves.range(core.get_cuda_device_count())
]
self._places = [core.CUDAPlace(i) for i in gpus]
else:
cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
assert self._places, "no place for execution"
if self._exec_strategy.num_threads == 0:
if self._exec_strategy.use_cuda:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
self._exec_strategy.num_threads = len(self._places) * 4
else:
cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self._exec_strategy.num_threads = cpu_num * 2
trainers_endpoints = self._program._trainers_endpoints
if self._build_strategy.num_trainers > 1 and trainers_endpoints:
assert self._build_strategy.num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)"
self._build_strategy.trainers_endpoints = trainers_endpoints
self._persistable_vars = set([
cpt.to_text(v.name)
for v in [
var for var in self._program.list_vars()
if var.persistable and var.type != core.VarDesc.VarType.RAW
]
])
places = list(map(_place_obj, self._places))
return core.ParallelExecutor(
places, self._persistable_vars, self._program.desc,
cpt.to_text(self._loss_name)
if self._loss_name else six.u(''), self._scope, self._local_scopes,
self._exec_strategy, self._build_strategy)
def _compile(self, scope, place):
if self._compiled:
return self
self._compiled = True
self._scope = scope
self._place = place
if self._is_data_parallel:
self._executor = self._compile_data_parallel()
else:
p = _place_obj(self._place)
self._executor = core.Executor(p)
return self
...@@ -14,11 +14,15 @@ ...@@ -14,11 +14,15 @@
from __future__ import print_function from __future__ import print_function
import os
import multiprocessing
import numpy as np import numpy as np
import contextlib import contextlib
import six import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
from . import compiler
from .. import compat as cpt
__all__ = ['Executor', 'global_scope', 'scope_guard'] __all__ = ['Executor', 'global_scope', 'scope_guard']
...@@ -275,11 +279,8 @@ class Executor(object): ...@@ -275,11 +279,8 @@ class Executor(object):
def __init__(self, place): def __init__(self, place):
self.place = place self.place = place
p = core.Place()
p.set_place(place)
self.executor = core.Executor(p)
self.program_caches = dict() self.program_caches = dict()
self.executor = None
self._closed = False self._closed = False
def _get_program_cache(self, program_cache_key): def _get_program_cache(self, program_cache_key):
...@@ -361,6 +362,7 @@ class Executor(object): ...@@ -361,6 +362,7 @@ class Executor(object):
You can no long use this executor after calling this method. You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to For the distributed training, this method would free the resource on PServers related to
the current Trainer. the current Trainer.
TODO(panyx0718): Why ParallelExecutor doesn't have close?
Example: Example:
>>> cpu = core.CPUPlace() >>> cpu = core.CPUPlace()
...@@ -368,10 +370,58 @@ class Executor(object): ...@@ -368,10 +370,58 @@ class Executor(object):
>>> ... >>> ...
>>> exe.close() >>> exe.close()
""" """
if not self._closed: if not self._closed and self.executor:
self.executor.close() self.executor.close()
self._closed = True self._closed = True
def _run_parallel(self,
exe,
scope,
feed=None,
fetch_list=None,
return_numpy=True):
if isinstance(feed, dict):
feed_tensor_dict = dict()
for feed_name in feed:
feed_tensor = feed[feed_name]
if not isinstance(feed_tensor, core.LoDTensor):
feed_tensor = core.LoDTensor()
# always set to CPU place, since the tensor need to be splitted
# it is fast in CPU
feed_tensor.set(feed[feed_name], core.CPUPlace())
feed_tensor_dict[feed_name] = feed_tensor
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(self._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
res = list()
for i, each in enumerate(feed):
if not isinstance(each, dict):
raise TypeError(
"Each element of feed list should be a dict")
res_dict = dict()
for feed_name in each:
tensor = each[feed_name]
if not isinstance(tensor, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(tensor, self._places[i])
tensor = tmp
res_dict[feed_name] = tensor
res.append(res_dict)
exe.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@'
exe.run(fetch_list, fetch_var_name)
arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
if return_numpy:
return as_numpy(arr)
return [arr[i] for i in range(len(arr))]
def run(self, def run(self,
program=None, program=None,
feed=None, feed=None,
...@@ -428,6 +478,47 @@ class Executor(object): ...@@ -428,6 +478,47 @@ class Executor(object):
if self._closed: if self._closed:
raise RuntimeError("Attempted to use a closed Executor") raise RuntimeError("Attempted to use a closed Executor")
if scope is None:
scope = global_scope()
compiled = isinstance(program, compiler._ProgramCompiler)
if not compiled:
p = core.Place()
p.set_place(self.place)
self.executor = core.Executor(p)
return self._run(
program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
scope=scope,
return_numpy=return_numpy,
use_program_cache=use_program_cache)
program._compile(scope, self.place)
self.executor = program._executor
if program._is_data_parallel:
return self._run_parallel(
exe=program._executor,
scope=scope,
feed=feed,
fetch_list=fetch_list,
return_numpy=return_numpy)
else:
return self._run(
program._program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
scope=scope,
return_numpy=return_numpy,
use_program_cache=use_program_cache)
def _run(self, program, feed, fetch_list, feed_var_name, fetch_var_name,
scope, return_numpy, use_program_cache):
if feed is None: if feed is None:
feed = {} feed = {}
if not isinstance(feed, dict): if not isinstance(feed, dict):
...@@ -444,9 +535,6 @@ class Executor(object): ...@@ -444,9 +535,6 @@ class Executor(object):
"Executor requires Program as its Parameter. But you passed in %s" "Executor requires Program as its Parameter. But you passed in %s"
% (type(program))) % (type(program)))
if scope is None:
scope = global_scope()
cache_key = _get_program_cache_key(feed, fetch_list) cache_key = _get_program_cache_key(feed, fetch_list)
if use_program_cache: if use_program_cache:
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
......
...@@ -167,9 +167,8 @@ class ParallelExecutor(object): ...@@ -167,9 +167,8 @@ class ParallelExecutor(object):
# step7: init ParallelExecutor # step7: init ParallelExecutor
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
places, persistable_vars, main.desc, places, persistable_vars, main.desc,
cpt.to_text(loss_name) cpt.to_text(loss_name) if loss_name else six.u(''), scope,
if loss_name else six.u(''), scope, local_scopes, exec_strategy, local_scopes, exec_strategy, build_strategy)
build_strategy, num_trainers, trainer_id)
self.scope = scope self.scope = scope
...@@ -292,3 +291,6 @@ class ParallelExecutor(object): ...@@ -292,3 +291,6 @@ class ParallelExecutor(object):
@property @property
def device_count(self): def device_count(self):
return len(self._places) return len(self._places)
def close(self):
pass
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import compiler
import time import time
import numpy as np import numpy as np
import math import math
...@@ -44,15 +45,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -44,15 +45,8 @@ class TestParallelExecutorBase(unittest.TestCase):
optimizer=fluid.optimizer.Adam, optimizer=fluid.optimizer.Adam,
use_fast_executor=False, use_fast_executor=False,
enable_sequential_execution=False): enable_sequential_execution=False):
def run_executor(exe, feed, fetch_list, program=None): def run_executor(exe, binary, feed, fetch_list):
if isinstance(exe, fluid.ParallelExecutor): res = exe.run(binary, feed=feed, fetch_list=fetch_list)
res = exe.run(fetch_list=fetch_list, feed=feed)
elif isinstance(exe, fluid.Executor):
if program is None:
program = fluid.default_main_program()
res = exe.run(program=program, feed=feed, fetch_list=fetch_list)
else:
raise ValueError('Unkown type exe')
return res return res
main = fluid.Program() main = fluid.Program()
...@@ -72,8 +66,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -72,8 +66,8 @@ class TestParallelExecutorBase(unittest.TestCase):
fluid.memory_optimize(main) fluid.memory_optimize(main)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place) exe = fluid.Executor(place)
startup_exe.run(startup) exe.run(startup)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay exec_strategy.allow_op_delay = allow_op_delay
if use_fast_executor: if use_fast_executor:
...@@ -86,15 +80,13 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -86,15 +80,13 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda(): if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True build_strategy.remove_unnecessary_lock = True
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( binary = compiler._ProgramCompiler(main)._with_data_parallel(
use_cuda,
loss_name=loss.name, loss_name=loss.name,
exec_strategy=exec_strategy, build_strategy=build_strategy,
build_strategy=build_strategy) exec_strategy=exec_strategy)
else: else:
exe = fluid.Executor(place=place) binary = compiler._ProgramCompiler(main)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count( batch_size *= fluid.core.get_cuda_device_count(
...@@ -102,13 +94,14 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -102,13 +94,14 @@ class TestParallelExecutorBase(unittest.TestCase):
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
begin = time.time() begin = time.time()
first_loss, = run_executor( first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter): for i in range(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[]) run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
last_loss, = run_executor( last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
end = time.time() end = time.time()
if batch_size is not None: if batch_size is not None:
......
...@@ -26,6 +26,7 @@ import pickle ...@@ -26,6 +26,7 @@ import pickle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
RUN_STEP = 10 RUN_STEP = 10
DEFAULT_BATCH_SIZE = 2 DEFAULT_BATCH_SIZE = 2
...@@ -104,8 +105,8 @@ class TestDistRunnerBase(object): ...@@ -104,8 +105,8 @@ class TestDistRunnerBase(object):
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
startup_exe = fluid.Executor(place) exe = fluid.Executor(place)
startup_exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1 strategy.num_threads = 1
...@@ -125,19 +126,16 @@ class TestDistRunnerBase(object): ...@@ -125,19 +126,16 @@ class TestDistRunnerBase(object):
mypass.set_int("num_repeats", args.batch_merge_repeat) mypass.set_int("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2": if args.update_method == "nccl2":
num_trainers = len(args.endpoints.split(",")) build_stra.num_trainers = len(args.endpoints.split(","))
trainer_id = args.trainer_id build_stra.trainer_id = args.trainer_id
else: else:
num_trainers = 1 build_stra.num_trainers = 1
trainer_id = 0 build_stra.trainer_id = 0
exe = fluid.ParallelExecutor( binary = compiler._ProgramCompiler(trainer_prog)._with_data_parallel(
args.use_cuda,
loss_name=avg_cost.name, loss_name=avg_cost.name,
exec_strategy=strategy,
build_strategy=build_stra, build_strategy=build_stra,
num_trainers=num_trainers, exec_strategy=strategy)
trainer_id=trainer_id)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.values() var for var in trainer_prog.global_block().vars.values()
...@@ -160,7 +158,8 @@ class TestDistRunnerBase(object): ...@@ -160,7 +158,8 @@ class TestDistRunnerBase(object):
out_losses = [] out_losses = []
for _ in six.moves.xrange(RUN_STEP): for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name], loss, = exe.run(binary,
fetch_list=[avg_cost.name],
feed=feeder.feed(get_data())) feed=feeder.feed(get_data()))
out_losses.append(loss[0]) out_losses.append(loss[0])
if six.PY2: if six.PY2:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册