未验证 提交 030b23da 编写于 作者: T Tomasz Socha 提交者: GitHub

Fix for Bfloat16 placement pass. (#43109)

* Fix bfloat16 placement pass

* Make it nicer

* Fix leftovers

* Style
上级 990c5e7f
......@@ -2631,8 +2631,10 @@ PDNode *patterns::Bfloat16Placement::operator()(
PDNode *patterns::OrphanedBfloat16::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
"mkldnn_data_type") == "float32";
return data_type_is_missing || data_type_is_fp32;
});
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
......@@ -2645,8 +2647,10 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
"mkldnn_data_type") == "float32";
return data_type_is_missing || data_type_is_fp32;
});
prev_op->LinksTo({prev_out});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册