From 615d8a226403961bfa435e52ba22e6ab197a39c7 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Thu, 18 Feb 2021 22:20:16 -0800 Subject: [PATCH] Modify relu native implementation 2 (#30996) * Modify relu native implementation * fix GPU performance --- cmake/cuda.cmake | 2 ++ paddle/fluid/operators/activation_op.cc | 2 +- paddle/fluid/operators/activation_op.cu | 2 +- paddle/fluid/operators/activation_op.h | 12 +++++++- .../operators/fused/fused_bn_activation_op.cu | 2 +- paddle/fluid/operators/gru_unit_op.h | 29 ++++++++++++------- paddle/fluid/operators/lstmp_op.h | 21 +++++++++----- paddle/fluid/operators/rnn_op.h | 2 +- 8 files changed, 49 insertions(+), 23 deletions(-) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index f373951ccb..2f4f5449f4 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -216,6 +216,8 @@ endif(WIN32) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -w") # Set :expt-relaxed-constexpr to suppress Eigen warnings set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +# Set :expt-extended-lambda to enable HOSTDEVICE annotation on lambdas +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") if(WIN32) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"") diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 3643fd926d..785d6daaec 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1061,7 +1061,7 @@ REGISTER_OPERATOR( ops::ActivationOpDoubleGrad2::FwdDeps()>, ops::ActivationDoubleGradOpInplaceInferer); -REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); +REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor); REGISTER_OP_CPU_KERNEL( relu_grad_grad, diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 3677739917..2033081af2 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -60,7 +60,7 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); +REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluCUDAFunctor, ReluGradFunctor); REGISTER_OP_CUDA_KERNEL( relu_grad_grad, diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 483f5cc2e5..289cc70392 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -318,7 +318,17 @@ struct ExpGradFunctor : public BaseActivationFunctor { // relu(x) = max(x, 0) template -struct ReluFunctor : public BaseActivationFunctor { +struct ReluCPUFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { + return v > static_cast(0) ? v : static_cast(0); + }); + } +}; + +template +struct ReluCUDAFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { out.device(d) = x.cwiseMax(static_cast(0)); diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.cu b/paddle/fluid/operators/fused/fused_bn_activation_op.cu index 32eaf11809..9339ae8e47 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.cu +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.cu @@ -93,7 +93,7 @@ class FusedBatchNormActKernel auto y_v = framework::EigenVector::Flatten(*y); auto &dev = *dev_ctx.eigen_device(); if (act_type == "relu") { - ReluFunctor()(dev, x_v, y_v); + ReluCUDAFunctor()(dev, x_v, y_v); } else { PADDLE_THROW( platform::errors::Unimplemented("Unsupported activation type")); diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 4865a02c52..2d1a89f9ae 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace operators { @@ -37,19 +38,24 @@ template class GRUUnitKernel : public framework::OpKernel { public: template - void ActCompute(const int act_type, const Device& d, X x, Y y) const { - if (act_type == identity) + void ActCompute(const int act_type, const Device& d, X x, Y y, + platform::Place place) const { + if (act_type == identity) { y.device(d) = x; - else if (act_type == sigmoid) + } else if (act_type == sigmoid) { SigmoidFunctor()(d, x, y); - else if (act_type == tanh) + } else if (act_type == tanh) { TanhFunctor()(d, x, y); - else if (act_type == relu) - ReluFunctor()(d, x, y); - else + } else if (act_type == relu) { + if (place == platform::CPUPlace()) + ReluCPUFunctor()(d, x, y); + else + ReluCUDAFunctor()(d, x, y); + } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported activation type, only supports identity, sigmoid, tanh " "and relu.")); + } } void Compute(const framework::ExecutionContext& context) const override { @@ -97,11 +103,13 @@ class GRUUnitKernel : public framework::OpKernel { Eigen::array extents{{batch_size, frame_size}}; Eigen::array u_offsets{{0, 0}}; ActCompute(context.Attr("gate_activation"), place, - g.slice(u_offsets, extents), g.slice(u_offsets, extents)); + g.slice(u_offsets, extents), g.slice(u_offsets, extents), + context.GetPlace()); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets{{0, frame_size}}; ActCompute(context.Attr("gate_activation"), place, - g.slice(r_offsets, extents), g.slice(r_offsets, extents)); + g.slice(r_offsets, extents), g.slice(r_offsets, extents), + context.GetPlace()); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, @@ -111,7 +119,8 @@ class GRUUnitKernel : public framework::OpKernel { Eigen::array c_offsets{{0, frame_size * 2}}; ActCompute(context.Attr("activation"), place, - g.slice(c_offsets, extents), g.slice(c_offsets, extents)); + g.slice(c_offsets, extents), g.slice(c_offsets, extents), + context.GetPlace()); auto c = g.slice(c_offsets, extents); // output candidate // calculate final output diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index a2d1d5295b..5a6ac42f45 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/transform.h" namespace paddle { @@ -81,18 +82,22 @@ class LSTMPKernel : public framework::OpKernel { public: template void ActCompute(const math::detail::ActivationType act_type, const Device& d, - X x, Y y) const { - if (act_type == math::detail::ActivationType::kIdentity) + X x, Y y, platform::Place place) const { + if (act_type == math::detail::ActivationType::kIdentity) { y.device(d) = x; - else if (act_type == math::detail::ActivationType::kSigmoid) + } else if (act_type == math::detail::ActivationType::kSigmoid) { SigmoidFunctor()(d, x, y); - else if (act_type == math::detail::ActivationType::kTanh) + } else if (act_type == math::detail::ActivationType::kTanh) { TanhFunctor()(d, x, y); - else if (act_type == math::detail::ActivationType::kReLU) - ReluFunctor()(d, x, y); - else + } else if (act_type == math::detail::ActivationType::kReLU) { + if (place == platform::CPUPlace()) + ReluCPUFunctor()(d, x, y); + else + ReluCUDAFunctor()(d, x, y); + } else { PADDLE_THROW( platform::errors::InvalidArgument("unsupported activation type")); + } } void Compute(const framework::ExecutionContext& ctx) const override { @@ -225,7 +230,7 @@ class LSTMPKernel : public framework::OpKernel { &proj_t, static_cast(0.0)); if (proj_act != math::detail::ActivationType::kIdentity) { auto proj_t_dev = EigenMatrix::From(proj_t); - ActCompute(cell_act, place, proj_t_dev, proj_t_dev); + ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace()); } if (proj_clip && proj_clip > 0.0) { T* x_data = proj_t.data(); diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index b993f5ac17..2b223e24cf 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -979,7 +979,7 @@ class RNNCPUKernel : public framework::OpKernel { } else if (is_rnn_relu(ctx)) { gate_num = 1; RnnFunc< - SimpleRNNCell, + SimpleRNNCell, Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], nullptr, sequence_length, state[0], nullptr, output, dropout_mask, num_layers, gate_num, -- GitLab