未验证 提交 49108efa 编写于 作者: W wenbin 提交者: GitHub

remove bf16 (#38133)

* remove bf16

* remove comments

* remove wrong return

* fix UT
上级 b28c374a
......@@ -2412,6 +2412,23 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
return next_op;
}
PDNode *patterns::UnsupportedBfloat16::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->HasAttr("mkldnn_data_type") == false;
});
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
prev_op->LinksTo({prev_out});
op->LinksFrom({prev_out});
return op;
}
PDNode *patterns::LastBfloat16Ops::operator()() {
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
......
......@@ -1416,6 +1416,16 @@ struct OrphanedBfloat16 : public PatternBase {
PATTERN_DECL_NODE(next_op);
};
struct UnsupportedBfloat16 : public PatternBase {
UnsupportedBfloat16(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "unsupported_bfloat16") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(op);
};
struct LastBfloat16Ops : public PatternBase {
LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "last_bfloat16_ops") {}
......
......@@ -71,10 +71,31 @@ void CPUBfloat16PlacementPass::RemoveOrphanedOperators(
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::RemoveUnsupportedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
// now quantize is supported FP32 only, so try to find
// bfloat16 operator that input type is not FP32
GraphPatternDetector gpd;
patterns::UnsupportedBfloat16 unsupported_bfloat16_pattern{
gpd.mutable_pattern(), "unsupported_bfloat16"};
unsupported_bfloat16_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, unsupported_bfloat16_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, unsupported_bfloat16_pattern);
if ((prev_out->Var()->GetDataType() != proto::VarType::FP32)) {
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--;
}
};
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrphanedOperators(graph, &bfloat16_operators);
RemoveUnsupportedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
......
......@@ -30,6 +30,9 @@ class CPUBfloat16PlacementPass : public Pass {
void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveUnsupportedOperators(ir::Graph* graph,
int* bfloat16_operators) const;
void ApplyImpl(ir::Graph* graph) const override;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册