From 2a344823ef3276d506ad3e84834fa87d9d92afbb Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 8 May 2020 13:41:32 +0800 Subject: [PATCH] add eltwise_activate fuse. test=develop (#3367) * add eltwise_activate_fuse. test=develop --- lite/api/paddle_use_passes.h | 2 +- lite/backends/cuda/math/elementwise.cu | 95 ++++++-- lite/backends/cuda/math/elementwise.h | 18 +- .../elementwise_add_activation_fuse_pass.cc | 23 +- .../elementwise_add_activation_fuse_pass.h | 2 +- .../elementwise_add_activation_fuser.cc | 56 +++-- .../fusion/elementwise_add_activation_fuser.h | 14 +- lite/core/optimizer.h | 2 +- lite/kernels/cuda/elementwise_compute.cu | 221 ++++++++++++------ lite/kernels/cuda/elementwise_compute.h | 34 ++- .../fusion_elementwise_activation_ops.cc | 2 - 11 files changed, 331 insertions(+), 138 deletions(-) diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 8cb4dbf192..6732b96873 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -37,7 +37,7 @@ USE_MIR_PASS(identity_dropout_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); -USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); +USE_MIR_PASS(lite_elementwise_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); diff --git a/lite/backends/cuda/math/elementwise.cu b/lite/backends/cuda/math/elementwise.cu index 8f0ebd1f97..63e710b358 100644 --- a/lite/backends/cuda/math/elementwise.cu +++ b/lite/backends/cuda/math/elementwise.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/backends/cuda/math/elementwise.h" +#include "lite/utils/cp_logging.h" namespace paddle { namespace lite { @@ -62,6 +63,52 @@ __global__ void elementwise_relu_kernel(const size_t total, } } +template +__global__ void elementwise_abs_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; + Dtype temp; +#if __CUDA_ARCH__ >= 350 + temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type); + +#else + temp = binary_calc(x_data[tid], y_data[idx], type); +#endif + out_data[tid] = temp > 0 ? temp : -temp; + } +} + +template +__global__ void elementwise_tanh_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; + Dtype temp; +#if __CUDA_ARCH__ >= 350 + temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type); + +#else + temp = binary_calc(x_data[tid], y_data[idx], type); +#endif + out_data[tid] = tanh(temp); + } +} + template __global__ void elementwise_add_kernel(const size_t total, const Dtype* x_data, @@ -135,19 +182,30 @@ void elementwise(const Dtype* x_data, } template -void elementwise_relu(const Dtype* x_data, - const Dtype* y_data, - Dtype* out_data, - int pre, - int n, - int post, - BinaryOperation type, - cudaStream_t stream) { +void elementwise_act(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + std::string act, + BinaryOperation type, + cudaStream_t stream) { int num = pre * n * post; int thread = 256; int block = (num + thread - 1) / thread; - elementwise_relu_kernel<<>>( - num, x_data, y_data, out_data, pre, n, post, type); + if (act == "relu") { + elementwise_relu_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); + } else if (act == "tanh") { + elementwise_tanh_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); + } else if (act == "abs") { + elementwise_abs_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); + } else { + LOG(FATAL) << "not supported activate type: " << act; + } } template void elementwise(const float*, @@ -159,14 +217,15 @@ template void elementwise(const float*, BinaryOperation, cudaStream_t); -template void elementwise_relu(const float*, - const float*, - float*, - int, - int, - int, - BinaryOperation, - cudaStream_t); +template void elementwise_act(const float* x_data, + const float* y_data, + float* out_data, + int pre, + int n, + int post, + std::string act, + BinaryOperation type, + cudaStream_t stream); template void elementwise_add(int num, diff --git a/lite/backends/cuda/math/elementwise.h b/lite/backends/cuda/math/elementwise.h index ce45d0544e..46412de235 100644 --- a/lite/backends/cuda/math/elementwise.h +++ b/lite/backends/cuda/math/elementwise.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include "lite/backends/cuda/math/utils.h" namespace paddle { @@ -33,14 +34,15 @@ void elementwise(const Dtype* x_data, cudaStream_t stream); template -void elementwise_relu(const Dtype* x_data, - const Dtype* y_data, - Dtype* out_data, - int pre, - int n, - int post, - BinaryOperation type, - cudaStream_t stream); +void elementwise_act(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + std::string act, + BinaryOperation type, + cudaStream_t stream); template void elementwise_add(int num, diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 1c2297710b..4de007bb17 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -22,20 +22,31 @@ namespace paddle { namespace lite { namespace mir { -void ElementwiseAddActivationFusePass::Apply( +void ElementwiseActivationFusePass::Apply( const std::unique_ptr& graph) { - fusion::ElementwiseAddActivationFuser fuser("relu"); - fuser(graph.get()); + // initialze fuser params + std::vector elt_types{ + "elementwise_add", "elementwise_sub", "elementwise_mul"}; + std::vector act_types{"relu", "abs", "tanh"}; + + // start fuse using params + for (auto elt_type : elt_types) { + for (auto act_type : act_types) { + fusion::ElementwiseActivationFuser fuser(elt_type, act_type); + fuser(graph.get()); + } + } } } // namespace mir } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, - paddle::lite::mir::ElementwiseAddActivationFusePass) +REGISTER_MIR_PASS(lite_elementwise_activation_fuse_pass, + paddle::lite::mir::ElementwiseActivationFusePass) .BindTargets({TARGET(kAny)}) .ExcludeTargets({TARGET(kXPU)}) .ExcludeTargets({TARGET(kBM)}) .ExcludeTargets({TARGET(kX86)}) - .BindKernel("fusion_elementwise_add_activation"); + .BindKernel("fusion_elementwise_add_activation") + .BindKernel("fusion_elementwise_sub_activation"); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h index 299b6b89a0..bca8bd802b 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h @@ -22,7 +22,7 @@ namespace paddle { namespace lite { namespace mir { -class ElementwiseAddActivationFusePass : public ProgramPass { +class ElementwiseActivationFusePass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override; }; diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuser.cc b/lite/core/mir/fusion/elementwise_add_activation_fuser.cc index 3c6bf4768b..28081748a7 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuser.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuser.cc @@ -21,21 +21,21 @@ namespace lite { namespace mir { namespace fusion { -void ElementwiseAddActivationFuser::BuildPattern() { +void ElementwiseActivationFuser::BuildPattern() { // create input nodes. - auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput(); - auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + auto* x = VarNode("x")->assert_is_op_input(eltwise_type_, "X")->AsInput(); + auto* y = VarNode("y")->assert_is_op_input(eltwise_type_, "Y")->AsInput(); // create op nodes - auto* add = OpNode("add", "elementwise_add") - ->assert_is_op("elementwise_add") + auto* elt = OpNode("elt", eltwise_type_) + ->assert_is_op(eltwise_type_) ->AsIntermediate(); auto* act = OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); // create intermediate nodes - auto* add_out = VarNode("add_out") - ->assert_is_op_output("elementwise_add", "Out") + auto* elt_out = VarNode("add_out") + ->assert_is_op_output(eltwise_type_, "Out") ->assert_is_op_input(act_type_, "X") ->AsIntermediate(); @@ -44,21 +44,29 @@ void ElementwiseAddActivationFuser::BuildPattern() { VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); // create topology. - std::vector add_inputs{x, y}; - add_inputs >> *add >> *add_out; - *add_out >> *act >> *out; + std::vector elt_inputs{x, y}; + elt_inputs >> *elt >> *elt_out; + *elt_out >> *act >> *out; } -void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, - const key2nodes_t& matched) { +void ElementwiseActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); - auto op = - LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); - auto old_op = matched.at("add")->stmt()->op(); + std::shared_ptr op; + if (eltwise_type_ == "elementwise_add") { + op = LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); + } else if (eltwise_type_ == "elementwise_sub") { + op = LiteOpRegistry::Global().Create("fusion_elementwise_sub_activation"); + } else if (eltwise_type_ == "elementwise_mul") { + op = LiteOpRegistry::Global().Create("fusion_elementwise_mul_activation"); + } else { + LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_; + } + + auto old_op = matched.at("elt")->stmt()->op(); auto* scope = old_op->scope(); auto& valid_places = old_op->valid_places(); op->Attach(op_desc, scope); - auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places); IR_NODE_LINK_TO(matched.at("x"), new_op_node); @@ -66,12 +74,20 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, IR_NODE_LINK_TO(new_op_node, matched.at("output")); } -cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc( - const key2nodes_t& matched) { - auto* desc = matched.at("add")->stmt()->op_info(); +cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) { + auto* desc = matched.at("elt")->stmt()->op_info(); cpp::OpDesc op_desc; - op_desc.SetType("fusion_elementwise_add_activation"); + if (eltwise_type_ == "elementwise_add") { + op_desc.SetType("fusion_elementwise_add_activation"); + } else if (eltwise_type_ == "elementwise_sub") { + op_desc.SetType("fusion_elementwise_sub_activation"); + } else if (eltwise_type_ == "elementwise_mul") { + op_desc.SetType("fusion_elementwise_mul_activation"); + } else { + LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_; + } + op_desc.SetInput("X", {matched.at("x")->arg()->name}); op_desc.SetInput("Y", {matched.at("y")->arg()->name}); op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuser.h b/lite/core/mir/fusion/elementwise_add_activation_fuser.h index 47bb2fcf82..ac56e7a675 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuser.h +++ b/lite/core/mir/fusion/elementwise_add_activation_fuser.h @@ -23,15 +23,23 @@ namespace lite { namespace mir { namespace fusion { -class ElementwiseAddActivationFuser : public FuseBase { +// Detect elementwise and activation ops, and then merge into +// fusion_eltsiwise_act op. +// Example: +// elementwise_add + relu fuse. +// fusion::ElementwiseActivationFuser fuser("elementwise_add", "relu"); +// fuser(graph.get()); +class ElementwiseActivationFuser : public FuseBase { public: - explicit ElementwiseAddActivationFuser(const std::string& act_type) - : act_type_(act_type) {} + explicit ElementwiseActivationFuser(const std::string& eltwise_type, + const std::string& act_type) + : eltwise_type_(eltwise_type), act_type_(act_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string eltwise_type_; std::string act_type_; }; diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 3d71b5d62e..c095ec9697 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -74,7 +74,7 @@ class Optimizer { "lite_scale_activation_fuse_pass", // #if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ (defined LITE_WITH_ARM) - "lite_elementwise_add_activation_fuse_pass", // + "lite_elementwise_activation_fuse_pass", // #endif "__xpu__resnet_fuse_pass", "__xpu__multi_encoder_fuse_pass", diff --git a/lite/kernels/cuda/elementwise_compute.cu b/lite/kernels/cuda/elementwise_compute.cu index 02b7c8f7d9..310be5e94b 100644 --- a/lite/kernels/cuda/elementwise_compute.cu +++ b/lite/kernels/cuda/elementwise_compute.cu @@ -70,7 +70,30 @@ inline bool is_broadcast(const DDim& x_dims, return true; } -#define ELEMENTWISE_COMPUTE(OP, WITH_RELU) \ +#define ELEMENTWISE_COMPUTE(OP) \ + auto& param = this->Param(); \ + auto& ctx = this->ctx_->template As(); \ + auto stream = ctx.exec_stream(); \ + const lite::Tensor* x = param.X; \ + const lite::Tensor* y = param.Y; \ + lite::Tensor* out = param.Out; \ + int axis = param.axis; \ + auto* x_data = x->data(); \ + auto* y_data = y->data(); \ + auto out_data = out->mutable_data(TARGET(kCUDA)); \ + int pixel_num = x->numel(); \ + int pre = 1; \ + int n = pixel_num; \ + int post = 1; \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, pre, n, post, OP, stream); \ + } else { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ + } + +#define ELEMENTWISE_COMPUTE_ACT(OP) \ auto& param = this->Param(); \ auto& ctx = this->ctx_->template As(); \ auto stream = ctx.exec_stream(); \ @@ -85,25 +108,43 @@ inline bool is_broadcast(const DDim& x_dims, int pre = 1; \ int n = pixel_num; \ int post = 1; \ - if (WITH_RELU) { \ - if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ - lite::cuda::math::elementwise_relu( \ - x_data, y_data, out_data, pre, n, post, OP, stream); \ - } else { \ - lite::cuda::math::elementwise_relu( \ - x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ - } \ + auto act = param.act_type; \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise_act( \ + x_data, y_data, out_data, pre, n, post, act, OP, stream); \ } else { \ - if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ - lite::cuda::math::elementwise( \ - x_data, y_data, out_data, pre, n, post, OP, stream); \ - } else { \ - lite::cuda::math::elementwise( \ - x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ - } \ + lite::cuda::math::elementwise_act( \ + x_data, y_data, out_data, 1, pixel_num, 1, act, OP, stream); \ + } + +#define ELEMENTWISE_COMPUTE_NHWC(OP) \ + std::map pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \ + auto& param = this->Param(); \ + auto& ctx = this->ctx_->template As(); \ + auto stream = ctx.exec_stream(); \ + const lite::Tensor* x = param.X; \ + const lite::Tensor* y = param.Y; \ + lite::Tensor* out = param.Out; \ + int axis = param.axis; \ + if (axis < 0) axis = x->dims().size() - y->dims().size(); \ + CHECK(axis >= 0) << "invalid axis of elementwise op"; \ + axis = pos_map[axis]; \ + auto* x_data = x->data(); \ + auto* y_data = y->data(); \ + auto out_data = out->mutable_data(TARGET(kCUDA)); \ + int pixel_num = x->numel(); \ + int pre = 1; \ + int n = pixel_num; \ + int post = 1; \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, pre, n, post, OP, stream); \ + } else { \ + lite::cuda::math::elementwise( \ + x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ } -#define ELEMENTWISE_COMPUTE_NHWC(OP, WITH_RELU) \ +#define ELEMENTWISE_COMPUTE_ACT_NHWC(OP) \ std::map pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \ auto& param = this->Param(); \ auto& ctx = this->ctx_->template As(); \ @@ -122,80 +163,83 @@ inline bool is_broadcast(const DDim& x_dims, int pre = 1; \ int n = pixel_num; \ int post = 1; \ - if (WITH_RELU) { \ - if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ - lite::cuda::math::elementwise_relu( \ - x_data, y_data, out_data, pre, n, post, OP, stream); \ - } else { \ - lite::cuda::math::elementwise_relu( \ - x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ - } \ + auto act = param.act_type; \ + if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ + lite::cuda::math::elementwise_act( \ + x_data, y_data, out_data, pre, n, post, act, OP, stream); \ } else { \ - if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \ - lite::cuda::math::elementwise( \ - x_data, y_data, out_data, pre, n, post, OP, stream); \ - } else { \ - lite::cuda::math::elementwise( \ - x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \ - } \ + lite::cuda::math::elementwise_act( \ + x_data, y_data, out_data, 1, pixel_num, 1, act, OP, stream); \ } void ElementwiseAddCompute::Run() { - ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, false) + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } void ElementwiseAddComputeNHWC::Run() { - ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, false) + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } void ElementwiseSubCompute::Run() { - ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false) + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } void ElementwiseSubComputeNHWC::Run() { - ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false) + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } void ElementwiseMulCompute::Run() { - ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false) + ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } void ElementwiseMulComputeNHWC::Run() { - ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, false) + ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddActivationCompute::Run() { + ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kADD) + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +void ElementwiseAddActivationComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kADD) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -void ElementwiseAddReluCompute::Run() { - ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, true) +void ElementwiseSubActivationCompute::Run() { + ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kSUB) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -void ElementwiseAddReluComputeNHWC::Run() { - ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, true) +void ElementwiseSubActivationComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kSUB) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -void ElementwiseMulReluCompute::Run() { - ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, true) +void ElementwiseMulActivationCompute::Run() { + ELEMENTWISE_COMPUTE_ACT(lite::cuda::math::BinaryOperation::kMUL) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -void ElementwiseMulReluComputeNHWC::Run() { - ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, true) +void ElementwiseMulActivationComputeNHWC::Run() { + ELEMENTWISE_COMPUTE_ACT_NHWC(lite::cuda::math::BinaryOperation::kMUL) cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } @@ -298,23 +342,25 @@ REGISTER_LITE_KERNEL(elementwise_mul, DATALAYOUT(kNHWC))}) .Finalize(); -REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, - kCUDA, - kFloat, - kNCHW, - paddle::lite::kernels::cuda::ElementwiseAddReluCompute, - def) +REGISTER_LITE_KERNEL( + fusion_elementwise_add_activation, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseAddActivationCompute, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); -REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, - kCUDA, - kFloat, - kNHWC, - paddle::lite::kernels::cuda::ElementwiseAddReluComputeNHWC, - nhwc_format) +REGISTER_LITE_KERNEL( + fusion_elementwise_add_activation, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseAddActivationComputeNHWC, + nhwc_format) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), @@ -329,23 +375,58 @@ REGISTER_LITE_KERNEL(fusion_elementwise_add_activation, DATALAYOUT(kNHWC))}) .Finalize(); -REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, - kCUDA, - kFloat, - kNCHW, - paddle::lite::kernels::cuda::ElementwiseMulReluCompute, - def) +REGISTER_LITE_KERNEL( + fusion_elementwise_sub_activation, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseSubActivationCompute, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); -REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation, - kCUDA, - kFloat, - kNHWC, - paddle::lite::kernels::cuda::ElementwiseMulReluComputeNHWC, - nhwc_format) +REGISTER_LITE_KERNEL( + fusion_elementwise_sub_activation, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseSubActivationComputeNHWC, + nhwc_format) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + fusion_elementwise_mul_activation, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseMulActivationCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + fusion_elementwise_mul_activation, + kCUDA, + kFloat, + kNHWC, + paddle::lite::kernels::cuda::ElementwiseMulActivationComputeNHWC, + nhwc_format) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), diff --git a/lite/kernels/cuda/elementwise_compute.h b/lite/kernels/cuda/elementwise_compute.h index bc9ffd5d27..b7558d94d4 100644 --- a/lite/kernels/cuda/elementwise_compute.h +++ b/lite/kernels/cuda/elementwise_compute.h @@ -74,40 +74,58 @@ class ElementwiseMulComputeNHWC virtual ~ElementwiseMulComputeNHWC() = default; }; -class ElementwiseAddReluCompute +class ElementwiseAddActivationCompute : public KernelLite { public: using param_t = operators::FusionElementwiseActivationParam; void Run() override; - virtual ~ElementwiseAddReluCompute() = default; + virtual ~ElementwiseAddActivationCompute() = default; }; -class ElementwiseAddReluComputeNHWC +class ElementwiseAddActivationComputeNHWC : public KernelLite { public: using param_t = operators::FusionElementwiseActivationParam; void Run() override; - virtual ~ElementwiseAddReluComputeNHWC() = default; + virtual ~ElementwiseAddActivationComputeNHWC() = default; }; -class ElementwiseMulReluCompute +class ElementwiseSubActivationCompute : public KernelLite { public: using param_t = operators::FusionElementwiseActivationParam; void Run() override; - virtual ~ElementwiseMulReluCompute() = default; + virtual ~ElementwiseSubActivationCompute() = default; }; -class ElementwiseMulReluComputeNHWC +class ElementwiseSubActivationComputeNHWC : public KernelLite { public: using param_t = operators::FusionElementwiseActivationParam; void Run() override; - virtual ~ElementwiseMulReluComputeNHWC() = default; + virtual ~ElementwiseSubActivationComputeNHWC() = default; +}; + +class ElementwiseMulActivationCompute + : public KernelLite { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void Run() override; + virtual ~ElementwiseMulActivationCompute() = default; +}; + +class ElementwiseMulActivationComputeNHWC + : public KernelLite { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void Run() override; + virtual ~ElementwiseMulActivationComputeNHWC() = default; }; } // namespace cuda diff --git a/lite/operators/fusion_elementwise_activation_ops.cc b/lite/operators/fusion_elementwise_activation_ops.cc index dfe3bda6c6..59d641c371 100644 --- a/lite/operators/fusion_elementwise_activation_ops.cc +++ b/lite/operators/fusion_elementwise_activation_ops.cc @@ -44,8 +44,6 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, param_.Out = GetMutableVar(scope, Out_name); param_.axis = opdesc.GetAttr("axis"); param_.act_type = opdesc.GetAttr("act_type"); - // TODO(sangoly): support more activation types. - CHECK(param_.act_type == "relu") << "Only relu activation be supported now"; return true; } -- GitLab