From 1ce127cb451fae9ef09744f77f49e9583eed05a2 Mon Sep 17 00:00:00 2001 From: sangoly Date: Fri, 14 Jun 2019 16:17:34 +0800 Subject: [PATCH] fix conv_bias_relu fuse bug & x86 conv kernel bug (#18098) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix conv_bias_relu fuse bug,x86 conv kernel bug --- .../conv_elementwise_add_relu_fuse_pass.cc | 5 ++- ...onv_elementwise_add_relu_fuse_pass_test.cc | 4 +- .../fusion/conv_elementwise_add_relu_fuser.cc | 41 +++++++++++-------- .../fusion/conv_elementwise_add_relu_fuser.h | 3 ++ paddle/fluid/lite/core/mir/passes.h | 1 + paddle/fluid/lite/kernels/x86/conv_compute.cc | 2 +- paddle/fluid/lite/kernels/x86/relu_compute.cc | 2 +- paddle/fluid/lite/operators/conv_op.h | 21 ++++++---- paddle/fluid/lite/operators/op_params.h | 4 +- paddle/fluid/lite/operators/relu_op.cc | 2 +- 10 files changed, 50 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc index 065e8ceca3f..3110c7aa6d4 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc @@ -24,8 +24,11 @@ namespace mir { void ConvElementwiseAddReLUFusePass::Apply( const std::unique_ptr& graph) { - fusion::ConvElementwiseAddReLUFuser fuser; + fusion::ConvElementwiseAddReLUFuser fuser("conv2d"); fuser(graph.get()); + + fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d"); + depthwise_fuser(graph.get()); } } // namespace mir diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc index 2cde3d25a69..30991313ad3 100644 --- a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc @@ -85,7 +85,7 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, add_1->SetAttr("axis", 1); relu_1->SetType("relu"); - relu_1->SetInput("Input", {"add_1_out"}); + relu_1->SetInput("X", {"add_1_out"}); relu_1->SetOutput("Out", {"relu_1_out"}); conv2d_2->SetType("conv2d"); @@ -105,7 +105,7 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, add_2->SetAttr("axis", 1); relu_2->SetType("relu"); - relu_2->SetInput("Input", {"add_2_out"}); + relu_2->SetInput("X", {"add_2_out"}); relu_2->SetOutput("Out", {"out"}); program_desc->Flush(); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc index c1322386348..497c8f4f0d3 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc @@ -23,21 +23,33 @@ namespace fusion { void ConvElementwiseAddReLUFuser::BuildPattern() { // create input nodes. - auto* input = VarNode("input"); - auto* filter = VarNode("filter"); - auto* bias = VarNode("bias"); + auto* input = + VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* filter = + VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); + auto* bias = + VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput(); // create op nodes - auto* conv2d = OpNode("conv2d", "conv2d"); - auto* add = OpNode("add", "elementwise_add"); - auto* relu = OpNode("relu", "relu"); + auto* conv2d = + OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + auto* add = OpNode("add", "elementwise_add") + ->assert_is_op("elementwise_add") + ->AsIntermediate(); + auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate(); // create intermediate nodes - auto* conv2d_out = VarNode("conv2d_out"); - auto* add_out = VarNode("add_out"); + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); // create output node - auto* out = VarNode("output"); + auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput(); // create topology. std::vector conv2d_inputs{filter, input}; @@ -45,19 +57,12 @@ void ConvElementwiseAddReLUFuser::BuildPattern() { conv2d_inputs >> *conv2d >> *conv2d_out; add_inputs >> *add >> *add_out; *add_out >> *relu >> *out; - - // Some op specialities. - conv2d_out->AsIntermediate(); - add_out->AsIntermediate(); - conv2d->AsIntermediate(); - add->AsIntermediate(); - relu->AsIntermediate(); } void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); - auto conv_op = LiteOpRegistry::Global().Create("conv2d"); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); auto conv_old = matched.at("conv2d")->stmt()->op; auto* scope = conv_old->scope(); auto& valid_places = conv_old->valid_places(); @@ -75,7 +80,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { auto* desc = matched.at("conv2d")->stmt()->op_info(); cpp::OpDesc op_desc; - op_desc.SetType("conv2d"); + op_desc.SetType(conv_type_); op_desc.SetInput("Input", {matched.at("input")->arg()->name}); op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h index 5ba0ee26841..3e21368234f 100644 --- a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h @@ -25,11 +25,14 @@ namespace fusion { class ConvElementwiseAddReLUFuser : public FuseBase { public: + explicit ConvElementwiseAddReLUFuser(const std::string& conv_type) + : conv_type_(conv_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string conv_type_; }; } // namespace fusion diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index b65e1d53d07..6e329a19227 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -32,3 +32,4 @@ USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); +USE_MIR_PASS(graph_visualze); diff --git a/paddle/fluid/lite/kernels/x86/conv_compute.cc b/paddle/fluid/lite/kernels/x86/conv_compute.cc index 35d0de82de4..b29161c1c60 100644 --- a/paddle/fluid/lite/kernels/x86/conv_compute.cc +++ b/paddle/fluid/lite/kernels/x86/conv_compute.cc @@ -105,7 +105,7 @@ class Conv2dCompute : public KernelLite { param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); lite::Tensor out_batch; out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( - input_shape.data())); + output_matrix_shape.data())); for (int g = 0; g < param.groups; g++) { lite::Tensor in_slice; diff --git a/paddle/fluid/lite/kernels/x86/relu_compute.cc b/paddle/fluid/lite/kernels/x86/relu_compute.cc index 44b1f525ab0..52fffb57981 100644 --- a/paddle/fluid/lite/kernels/x86/relu_compute.cc +++ b/paddle/fluid/lite/kernels/x86/relu_compute.cc @@ -51,6 +51,6 @@ class ReluCompute : public KernelLite { REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::ReluCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h index e5ad7fe67f9..79726e0284b 100644 --- a/paddle/fluid/lite/operators/conv_op.h +++ b/paddle/fluid/lite/operators/conv_op.h @@ -73,19 +73,22 @@ class ConvOpLite : public OpLite { std::vector input_arg_names = op_desc.InputArgumentNames(); if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != input_arg_names.end()) { - auto bias_var = scope->FindVar(op_desc.Input("Bias").front()); - if (bias_var != nullptr) { - param_.bias = - const_cast(&(bias_var->Get())); + auto bias_arguments = op_desc.Input("Bias"); + if (bias_arguments.size() != 0) { + auto bias_var = scope->FindVar(bias_arguments.front()); + if (bias_var != nullptr) { + param_.bias = &bias_var->Get(); + } } } if (std::find(input_arg_names.begin(), input_arg_names.end(), "ResidualData") != input_arg_names.end()) { - auto residual_data_var = - scope->FindVar(op_desc.Input("ResidualData").front()); - if (residual_data_var != nullptr) { - param_.residualData = const_cast( - &(residual_data_var->Get())); + auto res_argument = op_desc.Input("ResidualData"); + if (res_argument.size() != 0) { + auto residual_data_var = scope->FindVar(res_argument.front()); + if (residual_data_var != nullptr) { + param_.residualData = &residual_data_var->Get(); + } } } diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 87b4c8dd5fb..78df0ce8a7a 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -124,8 +124,8 @@ struct ConcatParam { struct ConvParam { lite::Tensor* x{}; lite::Tensor* filter{}; - lite::Tensor* bias{}; - lite::Tensor* residualData{}; + const lite::Tensor* bias{}; + const lite::Tensor* residualData{}; lite::Tensor* output{}; std::vector strides{1, 1}; std::vector paddings{0, 0}; diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index a588b1c8cbf..47251c72dfa 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -32,7 +32,7 @@ bool ReluOp::InferShape() const { bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.input = const_cast( - &scope->FindVar(opdesc.Input("Input").front())->Get()); + &scope->FindVar(opdesc.Input("X").front())->Get()); param_.output = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); CHECK(param_.input); -- GitLab