From 03e4da6d046414a6cab81b87cb1cd0eea4e19a1d Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Tue, 22 May 2018 20:46:41 +0800 Subject: [PATCH] Fix bug --- .../framework/details/multi_devices_graph_builder.cc | 11 +++++++---- paddle/fluid/framework/op_desc.cc | 10 ++++++++++ paddle/fluid/framework/op_desc.h | 3 +++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 6506af652..447dfa965 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" +#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" #ifdef PADDLE_WITH_CUDA @@ -162,8 +163,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (static_cast(boost::get(op->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { - auto &backward_vars = boost::get>( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + auto backward_vars = boost::get>( + op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(), + std::vector())); for (auto &og : backward_vars) { switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: @@ -404,8 +406,9 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { return boost::get( op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss)); + (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss)) && + !loss_var_name_.empty(); // If loss_var is empty. This is test mode } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index b68421afe..d14d9cb8a 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -223,6 +223,16 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } +Attribute OpDesc::GetAttrOrDefault( + const std::string &name, paddle::framework::Attribute default_attr) const { + auto it = attrs_.find(name); + if (it != attrs_.end()) { + return it->second; + } else { + return default_attr; + } +} + int OpDesc::GetBlockAttr(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 3ee36a47c..82542a83c 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -78,6 +78,9 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; + Attribute GetAttrOrDefault(const std::string &name, + Attribute default_attr) const; + int GetBlockAttr(const std::string &name) const; void Rename(const std::string &old_name, const std::string &new_name); -- GitLab