提交 03e4da6d 编写于 作者: Y yuyang18

Fix bug

上级 27e4ce72
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -162,8 +163,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -162,8 +163,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (static_cast<bool>(boost::get<int>(op->GetAttr( if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) { static_cast<int>(OpRole::kBackward))) {
auto &backward_vars = boost::get<std::vector<std::string>>( auto backward_vars = boost::get<std::vector<std::string>>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(),
std::vector<std::string>()));
for (auto &og : backward_vars) { for (auto &og : backward_vars) {
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
...@@ -404,8 +406,9 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, ...@@ -404,8 +406,9 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
return boost::get<int>( return boost::get<int>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) | (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)); static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -223,6 +223,16 @@ Attribute OpDesc::GetAttr(const std::string &name) const { ...@@ -223,6 +223,16 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
Attribute OpDesc::GetAttrOrDefault(
const std::string &name, paddle::framework::Attribute default_attr) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
return it->second;
} else {
return default_attr;
}
}
int OpDesc::GetBlockAttr(const std::string &name) const { int OpDesc::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
......
...@@ -78,6 +78,9 @@ class OpDesc { ...@@ -78,6 +78,9 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
Attribute GetAttrOrDefault(const std::string &name,
Attribute default_attr) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
void Rename(const std::string &old_name, const std::string &new_name); void Rename(const std::string &old_name, const std::string &new_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册