diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index d7e274b146f48b0fb5154a4e585eea724dd1dbc0..25c8a217251d25f0f7b4a37c4c656c535810b76e 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include #include "lite/core/mir/fusion/conv_bn_fuser.h" +#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" namespace paddle { @@ -23,11 +24,19 @@ namespace lite { namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { - fusion::ConvBNFuser fuser("conv2d"); - fuser(graph.get()); + // initialze fuser params + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; - fusion::ConvBNFuser fuser2("depthwise_conv2d"); - fuser2(graph.get()); + // start fuse using params + for (auto conv_has_bias : conv_has_bias_cases) { + for (auto conv_type : conv_type_cases) { + VLOG(4) << "conv_has_bias:" << conv_has_bias + << " conv_type:" << conv_type; + fusion::ConvBNFuser fuser(conv_type, conv_has_bias); + fuser(graph.get()); + } + } } } // namespace mir @@ -35,5 +44,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .BindTargets({TARGET(kAny)}) - .BindKernel("elementwise_add"); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 77ad8237fe8108c8b9d19d09bf45b724f6c0ca2d..099fc55583925988c37d85f965900f7e4dfa1e98 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -14,6 +14,7 @@ #include "lite/core/mir/fusion/conv_bn_fuser.h" #include +#include #include namespace paddle { @@ -30,7 +31,8 @@ void ConvBNFuser::BuildPattern() { auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); auto* conv_out = VarNode("conv_out") ->assert_is_op_output(conv_type_, "Output") - ->assert_is_op_input("batch_norm", "X"); + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); auto* bn_scale = VarNode("bn_scale") ->assert_is_op_input("batch_norm", "Scale") @@ -61,34 +63,30 @@ void ConvBNFuser::BuildPattern() { ->assert_is_op_output("batch_norm", "SavedVariance") ->AsIntermediate(); - conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + if (conv_has_bias_) { + auto* conv_bias = VarNode("conv_bias") + ->assert_is_op_input(conv_type_, "Bias") + ->AsInput() + ->AsIntermediate(); + conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out}); + } else { + conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + } bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var}) .LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out}); } void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto op_desc = GenOpDesc(matched); - auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); - auto conv_instruct = matched.at("conv2d")->stmt(); + auto conv_op_desc = conv_instruct->mutable_op_info(); auto conv = conv_instruct->op(); auto* scope = conv->scope(); - auto& valid_places = conv->valid_places(); - - auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) - ->GetMutable(); - auto conv_weight_dims = conv_weight_t->dims(); - size_t weight_num = conv_weight_t->data_size(); + // bn auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name) ->GetMutable(); - size_t bias_size = bn_scale_t->data_size(); auto bn_scale_d = bn_scale_t->mutable_data(); - CHECK_EQ(bias_size, static_cast(conv_weight_dims[0])) - << "The BN bias's size should be equal to the size of the first " - << "dim size of the conv weights"; - auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) ->GetMutable(); auto bn_mean_d = bn_mean_t->mutable_data(); @@ -102,59 +100,102 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto bn_bias_d = bn_bias_t->mutable_data(); auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); - auto conv_op_desc = conv_instruct->mutable_op_info(); - + // conv + auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) + ->GetMutable(); + CHECK_EQ(static_cast(bn_scale_t->data_size()), + static_cast(conv_weight_t->dims()[0])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + size_t weight_num = conv_weight_t->data_size(); bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; + + // comupte BN alpha and beta Tensor alpha_tensor, beta_tensor; alpha_tensor.CopyDataFrom(*bn_bias_t); beta_tensor.CopyDataFrom(*bn_bias_t); auto alpha_data = alpha_tensor.mutable_data(); auto beta_data = beta_tensor.mutable_data(); - int h = bias_size; - int w = weight_num / bias_size; + int h = + bn_scale_t + ->data_size(); // h == bias_size == out channel num of conv weight + int w = weight_num / + (bn_scale_t->data_size()); // w = `conv_weight_num` / bias_size = in + // channel num of conv weight + ComputeAlphaAndBeta( bn_scale_d, bn_mean_d, bn_var_d, alpha_data, beta_data, eps, h, w); + /////////////////////////////////////////////////////////////////////////////// + // Compute ConvBNFuser + // Before fusion + // + // conv(x) = conv(x) = kx + z = y + // bn(y) = ay + b + // + // Note: `alpha_data` is a, `beta_data` is b from `ComputeAlphaAndBeta` + // + // After fusion: + // + // bn(conv(x)) = a(kx + z) + b = akx + az + b + // + // Note: h == bias_size == out channel num of conv weight + // w = `conv_weight_num` / bias_size = in channel num of conv weight + // little difference for int8 + /////////////////////////////////////////////////////////////////////////////// if (enable_int8) { PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"), "INT8 mode: Conv should has weight_scale attr"); + auto conv_weight_d = conv_weight_t->mutable_data(); + // compute new conv_weight for int8 auto weight_scale = conv_op_desc->GetAttr>("weight_scale"); - for (int i = 0; i < h; i++) { - weight_scale[i] *= alpha_data[i]; + for (unsigned int i = 0; i < h; ++i) { + weight_scale[i] *= fabsf(alpha_data[i]); + if (alpha_data[i] < 0.f) { + auto ptr_row = conv_weight_d + i * w; + for (unsigned int j = 0; j < w; ++j) { + ptr_row[j] *= -1; + } + } } - // Interface like this should be abandoned. conv_op_desc->SetAttr("weight_scale", weight_scale); - auto update_conv_desc = *conv_instruct->mutable_op_info(); - conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); } else { + // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); - for (int i = 0; i < h; i++) { - for (int j = 0; j < w; j++) { + for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels + for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels conv_weight_d[i * w + j] *= alpha_data[i]; } } } - for (int i = 0; i < bias_size; i++) { + + // compute new conv_bias + if (conv_has_bias_) { + auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name) + ->GetMutable(); + auto conv_bias_d = conv_bias_t->data(); + for (unsigned int i = 0; i < bn_bias_t->data_size(); + ++i) { // bias_size == h == conv2d output channls + bn_bias_d[i] += alpha_data[i] * conv_bias_d[i]; + } + } + for (unsigned int i = 0; i < bn_bias_t->data_size(); ++i) { bn_bias_d[i] += beta_data[i]; } - eltwise_op->Attach(op_desc, scope); - auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places); - - IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node); - IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node); - IR_NODE_LINK_TO(new_op_node, matched.at("bn_out")); -} -cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) { - cpp::OpDesc op_desc; - op_desc.SetType("elementwise_add"); - op_desc.SetInput("X", {matched.at("conv_out")->arg()->name}); - op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name}); - op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name}); - op_desc.SetAttr("axis", 1); - return op_desc; + conv_op_desc->SetType(conv_type_); + conv_op_desc->SetInput("Input", {matched.at("conv_input")->arg()->name}); + conv_op_desc->SetInput("Filter", {matched.at("conv_weight")->arg()->name}); + conv_op_desc->SetOutput("Output", {matched.at("bn_out")->arg()->name}); + conv_op_desc->SetInput("Bias", + {matched.at("bn_bias")->arg()->name}); // conv_bias + auto update_conv_desc = *conv_instruct->mutable_op_info(); + conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("bn_bias"), matched.at("conv2d")); + IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("bn_out")); } } // namespace fusion diff --git a/lite/core/mir/fusion/conv_bn_fuser.h b/lite/core/mir/fusion/conv_bn_fuser.h index 575207e2ee12e3a68987aab3eb3325a78dc6cda1..8bd8c0ce0600bb68667d96d07d43fa3028b5a856 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.h +++ b/lite/core/mir/fusion/conv_bn_fuser.h @@ -27,12 +27,12 @@ namespace fusion { class ConvBNFuser : public FuseBase { public: - explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {} + explicit ConvBNFuser(const std::string& conv_type, const bool conv_has_bias) + : conv_type_(conv_type), conv_has_bias_(conv_has_bias) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: - cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; void ComputeAlphaAndBeta(float* scale_d, float* mean_d, float* var_d, @@ -51,6 +51,7 @@ class ConvBNFuser : public FuseBase { private: std::string conv_type_{"conv2d"}; + bool conv_has_bias_{false}; }; } // namespace fusion diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc index 2ff3631ba31a807f215822fa25198c39776ea572..fd9aadc5d01c2cb3b6c7a3e888503072a0798725 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -23,14 +23,21 @@ namespace lite { namespace mir { void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { - fusion::ConvElementwiseFuser fuser("conv2d"); - fuser(graph.get()); + // initialze fuser params + // note: `true` of conv_has_bias must as first pattern to match + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; - fusion::ConvElementwiseFuser depthwise_fuser("depthwise_conv2d"); - depthwise_fuser(graph.get()); - - fusion::ConvElementwiseFuser conv2d_transpose_fuser("conv2d_transpose"); - conv2d_transpose_fuser(graph.get()); + // start fuse using params + for (auto conv_has_bias : conv_has_bias_cases) { + for (auto conv_type : conv_type_cases) { + VLOG(4) << "conv_has_bias:" << conv_has_bias + << " conv_type:" << conv_type; + fusion::ConvElementwiseFuser fuser(conv_type, conv_has_bias); + fuser(graph.get()); + } + } } } // namespace mir diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc index abc78edda88e008945e9d184b02e5feef3e5a4b1..22ec1fa0d22378adf3776c6bb391f50fde376b7a 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -33,8 +33,7 @@ void ConvElementwiseFuser::BuildPattern() { ->assert_is_persistable_var(); // create op nodes - auto* conv2d = - OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); auto* add = OpNode("add", "elementwise_add") ->assert_is_op("elementwise_add") ->AsIntermediate(); @@ -51,6 +50,13 @@ void ConvElementwiseFuser::BuildPattern() { // create topology. std::vector conv2d_inputs{filter, input}; + // consider a special case: conv with bias + if (conv_has_bias_) { + PMNode* conv_bias = VarNode("conv_bias") + ->assert_is_op_input(conv_type_, "Bias") + ->AsIntermediate(); + conv2d_inputs.emplace_back(conv_bias); + } std::vector add_inputs{conv2d_out, bias}; conv2d_inputs >> *conv2d >> *conv2d_out; add_inputs >> *add >> *add_out; @@ -58,44 +64,49 @@ void ConvElementwiseFuser::BuildPattern() { void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto op_desc = GenOpDesc(matched); - auto conv_op = LiteOpRegistry::Global().Create(conv_type_); - auto conv_old = matched.at("conv2d")->stmt()->op(); - auto* scope = conv_old->scope(); - auto& valid_places = conv_old->valid_places(); - conv_op->Attach(op_desc, scope); - - auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + auto conv_instruct = matched.at("conv2d")->stmt(); + auto conv_op_desc = conv_instruct->mutable_op_info(); + auto* scope = conv_instruct->op()->scope(); - IR_NODE_LINK_TO(matched.at("input"), new_op_node); - IR_NODE_LINK_TO(matched.at("filter"), new_op_node); - IR_NODE_LINK_TO(matched.at("bias"), new_op_node); - IR_NODE_LINK_TO(new_op_node, matched.at("output")); -} + ///////////////////////////////////////////////////////////////////////////////////// + // ConvElementwiseFuser + // if `conv_bias` existed, store previous old `conv_bias` to + // `elemwise_bias`, and add `elementwise_add_bias` to `new_conv_bias`. + // if `conv_bias` not existed, set `elementwise_add_bias` as + // `new_conv_bias`. + ///////////////////////////////////////////////////////////////////////////////////// -cpp::OpDesc ConvElementwiseFuser::GenOpDesc(const key2nodes_t& matched) { - auto* desc = matched.at("conv2d")->stmt()->op_info(); + if (conv_has_bias_ == true && conv_op_desc->HasInput("Bias") && + conv_op_desc->Input("Bias").size() > 0) { + auto conv_bias_var = scope->FindVar(conv_op_desc->Input("Bias").front()); + if (conv_bias_var != nullptr) { + // conv bias + auto conv_bias_t = &(conv_bias_var->Get()); + auto conv_bias_d = conv_bias_t->data(); - cpp::OpDesc op_desc = *desc; - op_desc.SetType(conv_type_); - op_desc.SetInput("Input", {matched.at("input")->arg()->name}); - op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); - op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); - op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); - // Other inputs. See operators/conv_op.h - std::vector input_arg_names = desc->InputArgumentNames(); + // elementwise_add bias + auto elementwise_add_bias_t = + scope->FindVar(matched.at("bias")->arg()->name) + ->GetMutable(); + auto elementwise_add_bias_d = + elementwise_add_bias_t->mutable_data(); - if (std::find(input_arg_names.begin(), - input_arg_names.end(), - "ResidualData") != input_arg_names.end()) { - op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) { + elementwise_add_bias_d[i] += conv_bias_d[i]; + } + } } - // Only consider strides, padding, groups, dilations for now - op_desc.SetAttr("strides", desc->GetAttr>("strides")); - op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); - op_desc.SetAttr("groups", desc->GetAttr("groups")); - op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); - return op_desc; + + conv_op_desc->SetType(conv_type_); + conv_op_desc->SetInput("Input", {matched.at("input")->arg()->name}); + conv_op_desc->SetInput("Filter", {matched.at("filter")->arg()->name}); + conv_op_desc->SetOutput("Output", {matched.at("output")->arg()->name}); + conv_op_desc->SetInput("Bias", {matched.at("bias")->arg()->name}); + auto update_conv_desc = *conv_instruct->mutable_op_info(); + conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("bias"), matched.at("conv2d")); + IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("output")); } } // namespace fusion diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.h b/lite/core/mir/fusion/conv_elementwise_fuser.h index 4514fc5010b5c40f31a69c4459f0a26f33d6046a..fdcb5d8912d87c61f13c47e5ef07b926a96d7272 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.h +++ b/lite/core/mir/fusion/conv_elementwise_fuser.h @@ -25,16 +25,18 @@ namespace fusion { class ConvElementwiseFuser : public FuseBase { public: - explicit ConvElementwiseFuser(const std::string& conv_type) { + explicit ConvElementwiseFuser(const std::string& conv_type, + const bool conv_has_bias) { conv_type_ = conv_type; + conv_has_bias_ = conv_has_bias; } void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: - cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - std::string conv_type_; + std::string conv_type_{"conv2d"}; + bool conv_has_bias_{false}; }; } // namespace fusion