未验证 提交 08d22cf7 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14091 from panyx0718/fix2

add program check
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", inputs); op->SetInput("X", inputs);
} }
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
} }
// a->OP0->b // a->OP0->b
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetInput("X", inputs); op->SetInput("X", inputs);
} }
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
} }
// a->OP0->b // a->OP0->b
......
...@@ -23,8 +23,62 @@ limitations under the License. */ ...@@ -23,8 +23,62 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace {
void CheckProgram(const ProgramDesc &program) {
std::map<int, bool> visit;
#define _INT(role) static_cast<int>(role)
for (size_t i = 0; i < program.Size(); ++i) {
for (OpDesc *op : program.Block(i).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id = boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
switch (role_id) {
case _INT(OpRole::kForward):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kBackward)) == visit.end(),
"Cannot add forward operator before backward operator.");
break;
case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator before optimize operator.");
break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before "
"forward|loss operator.");
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator before optimize operator.");
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators must follow backward operator.");
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
}
#undef _INT
}
} // namespace
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
CheckProgram(program_);
// Make the nodes id start from 0. // Make the nodes id start from 0.
Node::ResetId(); Node::ResetId();
auto var_nodes = InitFromProgram(program_); auto var_nodes = InitFromProgram(program_);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
...@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type, ...@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type,
op->SetType(type); op->SetType(type);
op->SetInput("Xs", inputs); op->SetInput("Xs", inputs);
op->SetOutput("Xs", outputs); op->SetOutput("Xs", outputs);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
} }
TEST(DataFlowGraph, Build_IR_Graph) { TEST(DataFlowGraph, Build_IR_Graph) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册