提交 a943134a 编写于 作者: X Xin Pan

fix a few more tests

test=develop
上级 5839e323
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
......@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
......@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
......
......@@ -24,12 +24,15 @@ namespace paddle {
namespace framework {
namespace ir {
namespace {
void CheckProgram(const ProgramDesc &program) {
std::map<int, bool> visit;
#define _INT(role) static_cast<int>(role)
for (size_t i = 0; i < program.Size(); ++i) {
for (OpDesc *op : program.Block(i).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id = boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
......@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type,
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Xs", outputs);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
}
TEST(DataFlowGraph, Build_IR_Graph) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册