提交 ca8c77d9 编写于 作者: Y Yancey1989

selecte execution according to strategy test=develop

上级 4743c9cd
......@@ -134,7 +134,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &num_parallel_devices,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else
......@@ -153,9 +153,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("local_scopes");
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes);
pass->Erase("num_parallel_devices");
pass->Set<size_t>("num_parallel_devices",
new size_t(num_parallel_devices));
pass->Erase("nranks");
pass->Set<size_t>("nranks", new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......
......@@ -84,8 +84,6 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false};
bool enable_parallel_graph_{false};
int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
......@@ -112,7 +110,7 @@ struct BuildStrategy {
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &num_parallel_devices_,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
......@@ -120,6 +118,13 @@ struct BuildStrategy {
const bool use_cuda) const;
#endif
// If set true, ParallelExecutor would build the main_program into multiple
// graphs,
// each of the graphs would run with one device. This approach can achieve
// better performance
// on some scenarios.
mutable bool enable_parallel_graph_ = false;
private:
mutable bool is_finalized_ = false;
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
......
......@@ -138,7 +138,7 @@ static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places";
static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy";
static const char kNumParallelDevices[] = "num_parallel_devices";
static const char kNRanks[] = "nranks";
void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
......@@ -174,7 +174,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
size_t num_parallel_devices = Get<size_t>(kNumParallelDevices);
size_t nranks = Get<size_t>(kNRanks);
for (auto &node : nodes) {
if (node->IsVar() && node->Var()) {
......@@ -251,7 +251,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps(&result, node, places_.size());
}
if (!is_forwarding && num_parallel_devices > 1UL) {
if (!is_forwarding && nranks > 1UL) {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
......@@ -649,13 +649,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node, proto::VarType::Type dtype) const {
size_t num_parallel_devices = Get<size_t>("num_parallel_devices");
size_t nranks = Get<size_t>("nranks");
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
num_parallel_devices, local_scopes_[i], places_[i], dev_ctx, dtype);
nranks, local_scopes_[i], places_[i], dev_ctx, dtype);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......@@ -888,4 +888,4 @@ REGISTER_PASS(multi_devices_pass,
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNumParallelDevices);
.RequirePassAttr(paddle::framework::details::kNRanks);
......@@ -107,7 +107,7 @@ class ParallelExecutorPrivate {
bool own_local_scope_;
bool use_cuda_;
bool use_all_reduce_;
size_t num_parallel_devices_;
size_t nranks_;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
......@@ -203,7 +203,7 @@ ParallelExecutor::ParallelExecutor(
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ =
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
member_->num_parallel_devices_ = num_trainers * places.size();
member_->nranks_ = num_trainers * places.size();
if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1,
......@@ -211,16 +211,14 @@ ParallelExecutor::ParallelExecutor(
"the number of places must be greater than 1.");
}
if (build_strategy.enable_parallel_graph_) {
PADDLE_ENFORCE(
member_->use_all_reduce_,
"build_strategy.reduce should be `AllReduce` if you want to enable"
"ParallelGraph.");
PADDLE_ENFORCE(
member_->use_cuda_,
"execution_strategy.use_cuda should be True if you want to enable "
"ParallelGraph.");
}
// FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy.
build_strategy.enable_parallel_graph_ =
EnableParallelGraphExecution(main_program, exec_strategy, build_strategy);
VLOG(1) << "Enable ParallelGraph Execution: "
<< build_strategy.enable_parallel_graph_;
// Step 1. Bcast the bcast_vars to devs.
// Create local scopes
......@@ -242,20 +240,20 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr;
// nccl collective would broadcast nccl id by gen_nccl_id operator.
std::unique_ptr<ncclUniqueId> nccl_id;
// nccl collective would broadcast ncclUniqueId by gen_nccl_id operator.
if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
nccl_id.reset(nccl_id_var->GetMutable<ncclUniqueId>());
}
if (build_strategy.enable_parallel_graph_ && places.size() > 1) {
if (nccl_id == nullptr) {
nccl_id = new ncclUniqueId();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
if (build_strategy.enable_parallel_graph_ && member_->nranks_ > 1UL) {
if (nccl_id.get() == nullptr) {
nccl_id.reset(new ncclUniqueId());
platform::dynload::ncclGetUniqueId(nccl_id.get());
}
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, num_trainers, trainer_id));
member_->places_, nccl_id.get(), num_trainers, trainer_id));
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......@@ -268,27 +266,25 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs;
member_->num_parallel_devices_ = member_->places_.size() * num_trainers;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.enable_parallel_graph_) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->num_parallel_devices_,
member_->use_cuda_, member_->nccl_ctxs_.get());
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->num_parallel_devices_, member_->use_cuda_,
member_->nccl_ctxs_.get());
member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
#else
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->num_parallel_devices_, member_->use_cuda_);
member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
#endif
auto max_memory_size = GetEagerDeletionThreshold();
......@@ -470,6 +466,35 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
}
bool ParallelExecutor::EnableParallelGraphExecution(
const ProgramDesc &main_program, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) const {
bool enable_parallel_graph = true;
// TODO(Yancey1989): support sparse update in ParallelGraph mode.
for (auto &var_desc : main_program.Block(0).AllVars()) {
if (var_desc->GetType() == proto::VarType::SELECTED_ROWS) {
enable_parallel_graph = false;
}
}
// TODO(Yancey1989): support pserver mode
for (auto &op_desc : main_program.Block(0).AllOps()) {
if (op_desc->Type() == "send" || op_desc->Type() == "recv") {
enable_parallel_graph = false;
break;
}
}
if (!member_->use_all_reduce_ || !member_->use_cuda_)
enable_parallel_graph = false;
if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
enable_parallel_graph = false;
return enable_parallel_graph;
}
ParallelExecutor::~ParallelExecutor() {
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
......
......@@ -68,6 +68,9 @@ class ParallelExecutor {
private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
bool EnableParallelGraphExecution(const ProgramDesc &main_program,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy) const;
ParallelExecutorPrivate *member_;
};
......
......@@ -980,14 +980,6 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC")
.def_property(
"enable_parallel_graph",
[](const BuildStrategy &self) { return self.enable_parallel_graph_; },
[](BuildStrategy &self, bool b) { self.enable_parallel_graph_ = b; },
R"DOC(The type is BOOL, if set True, ParallelExecutor would build the main_program into multiple graphs,
each of the graphs would run with one device. This approach can achieve better performance in
some scenarios. Please note, this approach only supports all-reduce mode
on GPU device)DOC")
.def_property(
"memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; },
......
......@@ -156,7 +156,8 @@ def __bootstrap__():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus'
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'sync_nccl_allreduce'
]
core.init_gflags([sys.argv[0]] +
......
......@@ -39,7 +39,6 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None,
use_parallel_executor=True,
use_reduce=False,
use_parallel_graph=False,
use_ir_memory_optimize=False,
fuse_elewise_add_act_ops=False,
optimizer=fluid.optimizer.Adam,
......@@ -80,7 +79,6 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.enable_parallel_graph = use_parallel_graph
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
......
......@@ -175,14 +175,13 @@ class TestCRFModel(unittest.TestCase):
print(pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name])[0])
def _new_build_strategy(self, use_reduce=False, use_parallel_graph=False):
def _new_build_strategy(self, use_reduce=False):
build_strategy = fluid.BuildStrategy()
if use_reduce:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.enable_parallel_graph = use_parallel_graph
return build_strategy
......@@ -204,11 +203,6 @@ class TestCRFModel(unittest.TestCase):
is_sparse=False,
build_strategy=self._new_build_strategy(),
use_cuda=True)
self.check_network_convergence(
is_sparse=False,
build_strategy=self._new_build_strategy(
use_parallel_graph=True),
use_cuda=True)
self.check_network_convergence(
is_sparse=False,
......
......@@ -100,10 +100,7 @@ class TestMNIST(TestParallelExecutorBase):
self.assertAlmostEqual(loss[0], loss[1], delta=1e-4)
# simple_fc
def check_simple_fc_convergence(self,
use_cuda,
use_reduce=False,
use_parallel_graph=False):
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -114,15 +111,13 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce,
use_parallel_graph=use_parallel_graph)
use_reduce=use_reduce)
def test_simple_fc(self):
# use_cuda
if core.is_compiled_with_cuda():
self.check_simple_fc_convergence(True)
self.check_simple_fc_convergence(
True, use_reduce=False, use_parallel_graph=True)
self.check_simple_fc_convergence(True, use_reduce=False)
self.check_simple_fc_convergence(False)
def test_simple_fc_with_new_strategy(self):
......@@ -130,9 +125,7 @@ class TestMNIST(TestParallelExecutorBase):
self._compare_reduce_and_allreduce(simple_fc_net, True)
self._compare_reduce_and_allreduce(simple_fc_net, False)
def check_simple_fc_parallel_accuracy(self,
use_cuda,
use_parallel_graph=False):
def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -144,16 +137,7 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=False,
use_parallel_graph=use_parallel_graph)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=True,
use_parallel_graph=use_parallel_graph)
use_parallel_executor=False)
self.assertAlmostEquals(
np.mean(parallel_first_loss),
......@@ -165,15 +149,11 @@ class TestMNIST(TestParallelExecutorBase):
def test_simple_fc_parallel_accuracy(self):
if core.is_compiled_with_cuda():
self.check_simple_fc_parallel_accuracy(True)
self.check_simple_fc_parallel_accuracy(
True, use_parallel_graph=True)
self.check_simple_fc_parallel_accuracy(True)
# FIXME(Yancey1989): ParallelGraph executor type support CPU mode
self.check_simple_fc_parallel_accuracy(False)
def check_batchnorm_fc_convergence(self,
use_cuda,
use_fast_executor,
use_parallel_graph=False):
def check_batchnorm_fc_convergence(self, use_cuda, use_fast_executor):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -184,8 +164,7 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_fast_executor=use_fast_executor,
use_parallel_graph=use_parallel_graph)
use_fast_executor=use_fast_executor)
def test_batchnorm_fc(self):
for use_cuda in (False, True):
......@@ -193,7 +172,7 @@ class TestMNIST(TestParallelExecutorBase):
self.check_batchnorm_fc_convergence(use_cuda, use_fast_executor)
self.check_batchnorm_fc_convergence(
use_cuda=True, use_fast_executor=False, use_parallel_graph=True)
use_cuda=True, use_fast_executor=False)
def test_batchnorm_fc_with_new_strategy(self):
# FIXME(zcd): close this test temporally.
......
......@@ -277,9 +277,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda=True,
use_reduce=False,
iter=20,
delta2=1e-6,
use_parallel_graph=False,
lr_scale=1.0):
delta2=1e-6):
if use_cuda and not core.is_compiled_with_cuda():
return
......@@ -298,8 +296,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda=use_cuda,
use_reduce=use_reduce,
optimizer=optimizer,
use_parallel_executor=False,
use_parallel_graph=use_parallel_graph)
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
......@@ -308,8 +305,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=use_reduce,
optimizer=optimizer,
use_parallel_graph=use_parallel_graph)
optimizer=optimizer)
self.assertAlmostEquals(
np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6)
......@@ -320,11 +316,6 @@ class TestResnet(TestParallelExecutorBase):
if core.is_compiled_with_cuda():
self._check_resnet_convergence(
model=SE_ResNeXt50Small, use_cuda=True)
self._check_resnet_convergence(
model=SE_ResNeXt50Small,
use_cuda=True,
use_parallel_graph=True,
lr_scale=core.get_cuda_device_count())
self._check_resnet_convergence(
model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3)
......
......@@ -175,8 +175,6 @@ class TestTransformer(TestParallelExecutorBase):
self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(
transformer, use_cuda=True, enable_sequential_execution=True)
self.check_network_convergence(
transformer, use_cuda=True, use_parallel_graph=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册