未验证 提交 f3e4e42d 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #10130 from reyoung/feature/skip_loss

Add customize_loss_grad option to PE
......@@ -34,7 +34,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_scopes, bool skip_scale_loss,
platform::NCCLContextMap *nccl_ctxs)
: loss_var_name_(loss_var_name),
places_(places),
......@@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes)
const std::vector<Scope *> &local_scopes, bool skip_scale_loss)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes) {
......@@ -53,6 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
skip_scale_loss_ = skip_scale_loss;
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
......@@ -95,7 +96,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// always use the first device
CreateSendOp(&result, *op);
} else if (IsScaleLossOp(*op)) {
if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
} else {
CreateComputationalOps(&result, *op);
......
......@@ -34,12 +34,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
bool skip_scale_loss,
platform::NCCLContextMap *nccl_ctxs);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes);
const std::vector<Scope *> &local_scopes,
bool skip_scale_loss);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
......@@ -57,6 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
#endif
bool skip_scale_loss_;
bool IsScaleLossOp(const OpDesc &op) const;
......
......@@ -57,7 +57,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &params,
const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay)
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool customize_scale_loss)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
......@@ -90,12 +91,13 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_,
member_->nccl_ctxs_.get());
details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_,
customize_scale_loss, member_->nccl_ctxs_.get());
#else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_);
params, member_->local_scopes_,
customize_scale_loss);
#endif
auto graph = builder.Build(main_program);
......
......@@ -40,7 +40,7 @@ class ParallelExecutor {
const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes,
bool allow_op_delay);
bool allow_op_delay, bool customize_scale_loss);
~ParallelExecutor();
......
......@@ -502,11 +502,11 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector<Scope *> &local_scopes,
bool allow_op_delay) {
new (&self)
ParallelExecutor(num_threads, use_event, places, params,
bcast_vars, main_program, loss_var_name,
scope, local_scopes, allow_op_delay);
bool allow_op_delay, bool customize_loss_grad) {
new (&self) ParallelExecutor(num_threads, use_event, places,
params, bcast_vars, main_program,
loss_var_name, scope, local_scopes,
allow_op_delay, customize_loss_grad);
})
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
......
......@@ -29,7 +29,8 @@ class ParallelExecutor(object):
main_program=None,
num_threads=None,
allow_op_delay=False,
share_vars_from=None):
share_vars_from=None,
customize_loss_grad=False):
"""
ParallelExecutor can run program in parallel.
......@@ -122,7 +123,8 @@ class ParallelExecutor(object):
loss_name if loss_name else '',
scope,
local_scopes,
allow_op_delay)
allow_op_delay,
customize_loss_grad)
self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册