提交 74a6980e 编写于 作者: Y Yuan Shuai 提交者: Xiaoyang LI

Fix conv_bn fuser with no elemwise op added, Fix conv_elem with original conv...

Fix conv_bn fuser with no elemwise op added, Fix conv_elem with original conv with conv_bias (#2211)

* Fix conv_bn fuser with no elemwise op added. test=develop

* fix match not bug. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn pass. test=develop

* Fix conv-bn fuse pass. test=develop

* Fix conv-bn fuser consider the case: enable_int8=true without conv_bias. test=develop

* Fix back, consider enable_int8. test=develop

* Fix Bug for conv-bn quant pass. test=develop

* Fix conv-elemwise considering origin conv with conv_bias. test=develop

* Fix code format. test=develop

* simplify. test=develop

* simplify. test=develop

* Fix conv_elem. test=develop
上级 0f9f5fe6
......@@ -16,6 +16,7 @@
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_bn_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
......@@ -23,11 +24,19 @@ namespace lite {
namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvBNFuser fuser("conv2d");
fuser(graph.get());
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
fusion::ConvBNFuser fuser2("depthwise_conv2d");
fuser2(graph.get());
// start fuse using params
for (auto conv_has_bias : conv_has_bias_cases) {
for (auto conv_type : conv_type_cases) {
VLOG(4) << "conv_has_bias:" << conv_has_bias
<< " conv_type:" << conv_type;
fusion::ConvBNFuser fuser(conv_type, conv_has_bias);
fuser(graph.get());
}
}
}
} // namespace mir
......@@ -35,5 +44,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)})
.BindKernel("elementwise_add");
.BindTargets({TARGET(kAny)});
......@@ -14,6 +14,7 @@
#include "lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory>
#include <unordered_set>
#include <vector>
namespace paddle {
......@@ -30,7 +31,8 @@ void ConvBNFuser::BuildPattern() {
auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
auto* conv_out = VarNode("conv_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("batch_norm", "X");
->assert_is_op_input("batch_norm", "X")
->AsIntermediate();
auto* bn_scale = VarNode("bn_scale")
->assert_is_op_input("batch_norm", "Scale")
......@@ -61,34 +63,30 @@ void ConvBNFuser::BuildPattern() {
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
if (conv_has_bias_) {
auto* conv_bias = VarNode("conv_bias")
->assert_is_op_input(conv_type_, "Bias")
->AsInput()
->AsIntermediate();
conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out});
} else {
conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out});
}
bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var})
.LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out});
}
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv_instruct = matched.at("conv2d")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
auto conv = conv_instruct->op();
auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_weight_dims = conv_weight_t->dims();
size_t weight_num = conv_weight_t->data_size();
// bn
auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name)
->GetMutable<lite::Tensor>();
size_t bias_size = bn_scale_t->data_size();
auto bn_scale_d = bn_scale_t->mutable_data<float>();
CHECK_EQ(bias_size, static_cast<size_t>(conv_weight_dims[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_mean_d = bn_mean_t->mutable_data<float>();
......@@ -102,59 +100,102 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto bn_bias_d = bn_bias_t->mutable_data<float>();
auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon");
auto conv_op_desc = conv_instruct->mutable_op_info();
// conv
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
// comupte BN alpha and beta
Tensor alpha_tensor, beta_tensor;
alpha_tensor.CopyDataFrom(*bn_bias_t);
beta_tensor.CopyDataFrom(*bn_bias_t);
auto alpha_data = alpha_tensor.mutable_data<float>();
auto beta_data = beta_tensor.mutable_data<float>();
int h = bias_size;
int w = weight_num / bias_size;
int h =
bn_scale_t
->data_size(); // h == bias_size == out channel num of conv weight
int w = weight_num /
(bn_scale_t->data_size()); // w = `conv_weight_num` / bias_size = in
// channel num of conv weight
ComputeAlphaAndBeta(
bn_scale_d, bn_mean_d, bn_var_d, alpha_data, beta_data, eps, h, w);
///////////////////////////////////////////////////////////////////////////////
// Compute ConvBNFuser
// Before fusion
//
// conv(x) = conv(x) = kx + z = y
// bn(y) = ay + b
//
// Note: `alpha_data` is a, `beta_data` is b from `ComputeAlphaAndBeta`
//
// After fusion:
//
// bn(conv(x)) = a(kx + z) + b = akx + az + b
//
// Note: h == bias_size == out channel num of conv weight
// w = `conv_weight_num` / bias_size = in channel num of conv weight
// little difference for int8
///////////////////////////////////////////////////////////////////////////////
if (enable_int8) {
PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"),
"INT8 mode: Conv should has weight_scale attr");
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
// compute new conv_weight for int8
auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
for (int i = 0; i < h; i++) {
weight_scale[i] *= alpha_data[i];
for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + i * w;
for (unsigned int j = 0; j < w; ++j) {
ptr_row[j] *= -1;
}
}
}
// Interface like this should be abandoned.
conv_op_desc->SetAttr("weight_scale", weight_scale);
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
} else {
// compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>();
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels
conv_weight_d[i * w + j] *= alpha_data[i];
}
}
}
for (int i = 0; i < bias_size; i++) {
// compute new conv_bias
if (conv_has_bias_) {
auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_bias_d = conv_bias_t->data<float>();
for (unsigned int i = 0; i < bn_bias_t->data_size();
++i) { // bias_size == h == conv2d output channls
bn_bias_d[i] += alpha_data[i] * conv_bias_d[i];
}
}
for (unsigned int i = 0; i < bn_bias_t->data_size(); ++i) {
bn_bias_d[i] += beta_data[i];
}
eltwise_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places);
IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node);
IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("bn_out"));
}
conv_op_desc->SetType(conv_type_);
conv_op_desc->SetInput("Input", {matched.at("conv_input")->arg()->name});
conv_op_desc->SetInput("Filter", {matched.at("conv_weight")->arg()->name});
conv_op_desc->SetOutput("Output", {matched.at("bn_out")->arg()->name});
conv_op_desc->SetInput("Bias",
{matched.at("bn_bias")->arg()->name}); // conv_bias
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("elementwise_add");
op_desc.SetInput("X", {matched.at("conv_out")->arg()->name});
op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name});
op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name});
op_desc.SetAttr("axis", 1);
return op_desc;
IR_NODE_LINK_TO(matched.at("bn_bias"), matched.at("conv2d"));
IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("bn_out"));
}
} // namespace fusion
......
......@@ -27,12 +27,12 @@ namespace fusion {
class ConvBNFuser : public FuseBase {
public:
explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {}
explicit ConvBNFuser(const std::string& conv_type, const bool conv_has_bias)
: conv_type_(conv_type), conv_has_bias_(conv_has_bias) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
void ComputeAlphaAndBeta(float* scale_d,
float* mean_d,
float* var_d,
......@@ -51,6 +51,7 @@ class ConvBNFuser : public FuseBase {
private:
std::string conv_type_{"conv2d"};
bool conv_has_bias_{false};
};
} // namespace fusion
......
......@@ -23,14 +23,21 @@ namespace lite {
namespace mir {
void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
// initialze fuser params
// note: `true` of conv_has_bias must as first pattern to match
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{
"conv2d", "depthwise_conv2d", "conv2d_transpose"};
fusion::ConvElementwiseFuser conv2d_transpose_fuser("conv2d_transpose");
conv2d_transpose_fuser(graph.get());
// start fuse using params
for (auto conv_has_bias : conv_has_bias_cases) {
for (auto conv_type : conv_type_cases) {
VLOG(4) << "conv_has_bias:" << conv_has_bias
<< " conv_type:" << conv_type;
fusion::ConvElementwiseFuser fuser(conv_type, conv_has_bias);
fuser(graph.get());
}
}
}
} // namespace mir
......
......@@ -33,8 +33,7 @@ void ConvElementwiseFuser::BuildPattern() {
->assert_is_persistable_var();
// create op nodes
auto* conv2d =
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
......@@ -51,6 +50,13 @@ void ConvElementwiseFuser::BuildPattern() {
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
// consider a special case: conv with bias
if (conv_has_bias_) {
PMNode* conv_bias = VarNode("conv_bias")
->assert_is_op_input(conv_type_, "Bias")
->AsIntermediate();
conv2d_inputs.emplace_back(conv_bias);
}
std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out;
......@@ -58,44 +64,49 @@ void ConvElementwiseFuser::BuildPattern() {
void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op();
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto conv_instruct = matched.at("conv2d")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
auto* scope = conv_instruct->op()->scope();
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places);
IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("filter"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
/////////////////////////////////////////////////////////////////////////////////////
// ConvElementwiseFuser
// if `conv_bias` existed, store previous old `conv_bias` to
// `elemwise_bias`, and add `elementwise_add_bias` to `new_conv_bias`.
// if `conv_bias` not existed, set `elementwise_add_bias` as
// `new_conv_bias`.
/////////////////////////////////////////////////////////////////////////////////////
cpp::OpDesc ConvElementwiseFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
if (conv_has_bias_ == true && conv_op_desc->HasInput("Bias") &&
conv_op_desc->Input("Bias").size() > 0) {
auto conv_bias_var = scope->FindVar(conv_op_desc->Input("Bias").front());
if (conv_bias_var != nullptr) {
// conv bias
auto conv_bias_t = &(conv_bias_var->Get<lite::Tensor>());
auto conv_bias_d = conv_bias_t->data<float>();
cpp::OpDesc op_desc = *desc;
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
// elementwise_add bias
auto elementwise_add_bias_t =
scope->FindVar(matched.at("bias")->arg()->name)
->GetMutable<lite::Tensor>();
auto elementwise_add_bias_d =
elementwise_add_bias_t->mutable_data<float>();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) {
elementwise_add_bias_d[i] += conv_bias_d[i];
}
// Only consider strides, padding, groups, dilations for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
return op_desc;
}
}
conv_op_desc->SetType(conv_type_);
conv_op_desc->SetInput("Input", {matched.at("input")->arg()->name});
conv_op_desc->SetInput("Filter", {matched.at("filter")->arg()->name});
conv_op_desc->SetOutput("Output", {matched.at("output")->arg()->name});
conv_op_desc->SetInput("Bias", {matched.at("bias")->arg()->name});
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
IR_NODE_LINK_TO(matched.at("bias"), matched.at("conv2d"));
IR_OP_VAR_LINK(matched.at("conv2d"), matched.at("output"));
}
} // namespace fusion
......
......@@ -25,16 +25,18 @@ namespace fusion {
class ConvElementwiseFuser : public FuseBase {
public:
explicit ConvElementwiseFuser(const std::string& conv_type) {
explicit ConvElementwiseFuser(const std::string& conv_type,
const bool conv_has_bias) {
conv_type_ = conv_type;
conv_has_bias_ = conv_has_bias;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
std::string conv_type_{"conv2d"};
bool conv_has_bias_{false};
};
} // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册