提交 1ce127cb 编写于 作者: S sangoly 提交者: GitHub

fix conv_bias_relu fuse bug & x86 conv kernel bug (#18098)

fix conv_bias_relu fuse bug,x86 conv kernel bug
上级 d1f539c8
...@@ -24,8 +24,11 @@ namespace mir { ...@@ -24,8 +24,11 @@ namespace mir {
void ConvElementwiseAddReLUFusePass::Apply( void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) { const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser; fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get()); fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -85,7 +85,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -85,7 +85,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
add_1->SetAttr("axis", 1); add_1->SetAttr("axis", 1);
relu_1->SetType("relu"); relu_1->SetType("relu");
relu_1->SetInput("Input", {"add_1_out"}); relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"}); relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d"); conv2d_2->SetType("conv2d");
...@@ -105,7 +105,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -105,7 +105,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
add_2->SetAttr("axis", 1); add_2->SetAttr("axis", 1);
relu_2->SetType("relu"); relu_2->SetType("relu");
relu_2->SetInput("Input", {"add_2_out"}); relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"}); relu_2->SetOutput("Out", {"out"});
program_desc->Flush(); program_desc->Flush();
......
...@@ -23,21 +23,33 @@ namespace fusion { ...@@ -23,21 +23,33 @@ namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() { void ConvElementwiseAddReLUFuser::BuildPattern() {
// create input nodes. // create input nodes.
auto* input = VarNode("input"); auto* input =
auto* filter = VarNode("filter"); VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* bias = VarNode("bias"); auto* filter =
VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput();
auto* bias =
VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput();
// create op nodes // create op nodes
auto* conv2d = OpNode("conv2d", "conv2d"); auto* conv2d =
auto* add = OpNode("add", "elementwise_add"); OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* relu = OpNode("relu", "relu"); auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate();
// create intermediate nodes // create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out"); auto* conv2d_out = VarNode("conv2d_out")
auto* add_out = VarNode("add_out"); ->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
// create output node // create output node
auto* out = VarNode("output"); auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput();
// create topology. // create topology.
std::vector<PMNode*> conv2d_inputs{filter, input}; std::vector<PMNode*> conv2d_inputs{filter, input};
...@@ -45,19 +57,12 @@ void ConvElementwiseAddReLUFuser::BuildPattern() { ...@@ -45,19 +57,12 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
conv2d_inputs >> *conv2d >> *conv2d_out; conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out; add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out; *add_out >> *relu >> *out;
// Some op specialities.
conv2d_out->AsIntermediate();
add_out->AsIntermediate();
conv2d->AsIntermediate();
add->AsIntermediate();
relu->AsIntermediate();
} }
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create("conv2d"); auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op; auto conv_old = matched.at("conv2d")->stmt()->op;
auto* scope = conv_old->scope(); auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places(); auto& valid_places = conv_old->valid_places();
...@@ -75,7 +80,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -75,7 +80,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info(); auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
op_desc.SetType("conv2d"); op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name}); op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
......
...@@ -25,11 +25,14 @@ namespace fusion { ...@@ -25,11 +25,14 @@ namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase { class ConvElementwiseAddReLUFuser : public FuseBase {
public: public:
explicit ConvElementwiseAddReLUFuser(const std::string& conv_type)
: conv_type_(conv_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
}; };
} // namespace fusion } // namespace fusion
......
...@@ -32,3 +32,4 @@ USE_MIR_PASS(io_copy_kernel_pick_pass); ...@@ -32,3 +32,4 @@ USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
...@@ -105,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -105,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch; lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
input_shape.data())); output_matrix_shape.data()));
for (int g = 0; g < param.groups; g++) { for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice; lite::Tensor in_slice;
......
...@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW, REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ReluCompute<float>, def) paddle::lite::kernels::x86::ReluCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -73,19 +73,22 @@ class ConvOpLite : public OpLite { ...@@ -73,19 +73,22 @@ class ConvOpLite : public OpLite {
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames(); std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) { input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front()); auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() != 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) { if (bias_var != nullptr) {
param_.bias = param_.bias = &bias_var->Get<lite::Tensor>();
const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>())); }
} }
} }
if (std::find(input_arg_names.begin(), input_arg_names.end(), if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) { "ResidualData") != input_arg_names.end()) {
auto residual_data_var = auto res_argument = op_desc.Input("ResidualData");
scope->FindVar(op_desc.Input("ResidualData").front()); if (res_argument.size() != 0) {
auto residual_data_var = scope->FindVar(res_argument.front());
if (residual_data_var != nullptr) { if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor*>( param_.residualData = &residual_data_var->Get<lite::Tensor>();
&(residual_data_var->Get<lite::Tensor>())); }
} }
} }
......
...@@ -124,8 +124,8 @@ struct ConcatParam { ...@@ -124,8 +124,8 @@ struct ConcatParam {
struct ConvParam { struct ConvParam {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* filter{}; lite::Tensor* filter{};
lite::Tensor* bias{}; const lite::Tensor* bias{};
lite::Tensor* residualData{}; const lite::Tensor* residualData{};
lite::Tensor* output{}; lite::Tensor* output{};
std::vector<int> strides{1, 1}; std::vector<int> strides{1, 1};
std::vector<int> paddings{0, 0}; std::vector<int> paddings{0, 0};
......
...@@ -32,7 +32,7 @@ bool ReluOp::InferShape() const { ...@@ -32,7 +32,7 @@ bool ReluOp::InferShape() const {
bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>( param_.input = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output = param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.input); CHECK(param_.input);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册