未验证 提交 4c19b8c7 编写于 作者: S sprouteer 提交者: GitHub

[XPU] Support fc_batch_norm (#54157)

上级 f3eccb3f
...@@ -44,11 +44,13 @@ struct FcXPUPattern : public PatternBase { ...@@ -44,11 +44,13 @@ struct FcXPUPattern : public PatternBase {
const std::string& name_scope, const std::string& name_scope,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
bool with_bn,
const std::string& act_type); const std::string& act_type);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(mul); PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(add); PATTERN_DECL_NODE(add);
PATTERN_DECL_NODE(bn);
PATTERN_DECL_NODE(act); PATTERN_DECL_NODE(act);
// declare variable node's name // declare variable node's name
PATTERN_DECL_NODE(mul_x); PATTERN_DECL_NODE(mul_x);
...@@ -56,11 +58,21 @@ struct FcXPUPattern : public PatternBase { ...@@ -56,11 +58,21 @@ struct FcXPUPattern : public PatternBase {
PATTERN_DECL_NODE(mul_out); PATTERN_DECL_NODE(mul_out);
PATTERN_DECL_NODE(bias); PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(add_out); PATTERN_DECL_NODE(add_out);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_mean);
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_var);
PATTERN_DECL_NODE(bn_out);
PATTERN_DECL_NODE(bn_var_out);
PATTERN_DECL_NODE(bn_mean_out);
PATTERN_DECL_NODE(bn_saved_var);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(act_out); PATTERN_DECL_NODE(act_out);
private: private:
std::string mul_type_; std::string mul_type_;
bool with_bias_{false}; bool with_bias_{false};
bool with_bn_{false};
std::string act_type_; std::string act_type_;
}; };
...@@ -68,10 +80,12 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, ...@@ -68,10 +80,12 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
const std::string& name_scope, const std::string& name_scope,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
bool with_bn,
const std::string& act_type) const std::string& act_type)
: PatternBase(pattern, name_scope, name_scope), : PatternBase(pattern, name_scope, name_scope),
mul_type_(mul_type), mul_type_(mul_type),
with_bias_(with_bias), with_bias_(with_bias),
with_bn_(with_bn),
act_type_(act_type) { act_type_(act_type) {
auto* mul_x = pattern->NewNode(mul_x_repr()) auto* mul_x = pattern->NewNode(mul_x_repr())
->assert_is_op_input(mul_type_, "X") ->assert_is_op_input(mul_type_, "X")
...@@ -118,13 +132,57 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, ...@@ -118,13 +132,57 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
} else { } else {
add_out = mul_out; add_out = mul_out;
} }
PDNode* bn = nullptr;
PDNode* bn_bias = nullptr;
PDNode* bn_mean = nullptr;
PDNode* bn_scale = nullptr;
PDNode* bn_var = nullptr;
PDNode* bn_out = nullptr;
PDNode* bn_mean_out = nullptr;
PDNode* bn_saved_mean = nullptr;
PDNode* bn_var_out = nullptr;
PDNode* bn_saved_var = nullptr;
if (with_bn_) {
add_out->assert_is_op_input("batch_norm", "X");
bn_bias = pattern->NewNode(bn_bias_repr())
->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
bn_mean = pattern->NewNode(bn_mean_repr())
->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
bn_scale = pattern->NewNode(bn_scale_repr())
->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
bn_var = pattern->NewNode(bn_var_repr())
->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm");
bn_out =
pattern->NewNode(bn_out_repr())->assert_is_op_output("batch_norm", "Y");
if (!act_type_.empty()) {
bn_out->assert_has_n_outputs(1);
}
bn_mean_out = pattern->NewNode(bn_mean_out_repr())
->assert_is_op_output("batch_norm", "MeanOut");
bn_saved_mean = pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_output("batch_norm", "SavedMean");
bn_var_out = pattern->NewNode(bn_var_out_repr())
->assert_is_op_output("batch_norm", "VarianceOut");
bn_saved_var = pattern->NewNode(bn_saved_var_repr())
->assert_is_op_output("batch_norm", "SavedVariance");
bn->LinksFrom({add_out, bn_bias, bn_mean, bn_scale, bn_var})
.LinksTo(
{bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var});
} else {
bn_out = add_out;
}
if (!act_type_.empty()) { if (!act_type_.empty()) {
add_out->assert_is_op_input(act_type_, "X"); bn_out->assert_is_op_input(act_type_, "X");
act = pattern->NewNode(act_repr())->assert_is_op(act_type_); act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
act_out = pattern->NewNode(act_out_repr()) act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type_, "Out") ->assert_is_op_output(act_type_, "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
act->LinksFrom({add_out}).LinksTo({act_out}); act->LinksFrom({bn_out}).LinksTo({act_out});
} }
} }
...@@ -151,6 +209,12 @@ Origin subgraph: ...@@ -151,6 +209,12 @@ Origin subgraph:
elementwise_add_out elementwise_add_out
| |
| |
batch_norm
|
|
batch_norm_out
|
|
act act
| |
| |
...@@ -174,6 +238,7 @@ class FcXPUFusePass : public FusePassBase { ...@@ -174,6 +238,7 @@ class FcXPUFusePass : public FusePassBase {
int ApplyImpl(ir::Graph* graph, int ApplyImpl(ir::Graph* graph,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
bool with_bn,
const std::string& act_type) const; const std::string& act_type) const;
const std::string name_scope_{"fc_xpu_fuse_pass"}; const std::string name_scope_{"fc_xpu_fuse_pass"};
...@@ -187,13 +252,16 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -187,13 +252,16 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
int found_subgraph_count = 0; int found_subgraph_count = 0;
for (auto mul_type : {"mul", "matmul", "matmul_v2"}) { for (auto mul_type : {"mul", "matmul", "matmul_v2"}) {
for (auto with_bias : {true, false}) { for (auto with_bias : {true, false}) {
for (auto act_type : { for (auto with_bn : {true, false}) {
"relu", for (auto act_type : {
"gelu", "relu",
"tanh", "gelu",
"", "tanh",
}) { "",
found_subgraph_count += ApplyImpl(graph, mul_type, with_bias, act_type); }) {
found_subgraph_count +=
ApplyImpl(graph, mul_type, with_bias, with_bn, act_type);
}
} }
} }
} }
...@@ -203,10 +271,15 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -203,10 +271,15 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
int FcXPUFusePass::ApplyImpl(ir::Graph* graph, int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
bool with_bn,
const std::string& act_type) const { const std::string& act_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FcXPUPattern pattern( patterns::FcXPUPattern pattern(gpd.mutable_pattern(),
gpd.mutable_pattern(), name_scope_, mul_type, with_bias, act_type); name_scope_,
mul_type,
with_bias,
with_bn,
act_type);
int found_subgraph_count = 0; int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -219,30 +292,100 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -219,30 +292,100 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
GET_IR_NODE(bias); GET_IR_NODE(bias);
GET_IR_NODE(add); GET_IR_NODE(add);
GET_IR_NODE(add_out); GET_IR_NODE(add_out);
GET_IR_NODE(bn);
GET_IR_NODE(bn_bias);
GET_IR_NODE(bn_mean);
GET_IR_NODE(bn_scale);
GET_IR_NODE(bn_var);
GET_IR_NODE(bn_out);
GET_IR_NODE(bn_var_out);
GET_IR_NODE(bn_mean_out);
GET_IR_NODE(bn_saved_var);
GET_IR_NODE(bn_saved_mean);
GET_IR_NODE(act); GET_IR_NODE(act);
GET_IR_NODE(act_out); GET_IR_NODE(act_out);
auto* block = mul->Op()->Block(); auto* block = mul->Op()->Block();
auto* scope = param_scope(); auto* scope = param_scope();
auto* filter_t =
scope->FindVar(mul_w->Name())->GetMutable<phi::DenseTensor>();
// filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
CastToFp32(filter_t, nullptr);
}
auto filter_dims = filter_t->dims();
bool transpose_w = false; bool transpose_w = false;
if (mul_type == "matmul") { if (mul_type == "matmul") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
} else if (mul_type == "matmul_v2") { } else if (mul_type == "matmul_v2") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
} }
bool has_bias = with_bn || with_bias;
Node* fusion_bias_node = nullptr;
if (has_bias) {
if (bias != nullptr) {
PrepareBias(graph, scope, block, bias, &fusion_bias_node);
}
if (bn != nullptr) {
auto bn_bias_t =
scope->Var(bn_bias->Name())->GetMutable<phi::DenseTensor>();
auto bn_scale_t =
scope->Var(bn_scale->Name())->GetMutable<phi::DenseTensor>();
auto bn_mean_t =
scope->Var(bn_mean->Name())->GetMutable<phi::DenseTensor>();
auto bn_var_t =
scope->Var(bn_var->Name())->GetMutable<phi::DenseTensor>();
float* mul_w_ptr = filter_t->data<float>();
float* bn_scale_ptr = bn_scale_t->data<float>();
float* bn_bias_ptr = bn_bias_t->data<float>();
float* bn_mean_ptr = bn_mean_t->data<float>();
float* bn_var_ptr = bn_var_t->data<float>();
auto mean_len = bn_mean_t->numel();
auto filter_h = filter_dims[0];
auto filter_w = filter_dims[1];
float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon"));
if (fusion_bias_node == nullptr) { // prev node is conv
PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node);
}
auto fusion_bias_t = scope->Var(fusion_bias_node->Name())
->GetMutable<phi::DenseTensor>();
float* fusion_bias_ptr = fusion_bias_t->data<float>();
// recompute bias and weights
if (bias == nullptr) {
for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_h; j++) {
mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i];
}
}
} else {
for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
bn_bias_ptr[i] +=
(fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_h; j++) {
mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i];
}
}
memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float));
}
}
}
Node* mul_w_int16 = nullptr; Node* mul_w_int16 = nullptr;
Node* mul_w_max = nullptr; Node* mul_w_max = nullptr;
PrepareWeight<int16_t>( PrepareWeight<int16_t>(
graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w); graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w);
Node* bias_fp32 = nullptr;
if (bias != nullptr) {
PrepareBias(graph, scope, block, bias, &bias_fp32);
}
std::string fc_out_name; std::string fc_out_name;
if (act_out) { if (act_out) {
fc_out_name = act_out->Name(); fc_out_name = act_out->Name();
} else if (bn) {
fc_out_name = bn_out->Name();
} else if (add_out) { } else if (add_out) {
fc_out_name = add_out->Name(); fc_out_name = add_out->Name();
} else { } else {
...@@ -258,8 +401,8 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -258,8 +401,8 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); fc_xpu_op_desc.SetInput("x", {mul_x->Name()});
fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()}); fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()});
fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()}); fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()});
if (bias_fp32) { if (has_bias) {
fc_xpu_op_desc.SetInput("bias", {bias_fp32->Name()}); fc_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()});
} }
fc_xpu_op_desc.SetAttr( fc_xpu_op_desc.SetAttr(
"in_num_col_dims", "in_num_col_dims",
...@@ -294,9 +437,13 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -294,9 +437,13 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
IR_NODE_LINK_TO(mul_x, fc_xpu); IR_NODE_LINK_TO(mul_x, fc_xpu);
IR_NODE_LINK_TO(mul_w_int16, fc_xpu); IR_NODE_LINK_TO(mul_w_int16, fc_xpu);
IR_NODE_LINK_TO(mul_w_max, fc_xpu); IR_NODE_LINK_TO(mul_w_max, fc_xpu);
SAFE_IR_NODE_LINK_TO(bias_fp32, fc_xpu); if (bias || bn) {
SAFE_IR_NODE_LINK_TO(fusion_bias_node, fc_xpu);
}
if (act_out) { if (act_out) {
IR_NODE_LINK_TO(fc_xpu, act_out); IR_NODE_LINK_TO(fc_xpu, act_out);
} else if (bn_out) {
IR_NODE_LINK_TO(fc_xpu, bn_out);
} else if (add_out) { } else if (add_out) {
IR_NODE_LINK_TO(fc_xpu, add_out); IR_NODE_LINK_TO(fc_xpu, add_out);
} else { } else {
...@@ -315,6 +462,17 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -315,6 +462,17 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
} else { } else {
delete_nodes = {mul}; delete_nodes = {mul};
} }
if (bn != nullptr) {
delete_nodes.insert(bn);
delete_nodes.insert(bn_bias);
delete_nodes.insert(bn_var);
delete_nodes.insert(bn_mean);
delete_nodes.insert(bn_scale);
delete_nodes.insert(bn_var_out);
delete_nodes.insert(bn_mean_out);
delete_nodes.insert(bn_saved_var);
delete_nodes.insert(bn_saved_mean);
}
GraphSafeRemoveNodes(graph, delete_nodes); GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++; found_subgraph_count++;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册