未验证 提交 2b7e5bd3 编写于 作者: Q qingqing01 提交者: GitHub

Support testing during training by ParallelExecutor. (#9738)

* Support testing during training by ParallelExecutor.

* Add unit test.

* Improve the interface.

* Follow comments.
上级 9a4ce6f1
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -24,6 +23,7 @@ limitations under the License. */ ...@@ -24,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -43,30 +43,40 @@ class ParallelExecutorPrivate { ...@@ -43,30 +43,40 @@ class ParallelExecutorPrivate {
#endif #endif
}; };
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_;
}
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
size_t num_threads, bool use_event, size_t num_threads, bool use_event,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::unordered_set<std::string> &bcast_vars,
const std::string &loss_var_name, Scope *scope, bool allow_op_delay) const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
// Step 1. RunStartupProgram and Bcast the params to devs. // Step 1. Bcast the params to devs.
Executor exe(places[0]);
exe.Run(startup_program, scope, 0);
// Create local scopes // Create local scopes
if (local_scopes.empty()) {
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.push_back(&scope->NewScope()); member_->local_scopes_.push_back(&scope->NewScope());
} }
} else {
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.push_back(local_scopes[i]);
}
}
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
#endif #endif
if (platform::is_gpu_place(places[0]) && if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
member_->local_scopes_.size() != 1) { // Is CUDA local_scopes.empty()) { // Is CUDA
BCastParamsToGPUs(startup_program); BCastParamsToGPUs(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
...@@ -99,16 +109,17 @@ ParallelExecutor::ParallelExecutor( ...@@ -99,16 +109,17 @@ ParallelExecutor::ParallelExecutor(
} }
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const { const std::unordered_set<std::string> &vars) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *main_scope = member_->local_scopes_[0]; auto *main_scope = member_->local_scopes_[0];
for (auto *var_desc : startup_program.Block(0).AllVars()) { for (auto &var : vars) {
size_t idx = var_desc->Name().find("@GRAD"); auto *main_var = main_scope->FindVar(var);
if (idx != std::string::npos) continue; if (!main_var->IsType<LoDTensor>()) {
if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { continue;
auto &main_tensor = }
main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
...@@ -123,8 +134,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -123,8 +134,7 @@ void ParallelExecutor::BCastParamsToGPUs(
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
auto *t = auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims); t->Resize(dims);
buffer = t->mutable_data(place, main_tensor.type()); buffer = t->mutable_data(place, main_tensor.type());
} }
...@@ -136,13 +146,12 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -136,13 +146,12 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::CPUPlace cpu; platform::CPUPlace cpu;
for (size_t i = 1; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
t->Resize(dims); t->Resize(dims);
t->mutable_data(cpu, main_tensor.type()); t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t); paddle::framework::TensorCopy(main_tensor, cpu, t);
} }
} }
}
member_->nccl_ctxs_->WaitAll(); member_->nccl_ctxs_->WaitAll();
} }
#else #else
......
...@@ -36,11 +36,14 @@ class ParallelExecutor { ...@@ -36,11 +36,14 @@ class ParallelExecutor {
explicit ParallelExecutor(size_t num_threads, bool use_event, explicit ParallelExecutor(size_t num_threads, bool use_event,
const std::vector<platform::Place>& places, const std::vector<platform::Place>& places,
const std::unordered_set<std::string>& params, const std::unordered_set<std::string>& params,
const ProgramDesc& startup_program, const std::unordered_set<std::string>& bcast_vars,
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes,
bool allow_op_delay); bool allow_op_delay);
std::vector<Scope*>& GetLocalScopes();
void Run(const std::vector<std::string>& fetch_tensors, void Run(const std::vector<std::string>& fetch_tensors,
const std::string& fetched_var_name, const std::string& fetched_var_name,
const std::unordered_map<std::string, LoDTensor>& feed_tensors); const std::unordered_map<std::string, LoDTensor>& feed_tensors);
...@@ -51,7 +54,7 @@ class ParallelExecutor { ...@@ -51,7 +54,7 @@ class ParallelExecutor {
ParallelExecutorPrivate* member_; ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const ProgramDesc& startup_program) const; void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
}; };
} // namespace framework } // namespace framework
......
...@@ -544,13 +544,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -544,13 +544,20 @@ All parameter, weight, gradient are variables in Paddle.
[](ParallelExecutor &self, size_t num_threads, bool use_event, [](ParallelExecutor &self, size_t num_threads, bool use_event,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, bool allow_op_delay) { Scope *scope, std::vector<Scope *> &local_scopes,
new (&self) ParallelExecutor(num_threads, use_event, places, bool allow_op_delay) {
params, startup_program, main_program, new (&self)
loss_var_name, scope, allow_op_delay); ParallelExecutor(num_threads, use_event, places, params,
bcast_vars, main_program, loss_var_name,
scope, local_scopes, allow_op_delay);
}) })
.def("local_scopes",
[](ParallelExecutor &self) -> std::vector<Scope *> * {
return &self.GetLocalScopes();
},
py::return_value_policy::reference)
.def("run", &ParallelExecutor::Run); .def("run", &ParallelExecutor::Run);
BindRecordIOWriter(&m); BindRecordIOWriter(&m);
......
...@@ -22,10 +22,49 @@ __all__ = ['ParallelExecutor'] ...@@ -22,10 +22,49 @@ __all__ = ['ParallelExecutor']
class ParallelExecutor(object): class ParallelExecutor(object):
def __init__(self, def __init__(self,
loss_name,
use_cuda, use_cuda,
loss_name=None,
main_program=None,
num_threads=None, num_threads=None,
allow_op_delay=False): allow_op_delay=False,
share_vars_from=None):
"""
ParallelExecutor can run program in parallel.
Args:
use_cuda(bool): Whether to use CUDA or not.
loss_name(str, default None): The loss name must set in training.
main_program(Program, default None): The program that need to run,
if not provided, then default_main_program will be used.
num_threads(int, default None): How many threads are used for
training.
allow_op_delay(bool, default False): Whether to delay and buffer
some operators together for scheduling or not, which may
improve performance in some cases, defalut False.
share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor.
Returns:
A ParallelExecutor object.
Raises:
TypeError: If share_vars_from is provided, but not ParallelExecutor
object.
Examples:
.. code-block:: python
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=test_program,
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
"""
self._places = [] self._places = []
self._act_places = [] self._act_places = []
if use_cuda: if use_cuda:
...@@ -50,10 +89,21 @@ class ParallelExecutor(object): ...@@ -50,10 +89,21 @@ class ParallelExecutor(object):
else: else:
min(len(self._places) * 2, multiprocessing.cpu_count()) min(len(self._places) * 2, multiprocessing.cpu_count())
startup = framework.default_startup_program() main = main_program
main = framework.default_main_program() main = main if main else framework.default_main_program()
scope = executor.global_scope() scope = executor.global_scope()
if share_vars_from and not isinstance(share_vars_from,
ParallelExecutor):
raise TypeError("share_vars_from must be ParallelExecutor.")
local_scopes = share_vars_from.executor.local_scopes(
) if share_vars_from else []
persistable_vars = [
v.name
for v in filter(lambda var: var.persistable, main.list_vars())
]
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
num_threads, num_threads,
True if use_cuda else False, # use_event True if use_cuda else False, # use_event
...@@ -62,10 +112,11 @@ class ParallelExecutor(object): ...@@ -62,10 +112,11 @@ class ParallelExecutor(object):
p.name for p in main.global_block().iter_parameters() p.name for p in main.global_block().iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
startup.desc, set(persistable_vars),
main.desc, main.desc,
loss_name, loss_name if loss_name else '',
scope, scope,
local_scopes,
allow_op_delay) allow_op_delay)
self.scope = scope self.scope = scope
......
...@@ -207,7 +207,11 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -207,7 +207,11 @@ class TestParallelExecutorBase(unittest.TestCase):
if memory_opt: if memory_opt:
fluid.memory_optimize(main) fluid.memory_optimize(main)
exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True) place = fluid.CUDAPlace(0)
startup_exe = fluid.Executor(place)
startup_exe.run(startup)
exe = fluid.ParallelExecutor(True, loss_name=loss.name)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count()
begin = time.time() begin = time.time()
...@@ -453,3 +457,41 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -453,3 +457,41 @@ class TestTransformer(TestParallelExecutorBase):
@unittest.skip("transformer is buggy in multi gpu") @unittest.skip("transformer is buggy in multi gpu")
def test_main(self): def test_main(self):
self.check_network_convergence(transformer) self.check_network_convergence(transformer)
class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def test_parallel_testing(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = simple_fc_net(True)
test_program = main.clone(for_test=True)
opt = fluid.optimizer.SGD(learning_rate=0.0001)
opt.minimize(loss)
batch_size = 32
image = numpy.random.normal(size=(batch_size,
784)).astype('float32')
label = numpy.random.randint(0, 10, (batch_size, 1), dtype="int64")
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)
feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=test_program,
share_vars_from=train_exe)
for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
test_loss = numpy.array(test_loss)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
train_loss = numpy.array(train_loss)
self.assertTrue(numpy.allclose(train_loss, test_loss))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册