未验证 提交 60c51de5 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add input data type checking in BF16 placement pass (#38702)

上级 bbe83ed1
...@@ -2441,11 +2441,13 @@ PDNode *patterns::Bfloat16Placement::operator()( ...@@ -2441,11 +2441,13 @@ PDNode *patterns::Bfloat16Placement::operator()(
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types); auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
op->assert_more([&](Node *node) { op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") || return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
node->Op()->Type() == "reshape2"; node->Op()->Type() == "reshape2";
}); });
op->LinksFrom({op_in});
return op; return op;
} }
......
...@@ -1446,6 +1446,7 @@ struct Bfloat16Placement : public PatternBase { ...@@ -1446,6 +1446,7 @@ struct Bfloat16Placement : public PatternBase {
PDNode* operator()( PDNode* operator()(
const std::unordered_set<std::string>& bfloat16_enabled_op_types); const std::unordered_set<std::string>& bfloat16_enabled_op_types);
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
......
...@@ -41,8 +41,12 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType( ...@@ -41,8 +41,12 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern); GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern);
// Only float input can be converted to bfloat16
if (op_in->Var()->GetDataType() != proto::VarType::FP32) return;
if ((op->Op()->HasAttr("mkldnn_data_type") || if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) && op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) { !platform::HasOpINT8DataType(op->Op())) {
......
...@@ -68,7 +68,7 @@ ProgramDesc BuildProgramDesc() { ...@@ -68,7 +68,7 @@ ProgramDesc BuildProgramDesc() {
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m", std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m",
"n", "o", "p", "r", "s"})) { "n", "o", "p", "r", "s"})) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
} }
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"}); SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
...@@ -86,9 +86,8 @@ ProgramDesc BuildProgramDesc() { ...@@ -86,9 +86,8 @@ ProgramDesc BuildProgramDesc() {
} }
void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types, void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
unsigned expected_bfloat16_data_type_count) { unsigned expected_bfloat16_data_type_count,
auto prog = BuildProgramDesc(); const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass"); auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
...@@ -110,8 +109,8 @@ void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types, ...@@ -110,8 +109,8 @@ void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count); EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
} }
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { void DefaultAttrTest(unsigned expected_bfloat16_data_type_count,
auto prog = BuildProgramDesc(); const ProgramDesc& prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass"); auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
...@@ -128,15 +127,39 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { ...@@ -128,15 +127,39 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
} }
TEST(Bfloat16PlacementPass, enable_all) { TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8); MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8,
BuildProgramDesc());
} }
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
// 2 conv2d + 2 pool2 - 1 orphaned conv2d // 2 conv2d + 2 pool2 - 1 orphaned conv2d
MainTest({"conv2d", "pool2d"}, 3); MainTest({"conv2d", "pool2d"}, 3, BuildProgramDesc());
}
TEST(Bfloat16PlacementPass, default_attr_value) {
DefaultAttrTest(10, BuildProgramDesc());
}
ProgramDesc BuildProgramDescWithDataType() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e"})) {
if (v == "a") {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::INT32);
} else {
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
}
SetOp(&prog, "conv2d", "conv1", {"a"}, {"b"});
SetOp(&prog, "pool2d", "pool1", {"b"}, {"c"});
SetOp(&prog, "concat", "concat1", {"c", "d"}, {"e"});
return prog;
} }
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(10); } TEST(Bfloat16PlacementPass, check_data_types) {
DefaultAttrTest(2, BuildProgramDescWithDataType());
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册