diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 6506af6520bb35d99770b804e1204c9a437617c7..447dfa9655fa4daa7be7ab9cc691b904ac0ae651 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" +#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" #ifdef PADDLE_WITH_CUDA @@ -162,8 +163,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (static_cast(boost::get(op->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { - auto &backward_vars = boost::get>( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + auto backward_vars = boost::get>( + op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(), + std::vector())); for (auto &og : backward_vars) { switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: @@ -404,8 +406,9 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { return boost::get( op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss)); + (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss)) && + !loss_var_name_.empty(); // If loss_var is empty. This is test mode } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index b68421afed9edaa89c723290e11085bdf448403d..d14d9cb8ab8a01913388eeb20941d7a67b00cc09 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -223,6 +223,16 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } +Attribute OpDesc::GetAttrOrDefault( + const std::string &name, paddle::framework::Attribute default_attr) const { + auto it = attrs_.find(name); + if (it != attrs_.end()) { + return it->second; + } else { + return default_attr; + } +} + int OpDesc::GetBlockAttr(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 3ee36a47c156da67a9ff70852665fbbd464bea17..82542a83c504a55d76b0fded305ef0f4886f8e6a 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -78,6 +78,9 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; + Attribute GetAttrOrDefault(const std::string &name, + Attribute default_attr) const; + int GetBlockAttr(const std::string &name) const; void Rename(const std::string &old_name, const std::string &new_name);