提交 74a6980e 编写于 作者: Y Yuan Shuai 提交者: Xiaoyang LI

Fix conv_bn fuser with no elemwise op added, Fix conv_elem with original conv...

Fix conv_bn fuser with no elemwise op added, Fix conv_elem with original conv with conv_bias (#2211)

* Fix conv_bn fuser with no elemwise op added. test=develop

* fix match not bug. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn fuse pass. test=develop

* Fix conv-bn fuser consider the case: enable_int8=true without conv_bias. test=develop

* Fix back, consider enable_int8. test=develop

* Fix Bug for conv-bn quant pass. test=develop

* Fix conv-elemwise considering origin conv with conv_bias. test=develop

* Fix code format. test=develop

* simplify. test=develop

* simplify. test=develop

* Fix conv_elem. test=develop
上级 0f9f5fe6
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "lite/core/mir/fusion/conv_bn_fuser.h" #include "lite/core/mir/fusion/conv_bn_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
namespace paddle { namespace paddle {
...@@ -23,11 +24,19 @@ namespace lite { ...@@ -23,11 +24,19 @@ namespace lite {
namespace mir { namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvBNFuser fuser("conv2d"); // initialze fuser params
fuser(graph.get()); std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
fusion::ConvBNFuser fuser2("depthwise_conv2d"); // start fuse using params
fuser2(graph.get()); for (auto conv_has_bias : conv_has_bias_cases) {
for (auto conv_type : conv_type_cases) {
VLOG(4) << "conv_has_bias:" << conv_has_bias
<< " conv_type:" << conv_type;
fusion::ConvBNFuser fuser(conv_type, conv_has_bias);
fuser(graph.get());
}
}
} }
} // namespace mir } // namespace mir
...@@ -35,5 +44,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -35,5 +44,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)}) .BindTargets({TARGET(kAny)});
.BindKernel("elementwise_add");
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "lite/core/mir/fusion/conv_bn_fuser.h" #include "lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -30,7 +31,8 @@ void ConvBNFuser::BuildPattern() { ...@@ -30,7 +31,8 @@ void ConvBNFuser::BuildPattern() {
auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
auto* conv_out = VarNode("conv_out") auto* conv_out = VarNode("conv_out")
->assert_is_op_output(conv_type_, "Output") ->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("batch_norm", "X"); ->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* bn_scale = VarNode("bn_scale") auto* bn_scale = VarNode("bn_scale")
->assert_is_op_input("batch_norm", "Scale") ->assert_is_op_input("batch_norm", "Scale")
...@@ -61,34 +63,30 @@ void ConvBNFuser::BuildPattern() { ...@@ -61,34 +63,30 @@ void ConvBNFuser::BuildPattern() {
->assert_is_op_output("batch_norm", "SavedVariance") ->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate(); ->AsIntermediate();
if (conv_has_bias_) {
auto* conv_bias = VarNode("conv_bias")
->assert_is_op_input(conv_type_, "Bias")
->AsInput()
->AsIntermediate();
conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out});
} else {
conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out});
}
bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var}) bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var})
.LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out}); .LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out});
} }
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv_instruct = matched.at("conv2d")->stmt(); auto conv_instruct = matched.at("conv2d")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
auto conv = conv_instruct->op(); auto conv = conv_instruct->op();
auto* scope = conv->scope(); auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_weight_dims = conv_weight_t->dims();
size_t weight_num = conv_weight_t->data_size();
// bn
auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name) auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name)
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
size_t bias_size = bn_scale_t->data_size();
auto bn_scale_d = bn_scale_t->mutable_data<float>(); auto bn_scale_d = bn_scale_t->mutable_data<float>();
CHECK_EQ(bias_size, static_cast<size_t>(conv_weight_dims[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name)
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
auto bn_mean_d = bn_mean_t->mutable_data<float>(); auto bn_mean_d = bn_mean_t->mutable_data<float>();
...@@ -102,59 +100,102 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -102,59 +100,102 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto bn_bias_d = bn_bias_t->mutable_data<float>(); auto bn_bias_d = bn_bias_t->mutable_data<float>();
auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon"); auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon");
auto conv_op_desc = conv_instruct->mutable_op_info(); // conv
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
// comupte BN alpha and beta
Tensor alpha_tensor, beta_tensor; Tensor alpha_tensor, beta_tensor;
alpha_tensor.CopyDataFrom(*bn_bias_t); alpha_tensor.CopyDataFrom(*bn_bias_t);
beta_tensor.CopyDataFrom(*bn_bias_t); beta_tensor.CopyDataFrom(*bn_bias_t);
auto alpha_data = alpha_tensor.mutable_data<float>(); auto alpha_data = alpha_tensor.mutable_data<float>();
auto beta_data = beta_tensor.mutable_data<float>(); auto beta_data = beta_tensor.mutable_data<float>();
int h = bias_size; int h =
int w = weight_num / bias_size; bn_scale_t
->data_size(); // h == bias_size == out channel num of conv weight
int w = weight_num /
(bn_scale_t->data_size()); // w = `conv_weight_num` / bias_size = in
// channel num of conv weight
ComputeAlphaAndBeta( ComputeAlphaAndBeta(
bn_scale_d, bn_mean_d, bn_var_d, alpha_data, beta_data, eps, h, w); bn_scale_d, bn_mean_d, bn_var_d, alpha_data, beta_data, eps, h, w);
///////////////////////////////////////////////////////////////////////////////
// Compute ConvBNFuser
// Before fusion
//
// conv(x) = conv(x) = kx + z = y
// bn(y) = ay + b
//
// Note: `alpha_data` is a, `beta_data` is b from `ComputeAlphaAndBeta`
//
// After fusion:
//
// bn(conv(x)) = a(kx + z) + b = akx + az + b
//
// Note: h == bias_size == out channel num of conv weight
// w = `conv_weight_num` / bias_size = in channel num of conv weight
// little difference for int8
///////////////////////////////////////////////////////////////////////////////
if (enable_int8) { if (enable_int8) {
PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"), PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"),
"INT8 mode: Conv should has weight_scale attr"); "INT8 mode: Conv should has weight_scale attr");
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8
auto weight_scale = auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale"); conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
for (int i = 0; i < h; i++) { for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= alpha_data[i]; weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + i * w;
for (unsigned int j = 0; j < w; ++j) {
ptr_row[j] *= -1;
}
}
} }
// Interface like this should be abandoned.
conv_op_desc->SetAttr("weight_scale", weight_scale); conv_op_desc->SetAttr("weight_scale", weight_scale);
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
} else { } else {
// compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>(); auto conv_weight_d = conv_weight_t->mutable_data<float>();
for (int i = 0; i < h; i++) { for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
for (int j = 0; j < w; j++) { for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels
conv_weight_d[i * w + j] *= alpha_data[i]; conv_weight_d[i * w + j] *= alpha_data[i];
} }
} }
} }
for (int i = 0; i < bias_size; i++) {
// compute new conv_bias
if (conv_has_bias_) {
auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_bias_d = conv_bias_t->data<float>();
for (unsigned int i = 0; i < bn_bias_t->data_size();
++i) { // bias_size == h == conv2d output channls
bn_bias_d[i] += alpha_data[i] * conv_bias_d[i];
}
}
for (unsigned int i = 0; i < bn_bias_t->data_size(); ++i) {
bn_bias_d[i] += beta_data[i]; bn_bias_d[i] += beta_data[i];
} }
eltwise_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places);
IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node); conv_op_desc->SetType(conv_type_);
IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node); conv_op_desc->SetInput("Input", {matched.at("conv_input")->arg()->name});
IR_NODE_LINK_TO(new_op_node, matched.at("bn_out")); conv_op_desc->SetInput("Filter", {matched.at("conv_weight")->arg()->name});
} conv_op_desc->SetOutput("Output", {matched.at("bn_out")->arg()->name});
conv_op_desc->SetInput("Bias",
{matched.at("bn_bias")->arg()->name}); // conv_bias
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) { IR_NODE_LINK_TO(matched.at("bn_bias"), matched.at("conv2d"));
cpp::OpDesc op_desc; IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("bn_out"));
op_desc.SetType("elementwise_add");
op_desc.SetInput("X", {matched.at("conv_out")->arg()->name});
op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name});
op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name});
op_desc.SetAttr("axis", 1);
return op_desc;
} }
} // namespace fusion } // namespace fusion
......
...@@ -27,12 +27,12 @@ namespace fusion { ...@@ -27,12 +27,12 @@ namespace fusion {
class ConvBNFuser : public FuseBase { class ConvBNFuser : public FuseBase {
public: public:
explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {} explicit ConvBNFuser(const std::string& conv_type, const bool conv_has_bias)
: conv_type_(conv_type), conv_has_bias_(conv_has_bias) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
void ComputeAlphaAndBeta(float* scale_d, void ComputeAlphaAndBeta(float* scale_d,
float* mean_d, float* mean_d,
float* var_d, float* var_d,
...@@ -51,6 +51,7 @@ class ConvBNFuser : public FuseBase { ...@@ -51,6 +51,7 @@ class ConvBNFuser : public FuseBase {
private: private:
std::string conv_type_{"conv2d"}; std::string conv_type_{"conv2d"};
bool conv_has_bias_{false};
}; };
} // namespace fusion } // namespace fusion
......
...@@ -23,14 +23,21 @@ namespace lite { ...@@ -23,14 +23,21 @@ namespace lite {
namespace mir { namespace mir {
void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseFuser fuser("conv2d"); // initialze fuser params
fuser(graph.get()); // note: `true` of conv_has_bias must as first pattern to match
std::vector<bool> conv_has_bias_cases{true, false};
fusion::ConvElementwiseFuser depthwise_fuser("depthwise_conv2d"); std::vector<std::string> conv_type_cases{
depthwise_fuser(graph.get()); "conv2d", "depthwise_conv2d", "conv2d_transpose"};
fusion::ConvElementwiseFuser conv2d_transpose_fuser("conv2d_transpose"); // start fuse using params
conv2d_transpose_fuser(graph.get()); for (auto conv_has_bias : conv_has_bias_cases) {
for (auto conv_type : conv_type_cases) {
VLOG(4) << "conv_has_bias:" << conv_has_bias
<< " conv_type:" << conv_type;
fusion::ConvElementwiseFuser fuser(conv_type, conv_has_bias);
fuser(graph.get());
}
}
} }
} // namespace mir } // namespace mir
......
...@@ -33,8 +33,7 @@ void ConvElementwiseFuser::BuildPattern() { ...@@ -33,8 +33,7 @@ void ConvElementwiseFuser::BuildPattern() {
->assert_is_persistable_var(); ->assert_is_persistable_var();
// create op nodes // create op nodes
auto* conv2d = auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* add = OpNode("add", "elementwise_add") auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add") ->assert_is_op("elementwise_add")
->AsIntermediate(); ->AsIntermediate();
...@@ -51,6 +50,13 @@ void ConvElementwiseFuser::BuildPattern() { ...@@ -51,6 +50,13 @@ void ConvElementwiseFuser::BuildPattern() {
// create topology. // create topology.
std::vector<PMNode*> conv2d_inputs{filter, input}; std::vector<PMNode*> conv2d_inputs{filter, input};
// consider a special case: conv with bias
if (conv_has_bias_) {
PMNode* conv_bias = VarNode("conv_bias")
->assert_is_op_input(conv_type_, "Bias")
->AsIntermediate();
conv2d_inputs.emplace_back(conv_bias);
}
std::vector<PMNode*> add_inputs{conv2d_out, bias}; std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out; conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out; add_inputs >> *add >> *add_out;
...@@ -58,44 +64,49 @@ void ConvElementwiseFuser::BuildPattern() { ...@@ -58,44 +64,49 @@ void ConvElementwiseFuser::BuildPattern() {
void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph, void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto conv_instruct = matched.at("conv2d")->stmt();
auto conv_op = LiteOpRegistry::Global().Create(conv_type_); auto conv_op_desc = conv_instruct->mutable_op_info();
auto conv_old = matched.at("conv2d")->stmt()->op(); auto* scope = conv_instruct->op()->scope();
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); /////////////////////////////////////////////////////////////////////////////////////
// ConvElementwiseFuser
IR_NODE_LINK_TO(matched.at("input"), new_op_node); // if `conv_bias` existed, store previous old `conv_bias` to
IR_NODE_LINK_TO(matched.at("filter"), new_op_node); // `elemwise_bias`, and add `elementwise_add_bias` to `new_conv_bias`.
IR_NODE_LINK_TO(matched.at("bias"), new_op_node); // if `conv_bias` not existed, set `elementwise_add_bias` as
IR_NODE_LINK_TO(new_op_node, matched.at("output")); // `new_conv_bias`.
} /////////////////////////////////////////////////////////////////////////////////////
cpp::OpDesc ConvElementwiseFuser::GenOpDesc(const key2nodes_t& matched) { if (conv_has_bias_ == true && conv_op_desc->HasInput("Bias") &&
auto* desc = matched.at("conv2d")->stmt()->op_info(); conv_op_desc->Input("Bias").size() > 0) {
auto conv_bias_var = scope->FindVar(conv_op_desc->Input("Bias").front());
if (conv_bias_var != nullptr) {
// conv bias
auto conv_bias_t = &(conv_bias_var->Get<lite::Tensor>());
auto conv_bias_d = conv_bias_t->data<float>();
cpp::OpDesc op_desc = *desc; // elementwise_add bias
op_desc.SetType(conv_type_); auto elementwise_add_bias_t =
op_desc.SetInput("Input", {matched.at("input")->arg()->name}); scope->FindVar(matched.at("bias")->arg()->name)
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); ->GetMutable<lite::Tensor>();
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); auto elementwise_add_bias_d =
op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); elementwise_add_bias_t->mutable_data<float>();
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
if (std::find(input_arg_names.begin(), for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) {
input_arg_names.end(), elementwise_add_bias_d[i] += conv_bias_d[i];
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
} }
// Only consider strides, padding, groups, dilations for now }
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides")); }
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups")); conv_op_desc->SetType(conv_type_);
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations")); conv_op_desc->SetInput("Input", {matched.at("input")->arg()->name});
return op_desc; conv_op_desc->SetInput("Filter", {matched.at("filter")->arg()->name});
conv_op_desc->SetOutput("Output", {matched.at("output")->arg()->name});
conv_op_desc->SetInput("Bias", {matched.at("bias")->arg()->name});
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
IR_NODE_LINK_TO(matched.at("bias"), matched.at("conv2d"));
IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("output"));
} }
} // namespace fusion } // namespace fusion
......
...@@ -25,16 +25,18 @@ namespace fusion { ...@@ -25,16 +25,18 @@ namespace fusion {
class ConvElementwiseFuser : public FuseBase { class ConvElementwiseFuser : public FuseBase {
public: public:
explicit ConvElementwiseFuser(const std::string& conv_type) { explicit ConvElementwiseFuser(const std::string& conv_type,
const bool conv_has_bias) {
conv_type_ = conv_type; conv_type_ = conv_type;
conv_has_bias_ = conv_has_bias;
} }
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; std::string conv_type_{"conv2d"};
std::string conv_type_; bool conv_has_bias_{false};
}; };
} // namespace fusion } // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册