未验证 提交 60c51de5 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add input data type checking in BF16 placement pass (#38702)

上级 bbe83ed1
......@@ -2441,11 +2441,13 @@ PDNode *patterns::Bfloat16Placement::operator()(
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
node->Op()->Type() == "reshape2";
});
op->LinksFrom({op_in});
return op;
}
......
......@@ -1446,6 +1446,7 @@ struct Bfloat16Placement : public PatternBase {
PDNode* operator()(
const std::unordered_set<std::string>& bfloat16_enabled_op_types);
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
......
......@@ -41,8 +41,12 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern);
// Only float input can be converted to bfloat16
if (op_in->Var()->GetDataType() != proto::VarType::FP32) return;
if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
......
......@@ -68,7 +68,7 @@ ProgramDesc BuildProgramDesc() {
for (auto& v :
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m",
"n", "o", "p", "r", "s"})) {
prog.MutableBlock(0)->Var(v);
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
......@@ -86,9 +86,8 @@ ProgramDesc BuildProgramDesc() {
}
void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
unsigned expected_bfloat16_data_type_count,
const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
......@@ -110,8 +109,8 @@ void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
}
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count,
const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
graph.reset(pass->Apply(graph.release()));
......@@ -128,15 +127,39 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
}
TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8);
MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8,
BuildProgramDesc());
}
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
// 2 conv2d + 2 pool2 - 1 orphaned conv2d
MainTest({"conv2d", "pool2d"}, 3);
MainTest({"conv2d", "pool2d"}, 3, BuildProgramDesc());
}
TEST(Bfloat16PlacementPass, default_attr_value) {
DefaultAttrTest(10, BuildProgramDesc());
}
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(10); }
ProgramDesc BuildProgramDescWithDataType() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e"})) {
if (v == "a") {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::INT32);
} else {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
}
SetOp(&prog, "conv2d", "conv1", {"a"}, {"b"});
SetOp(&prog, "pool2d", "pool1", {"b"}, {"c"});
SetOp(&prog, "concat", "concat1", {"c", "d"}, {"e"});
return prog;
}
TEST(Bfloat16PlacementPass, check_data_types) {
DefaultAttrTest(2, BuildProgramDescWithDataType());
}
} // namespace ir
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册