未验证 提交 d4c771c6 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14235 from panyx0718/fix5

fix to only check block 0
...@@ -26,59 +26,58 @@ namespace ir { ...@@ -26,59 +26,58 @@ namespace ir {
namespace { namespace {
void CheckProgram(const ProgramDesc &program) { void CheckProgram(const ProgramDesc &program) {
std::map<int, bool> visit;
#define _INT(role) static_cast<int>(role) #define _INT(role) static_cast<int>(role)
for (size_t i = 0; i < program.Size(); ++i) { std::map<int, bool> visit;
for (OpDesc *op : program.Block(i).AllOps()) { for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added. // For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue; if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id = boost::get<int>( int role_id =
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true; visit[role_id] = true;
switch (role_id) { switch (role_id) {
case _INT(OpRole::kForward): case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) { if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR) LOG(ERROR)
<< "Cannot add backward operator before forward operator %s." << "Cannot add backward operator before forward operator %s."
<< op->Type(); << op->Type();
} }
break; break;
case _INT(OpRole::kBackward): case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss): case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE( PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(), visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s before optimize operator.", "Cannot add backward operator %s after optimize operator.",
op->Type()); op->Type());
break; break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss): case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) | PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(), _INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before " "Cannot add backward|loss operator before "
"forward|loss operator %s.", "forward|loss operator %s.",
op->Type()); op->Type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(), visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.", "Cannot add forward|loss operator %s after optimize operator.",
op->Type()); op->Type());
break; break;
case _INT(OpRole::kOptimize): case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched): case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(), PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.", "Optimize operators %s must follow backward operator.",
op->Type()); op->Type());
break; break;
case _INT(OpRole::kLRSched): case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist): case _INT(OpRole::kDist):
case _INT(OpRole::kRPC): case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified): case _INT(OpRole::kNotSpecified):
break; break;
default: default:
LOG(FATAL) << "Unknown operator role. Don't add new role because " LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing."; "you don't know what you are doing.";
}
} }
} }
#undef _INT #undef _INT
} }
} // namespace } // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册