未验证 提交 615d8a22 编写于 作者: W Wojciech Uss 提交者: GitHub

Modify relu native implementation 2 (#30996)

* Modify relu native implementation

* fix GPU performance
上级 9401173e
...@@ -216,6 +216,8 @@ endif(WIN32) ...@@ -216,6 +216,8 @@ endif(WIN32)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -w") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings # Set :expt-relaxed-constexpr to suppress Eigen warnings
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
# Set :expt-extended-lambda to enable HOSTDEVICE annotation on lambdas
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
if(WIN32) if(WIN32)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"")
......
...@@ -1061,7 +1061,7 @@ REGISTER_OPERATOR( ...@@ -1061,7 +1061,7 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
relu_grad_grad, relu_grad_grad,
......
...@@ -60,7 +60,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -60,7 +60,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */ /* ========================================================================== */
/* =========================== relu register ============================ */ /* =========================== relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluCUDAFunctor, ReluGradFunctor);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
relu_grad_grad, relu_grad_grad,
......
...@@ -318,7 +318,17 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> { ...@@ -318,7 +318,17 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T> template <typename T>
struct ReluFunctor : public BaseActivationFunctor<T> { struct ReluCPUFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
return v > static_cast<T>(0) ? v : static_cast<T>(0);
});
}
};
template <typename T>
struct ReluCUDAFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0)); out.device(d) = x.cwiseMax(static_cast<T>(0));
......
...@@ -93,7 +93,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T> ...@@ -93,7 +93,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
auto y_v = framework::EigenVector<T>::Flatten(*y); auto y_v = framework::EigenVector<T>::Flatten(*y);
auto &dev = *dev_ctx.eigen_device(); auto &dev = *dev_ctx.eigen_device();
if (act_type == "relu") { if (act_type == "relu") {
ReluFunctor<T>()(dev, x_v, y_v); ReluCUDAFunctor<T>()(dev, x_v, y_v);
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("Unsupported activation type")); platform::errors::Unimplemented("Unsupported activation type"));
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,19 +38,24 @@ template <typename DeviceContext, typename T> ...@@ -37,19 +38,24 @@ template <typename DeviceContext, typename T>
class GRUUnitKernel : public framework::OpKernel<T> { class GRUUnitKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void ActCompute(const int act_type, const Device& d, X x, Y y) const { void ActCompute(const int act_type, const Device& d, X x, Y y,
if (act_type == identity) platform::Place place) const {
if (act_type == identity) {
y.device(d) = x; y.device(d) = x;
else if (act_type == sigmoid) } else if (act_type == sigmoid) {
SigmoidFunctor<T>()(d, x, y); SigmoidFunctor<T>()(d, x, y);
else if (act_type == tanh) } else if (act_type == tanh) {
TanhFunctor<T>()(d, x, y); TanhFunctor<T>()(d, x, y);
else if (act_type == relu) } else if (act_type == relu) {
ReluFunctor<T>()(d, x, y); if (place == platform::CPUPlace())
else ReluCPUFunctor<T>()(d, x, y);
else
ReluCUDAFunctor<T>()(d, x, y);
} else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported activation type, only supports identity, sigmoid, tanh " "Unsupported activation type, only supports identity, sigmoid, tanh "
"and relu.")); "and relu."));
}
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -97,11 +103,13 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -97,11 +103,13 @@ class GRUUnitKernel : public framework::OpKernel<T> {
Eigen::array<int, 2> extents{{batch_size, frame_size}}; Eigen::array<int, 2> extents{{batch_size, frame_size}};
Eigen::array<int, 2> u_offsets{{0, 0}}; Eigen::array<int, 2> u_offsets{{0, 0}};
ActCompute(context.Attr<int>("gate_activation"), place, ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(u_offsets, extents), g.slice(u_offsets, extents)); g.slice(u_offsets, extents), g.slice(u_offsets, extents),
context.GetPlace());
auto u = g.slice(u_offsets, extents); // update gate auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets{{0, frame_size}}; Eigen::array<int, 2> r_offsets{{0, frame_size}};
ActCompute(context.Attr<int>("gate_activation"), place, ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(r_offsets, extents), g.slice(r_offsets, extents)); g.slice(r_offsets, extents), g.slice(r_offsets, extents),
context.GetPlace());
auto r = g.slice(r_offsets, extents); // reset gate auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state r_h_p.device(place) = r * h_p; // reset previous hidden state
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
...@@ -111,7 +119,8 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -111,7 +119,8 @@ class GRUUnitKernel : public framework::OpKernel<T> {
Eigen::array<int, 2> c_offsets{{0, frame_size * 2}}; Eigen::array<int, 2> c_offsets{{0, frame_size * 2}};
ActCompute(context.Attr<int>("activation"), place, ActCompute(context.Attr<int>("activation"), place,
g.slice(c_offsets, extents), g.slice(c_offsets, extents)); g.slice(c_offsets, extents), g.slice(c_offsets, extents),
context.GetPlace());
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output // calculate final output
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -81,18 +82,22 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -81,18 +82,22 @@ class LSTMPKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void ActCompute(const math::detail::ActivationType act_type, const Device& d, void ActCompute(const math::detail::ActivationType act_type, const Device& d,
X x, Y y) const { X x, Y y, platform::Place place) const {
if (act_type == math::detail::ActivationType::kIdentity) if (act_type == math::detail::ActivationType::kIdentity) {
y.device(d) = x; y.device(d) = x;
else if (act_type == math::detail::ActivationType::kSigmoid) } else if (act_type == math::detail::ActivationType::kSigmoid) {
SigmoidFunctor<T>()(d, x, y); SigmoidFunctor<T>()(d, x, y);
else if (act_type == math::detail::ActivationType::kTanh) } else if (act_type == math::detail::ActivationType::kTanh) {
TanhFunctor<T>()(d, x, y); TanhFunctor<T>()(d, x, y);
else if (act_type == math::detail::ActivationType::kReLU) } else if (act_type == math::detail::ActivationType::kReLU) {
ReluFunctor<T>()(d, x, y); if (place == platform::CPUPlace())
else ReluCPUFunctor<T>()(d, x, y);
else
ReluCUDAFunctor<T>()(d, x, y);
} else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::InvalidArgument("unsupported activation type")); platform::errors::InvalidArgument("unsupported activation type"));
}
} }
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -225,7 +230,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -225,7 +230,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
&proj_t, static_cast<T>(0.0)); &proj_t, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t); auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev); ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace());
} }
if (proj_clip && proj_clip > 0.0) { if (proj_clip && proj_clip > 0.0) {
T* x_data = proj_t.data<T>(); T* x_data = proj_t.data<T>();
......
...@@ -979,7 +979,7 @@ class RNNCPUKernel : public framework::OpKernel<T> { ...@@ -979,7 +979,7 @@ class RNNCPUKernel : public framework::OpKernel<T> {
} else if (is_rnn_relu(ctx)) { } else if (is_rnn_relu(ctx)) {
gate_num = 1; gate_num = 1;
RnnFunc< RnnFunc<
SimpleRNNCell<T, ReluFunctor, math::detail::ActivationType::kReLU>, SimpleRNNCell<T, ReluCPUFunctor, math::detail::ActivationType::kReLU>,
Layer, SingleLayer, BidirLayer, T>( Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length, ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num, state[0], nullptr, output, dropout_mask, num_layers, gate_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部