提交 03e4da6d 编写于 作者: Y yuyang18

Fix bug

上级 27e4ce72
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
......@@ -162,8 +163,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
auto &backward_vars = boost::get<std::vector<std::string>>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
auto backward_vars = boost::get<std::vector<std::string>>(
op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(),
std::vector<std::string>()));
for (auto &og : backward_vars) {
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
......@@ -404,8 +406,9 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
return boost::get<int>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss));
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
}
} // namespace details
} // namespace framework
......
......@@ -223,6 +223,16 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second;
}
Attribute OpDesc::GetAttrOrDefault(
const std::string &name, paddle::framework::Attribute default_attr) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
return it->second;
} else {
return default_attr;
}
}
int OpDesc::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
......
......@@ -78,6 +78,9 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const;
Attribute GetAttrOrDefault(const std::string &name,
Attribute default_attr) const;
int GetBlockAttr(const std::string &name) const;
void Rename(const std::string &old_name, const std::string &new_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册