diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 74b021bf1af2716af31cc7e7599bf1047f44f36c..65708c4c1d1eb0633d94653aa55d3cf4b6f201a5 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -44,11 +44,13 @@ struct FcXPUPattern : public PatternBase { const std::string& name_scope, const std::string& mul_type, bool with_bias, + bool with_bn, const std::string& act_type); // declare operator node's name PATTERN_DECL_NODE(mul); PATTERN_DECL_NODE(add); + PATTERN_DECL_NODE(bn); PATTERN_DECL_NODE(act); // declare variable node's name PATTERN_DECL_NODE(mul_x); @@ -56,11 +58,21 @@ struct FcXPUPattern : public PatternBase { PATTERN_DECL_NODE(mul_out); PATTERN_DECL_NODE(bias); PATTERN_DECL_NODE(add_out); + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_mean); + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_var); + PATTERN_DECL_NODE(bn_out); + PATTERN_DECL_NODE(bn_var_out); + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_saved_var); + PATTERN_DECL_NODE(bn_saved_mean); PATTERN_DECL_NODE(act_out); private: std::string mul_type_; bool with_bias_{false}; + bool with_bn_{false}; std::string act_type_; }; @@ -68,10 +80,12 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, const std::string& name_scope, const std::string& mul_type, bool with_bias, + bool with_bn, const std::string& act_type) : PatternBase(pattern, name_scope, name_scope), mul_type_(mul_type), with_bias_(with_bias), + with_bn_(with_bn), act_type_(act_type) { auto* mul_x = pattern->NewNode(mul_x_repr()) ->assert_is_op_input(mul_type_, "X") @@ -118,13 +132,57 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, } else { add_out = mul_out; } + PDNode* bn = nullptr; + PDNode* bn_bias = nullptr; + PDNode* bn_mean = nullptr; + PDNode* bn_scale = nullptr; + PDNode* bn_var = nullptr; + PDNode* bn_out = nullptr; + PDNode* bn_mean_out = nullptr; + PDNode* bn_saved_mean = nullptr; + PDNode* bn_var_out = nullptr; + PDNode* bn_saved_var = nullptr; + if (with_bn_) { + add_out->assert_is_op_input("batch_norm", "X"); + bn_bias = pattern->NewNode(bn_bias_repr()) + ->assert_is_op_input("batch_norm", "Bias") + ->assert_has_n_outputs(1); + bn_mean = pattern->NewNode(bn_mean_repr()) + ->assert_is_op_input("batch_norm", "Mean") + ->assert_has_n_outputs(1); + bn_scale = pattern->NewNode(bn_scale_repr()) + ->assert_is_op_input("batch_norm", "Scale") + ->assert_has_n_outputs(1); + bn_var = pattern->NewNode(bn_var_repr()) + ->assert_is_op_input("batch_norm", "Variance") + ->assert_has_n_outputs(1); + bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); + bn_out = + pattern->NewNode(bn_out_repr())->assert_is_op_output("batch_norm", "Y"); + if (!act_type_.empty()) { + bn_out->assert_has_n_outputs(1); + } + bn_mean_out = pattern->NewNode(bn_mean_out_repr()) + ->assert_is_op_output("batch_norm", "MeanOut"); + bn_saved_mean = pattern->NewNode(bn_saved_mean_repr()) + ->assert_is_op_output("batch_norm", "SavedMean"); + bn_var_out = pattern->NewNode(bn_var_out_repr()) + ->assert_is_op_output("batch_norm", "VarianceOut"); + bn_saved_var = pattern->NewNode(bn_saved_var_repr()) + ->assert_is_op_output("batch_norm", "SavedVariance"); + bn->LinksFrom({add_out, bn_bias, bn_mean, bn_scale, bn_var}) + .LinksTo( + {bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var}); + } else { + bn_out = add_out; + } if (!act_type_.empty()) { - add_out->assert_is_op_input(act_type_, "X"); + bn_out->assert_is_op_input(act_type_, "X"); act = pattern->NewNode(act_repr())->assert_is_op(act_type_); act_out = pattern->NewNode(act_out_repr()) ->assert_is_op_output(act_type_, "Out") ->assert_var_not_persistable(); - act->LinksFrom({add_out}).LinksTo({act_out}); + act->LinksFrom({bn_out}).LinksTo({act_out}); } } @@ -151,6 +209,12 @@ Origin subgraph: elementwise_add_out | | + batch_norm + | + | + batch_norm_out + | + | act | | @@ -174,6 +238,7 @@ class FcXPUFusePass : public FusePassBase { int ApplyImpl(ir::Graph* graph, const std::string& mul_type, bool with_bias, + bool with_bn, const std::string& act_type) const; const std::string name_scope_{"fc_xpu_fuse_pass"}; @@ -187,13 +252,16 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { int found_subgraph_count = 0; for (auto mul_type : {"mul", "matmul", "matmul_v2"}) { for (auto with_bias : {true, false}) { - for (auto act_type : { - "relu", - "gelu", - "tanh", - "", - }) { - found_subgraph_count += ApplyImpl(graph, mul_type, with_bias, act_type); + for (auto with_bn : {true, false}) { + for (auto act_type : { + "relu", + "gelu", + "tanh", + "", + }) { + found_subgraph_count += + ApplyImpl(graph, mul_type, with_bias, with_bn, act_type); + } } } } @@ -203,10 +271,15 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { int FcXPUFusePass::ApplyImpl(ir::Graph* graph, const std::string& mul_type, bool with_bias, + bool with_bn, const std::string& act_type) const { GraphPatternDetector gpd; - patterns::FcXPUPattern pattern( - gpd.mutable_pattern(), name_scope_, mul_type, with_bias, act_type); + patterns::FcXPUPattern pattern(gpd.mutable_pattern(), + name_scope_, + mul_type, + with_bias, + with_bn, + act_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -219,30 +292,100 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(bias); GET_IR_NODE(add); GET_IR_NODE(add_out); + GET_IR_NODE(bn); + GET_IR_NODE(bn_bias); + GET_IR_NODE(bn_mean); + GET_IR_NODE(bn_scale); + GET_IR_NODE(bn_var); + GET_IR_NODE(bn_out); + GET_IR_NODE(bn_var_out); + GET_IR_NODE(bn_mean_out); + GET_IR_NODE(bn_saved_var); + GET_IR_NODE(bn_saved_mean); GET_IR_NODE(act); GET_IR_NODE(act_out); auto* block = mul->Op()->Block(); auto* scope = param_scope(); + auto* filter_t = + scope->FindVar(mul_w->Name())->GetMutable(); + // filter fp16 --> fp32 + auto tensor_type = filter_t->dtype(); + if (tensor_type == phi::DataType::FLOAT16) { + CastToFp32(filter_t, nullptr); + } + auto filter_dims = filter_t->dims(); + bool transpose_w = false; if (mul_type == "matmul") { transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); } else if (mul_type == "matmul_v2") { transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); } + + bool has_bias = with_bn || with_bias; + Node* fusion_bias_node = nullptr; + if (has_bias) { + if (bias != nullptr) { + PrepareBias(graph, scope, block, bias, &fusion_bias_node); + } + if (bn != nullptr) { + auto bn_bias_t = + scope->Var(bn_bias->Name())->GetMutable(); + auto bn_scale_t = + scope->Var(bn_scale->Name())->GetMutable(); + auto bn_mean_t = + scope->Var(bn_mean->Name())->GetMutable(); + auto bn_var_t = + scope->Var(bn_var->Name())->GetMutable(); + float* mul_w_ptr = filter_t->data(); + float* bn_scale_ptr = bn_scale_t->data(); + float* bn_bias_ptr = bn_bias_t->data(); + float* bn_mean_ptr = bn_mean_t->data(); + float* bn_var_ptr = bn_var_t->data(); + auto mean_len = bn_mean_t->numel(); + auto filter_h = filter_dims[0]; + auto filter_w = filter_dims[1]; + float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); + if (fusion_bias_node == nullptr) { // prev node is conv + PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); + } + auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) + ->GetMutable(); + float* fusion_bias_ptr = fusion_bias_t->data(); + // recompute bias and weights + if (bias == nullptr) { + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + for (int j = 0; j < filter_h; j++) { + mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i]; + } + } + } else { + for (int i = 0; i < mean_len; ++i) { + bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); + bn_bias_ptr[i] += + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; + for (int j = 0; j < filter_h; j++) { + mul_w_ptr[j * filter_w + i] *= bn_scale_ptr[i]; + } + } + memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float)); + } + } + } + Node* mul_w_int16 = nullptr; Node* mul_w_max = nullptr; PrepareWeight( graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w); - Node* bias_fp32 = nullptr; - if (bias != nullptr) { - PrepareBias(graph, scope, block, bias, &bias_fp32); - } - std::string fc_out_name; if (act_out) { fc_out_name = act_out->Name(); + } else if (bn) { + fc_out_name = bn_out->Name(); } else if (add_out) { fc_out_name = add_out->Name(); } else { @@ -258,8 +401,8 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()}); fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()}); - if (bias_fp32) { - fc_xpu_op_desc.SetInput("bias", {bias_fp32->Name()}); + if (has_bias) { + fc_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); } fc_xpu_op_desc.SetAttr( "in_num_col_dims", @@ -294,9 +437,13 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, IR_NODE_LINK_TO(mul_x, fc_xpu); IR_NODE_LINK_TO(mul_w_int16, fc_xpu); IR_NODE_LINK_TO(mul_w_max, fc_xpu); - SAFE_IR_NODE_LINK_TO(bias_fp32, fc_xpu); + if (bias || bn) { + SAFE_IR_NODE_LINK_TO(fusion_bias_node, fc_xpu); + } if (act_out) { IR_NODE_LINK_TO(fc_xpu, act_out); + } else if (bn_out) { + IR_NODE_LINK_TO(fc_xpu, bn_out); } else if (add_out) { IR_NODE_LINK_TO(fc_xpu, add_out); } else { @@ -315,6 +462,17 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, } else { delete_nodes = {mul}; } + if (bn != nullptr) { + delete_nodes.insert(bn); + delete_nodes.insert(bn_bias); + delete_nodes.insert(bn_var); + delete_nodes.insert(bn_mean); + delete_nodes.insert(bn_scale); + delete_nodes.insert(bn_var_out); + delete_nodes.insert(bn_mean_out); + delete_nodes.insert(bn_saved_var); + delete_nodes.insert(bn_saved_mean); + } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; };