From 810eb08974c0efe49150645b05c44c2421d9d62c Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 20 Dec 2019 11:23:12 +0800 Subject: [PATCH] add var_conv_2d_relu pass test=develop (#2631) add var_conv_2d + relu fuse pass --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + lite/core/mir/fusion/CMakeLists.txt | 4 + .../var_conv_2d_activation_fuse_pass.cc | 40 ++++++++++ .../fusion/var_conv_2d_activation_fuse_pass.h | 32 ++++++++ .../fusion/var_conv_2d_activation_fuser.cc | 80 +++++++++++++++++++ .../mir/fusion/var_conv_2d_activation_fuser.h | 44 ++++++++++ lite/core/optimizer.h | 1 + lite/kernels/cuda/var_conv_2d_compute.cu | 4 + lite/operators/op_params.h | 2 + lite/operators/var_conv_2d_op.cc | 4 + 11 files changed, 213 insertions(+) create mode 100644 lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc create mode 100644 lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h create mode 100644 lite/core/mir/fusion/var_conv_2d_activation_fuser.cc create mode 100644 lite/core/mir/fusion/var_conv_2d_activation_fuser.h diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index ba3e056627..ac29cdda01 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -34,6 +34,7 @@ USE_MIR_PASS(lite_interpolate_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); +USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 8d5b3e6876..810ff0f875 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -16,6 +16,7 @@ lite_cc_library(mir_passes fusion/interpolate_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc fusion/conv_activation_fuse_pass.cc + fusion/var_conv_2d_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 5ac5283755..8699470955 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -10,6 +10,9 @@ lite_cc_library(fuse_conv_elementwise lite_cc_library(fuse_conv_activation SRCS conv_activation_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_var_conv_activation + SRCS var_conv_2d_activation_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) @@ -31,6 +34,7 @@ set(mir_fusers fuse_shuffle_channel fuse_conv_elementwise fuse_conv_activation + fuse_var_conv_activation fuse_conv_bn fuse_quant_dequant fuse_elementwise_add_activation diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc new file mode 100644 index 0000000000..0ce2248cbc --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void VarConv2dActivationFusePass::Apply( + const std::unique_ptr& graph) { + std::vector act_types{"relu"}; + for (auto act_type : act_types) { + fusion::VarConvActivationFuser fuser(act_type, "var_conv_2d"); + fuser(graph.get()); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_var_conv_2d_activation_fuse_pass, + paddle::lite::mir::VarConv2dActivationFusePass) + .BindTargets({TARGET(kCUDA)}); diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h new file mode 100644 index 0000000000..7616aadef3 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class VarConv2dActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc new file mode 100644 index 0000000000..eabd97ae45 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void VarConvActivationFuser::BuildPattern() { + // create nodes. + auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput(); + auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput(); + + auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate(); + + auto* act = OpNode("act", act_type_)->AsIntermediate(); + + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Out") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + auto* conv2d_out_1 = VarNode("conv2d_out_1") + ->assert_is_op_output(conv_type_, "Col") + ->AsIntermediate(); + + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input}; + conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; + *conv2d >> *conv2d_out_1; +} + +void VarConvActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("var_conv_2d")->stmt()->op(); + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("X"), new_op_node); + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc VarConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("var_conv_2d")->stmt()->op_info(); + op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); + cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info(); + + if (act_type_ == "relu") { + op_desc.SetAttr("fuse_relu", true); + } + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.h b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h new file mode 100644 index 0000000000..68bc89f7d1 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class VarConvActivationFuser : public FuseBase { + public: + explicit VarConvActivationFuser(const std::string& act_type, + const std::string& conv_type) + : act_type_(act_type), conv_type_(conv_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string act_type_; + std::string conv_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 6b2d7f5b18..58060f5bb5 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -62,6 +62,7 @@ class Optimizer { // TODO(Superjomn) Refine the fusion related design to select fusion // kernels for devices automatically. "lite_conv_activation_fuse_pass", // + "lite_var_conv_2d_activation_fuse_pass", // "lite_fc_fuse_pass", // "lite_shuffle_channel_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", // diff --git a/lite/kernels/cuda/var_conv_2d_compute.cu b/lite/kernels/cuda/var_conv_2d_compute.cu index 92a59876a6..1e42635934 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.cu +++ b/lite/kernels/cuda/var_conv_2d_compute.cu @@ -73,6 +73,10 @@ void VarConv2DCompute::PrepareForRun() { (*conv_param_.paddings.get())[i * 2 + 1], conv_param_.strides[i])); } + if (param.fuse_relu) { + conv_param_.activation_param.has_active = true; + conv_param_.activation_param.active_type = lite_api::ActivationType::kRelu; + } conv_param_.output->Resize({output_shape}); conv_impl_.reset(new lite::cuda::math::CudnnConv2D); conv_impl_->init(conv_param_, &context); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index dc010b96c0..bd2ba937ea 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -864,6 +864,8 @@ struct VarConv2DParam { int stride_w; int kernel_h; int kernel_w; + + bool fuse_relu{false}; }; /// ----------------------- shape operators ---------------------- diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index d87c871a32..51f43c7099 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -48,6 +48,10 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.kernel_w = opdesc.GetAttr("KernelW"); param_.stride_h = opdesc.GetAttr("StrideH"); param_.stride_w = opdesc.GetAttr("StrideW"); + + if (opdesc.HasAttr("fuse_relu")) { + param_.fuse_relu = opdesc.GetAttr("fuse_relu"); + } return true; } -- GitLab