diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index ba3e0566277f74fe6dd09b52d77b07aba14a458d..ac29cdda019c29ee208df391e0c637dc07329abe 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -34,6 +34,7 @@ USE_MIR_PASS(lite_interpolate_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); +USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 8d5b3e6876ace41e7e63835e9c847bd381c16635..810ff0f875168da1c4411471b7ea3ea6617a9b4f 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -16,6 +16,7 @@ lite_cc_library(mir_passes fusion/interpolate_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc fusion/conv_activation_fuse_pass.cc + fusion/var_conv_2d_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 5ac52837551f0b78d67dfe1733fe354ee2cf7f01..8699470955b663fc2562074e99529def72836794 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -10,6 +10,9 @@ lite_cc_library(fuse_conv_elementwise lite_cc_library(fuse_conv_activation SRCS conv_activation_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_var_conv_activation + SRCS var_conv_2d_activation_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) @@ -31,6 +34,7 @@ set(mir_fusers fuse_shuffle_channel fuse_conv_elementwise fuse_conv_activation + fuse_var_conv_activation fuse_conv_bn fuse_quant_dequant fuse_elementwise_add_activation diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce2248cbc23d8887a22f94c14b2507fb0cacbed --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void VarConv2dActivationFusePass::Apply( + const std::unique_ptr& graph) { + std::vector act_types{"relu"}; + for (auto act_type : act_types) { + fusion::VarConvActivationFuser fuser(act_type, "var_conv_2d"); + fuser(graph.get()); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_var_conv_2d_activation_fuse_pass, + paddle::lite::mir::VarConv2dActivationFusePass) + .BindTargets({TARGET(kCUDA)}); diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7616aadef340d3e4d6bc11534dd839c91fe9ed1d --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class VarConv2dActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..eabd97ae4513b84c9c002aa1587d45cce6b22e21 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void VarConvActivationFuser::BuildPattern() { + // create nodes. + auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput(); + auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput(); + + auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate(); + + auto* act = OpNode("act", act_type_)->AsIntermediate(); + + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Out") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + auto* conv2d_out_1 = VarNode("conv2d_out_1") + ->assert_is_op_output(conv_type_, "Col") + ->AsIntermediate(); + + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input}; + conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; + *conv2d >> *conv2d_out_1; +} + +void VarConvActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("var_conv_2d")->stmt()->op(); + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("X"), new_op_node); + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc VarConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("var_conv_2d")->stmt()->op_info(); + op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); + cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info(); + + if (act_type_ == "relu") { + op_desc.SetAttr("fuse_relu", true); + } + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.h b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..68bc89f7d13d38dc07814f3296a25bfd7dea0248 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class VarConvActivationFuser : public FuseBase { + public: + explicit VarConvActivationFuser(const std::string& act_type, + const std::string& conv_type) + : act_type_(act_type), conv_type_(conv_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string act_type_; + std::string conv_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 6b2d7f5b18df8230af856d68a8301ee3cb929900..58060f5bb599b7b9854e5c9b53f24d733af22f15 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -62,6 +62,7 @@ class Optimizer { // TODO(Superjomn) Refine the fusion related design to select fusion // kernels for devices automatically. "lite_conv_activation_fuse_pass", // + "lite_var_conv_2d_activation_fuse_pass", // "lite_fc_fuse_pass", // "lite_shuffle_channel_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", // diff --git a/lite/kernels/cuda/var_conv_2d_compute.cu b/lite/kernels/cuda/var_conv_2d_compute.cu index 92a59876a635ea87022f90184fafc7ea4e9919af..1e42635934b67b28fca29808f484be53292d74cf 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.cu +++ b/lite/kernels/cuda/var_conv_2d_compute.cu @@ -73,6 +73,10 @@ void VarConv2DCompute::PrepareForRun() { (*conv_param_.paddings.get())[i * 2 + 1], conv_param_.strides[i])); } + if (param.fuse_relu) { + conv_param_.activation_param.has_active = true; + conv_param_.activation_param.active_type = lite_api::ActivationType::kRelu; + } conv_param_.output->Resize({output_shape}); conv_impl_.reset(new lite::cuda::math::CudnnConv2D); conv_impl_->init(conv_param_, &context); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index dc010b96c0ea69fc8be1f30b1b146dc39352cc28..bd2ba937ea11038ce67da790d2733f5ba6d53b54 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -864,6 +864,8 @@ struct VarConv2DParam { int stride_w; int kernel_h; int kernel_w; + + bool fuse_relu{false}; }; /// ----------------------- shape operators ---------------------- diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index d87c871a3252ad75c066953d3648b55038a2a21c..51f43c709990d7ac1e664336e252ed684479b783 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -48,6 +48,10 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.kernel_w = opdesc.GetAttr("KernelW"); param_.stride_h = opdesc.GetAttr("StrideH"); param_.stride_w = opdesc.GetAttr("StrideW"); + + if (opdesc.HasAttr("fuse_relu")) { + param_.fuse_relu = opdesc.GetAttr("fuse_relu"); + } return true; }