未验证 提交 98c4c780 编写于 作者: W Wojciech Uss 提交者: GitHub

Modify relu native implementation 2 (#30996) (#31348)

上级 325bfc37
......@@ -216,6 +216,8 @@ endif(WIN32)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
# Set :expt-extended-lambda to enable HOSTDEVICE annotation on lambdas
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
if(WIN32)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"")
......
......@@ -1051,7 +1051,7 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
relu_grad_grad,
......
......@@ -60,7 +60,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluCUDAFunctor, ReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
......
......@@ -318,7 +318,17 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor : public BaseActivationFunctor<T> {
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
return v > static_cast<T>(0) ? v : static_cast<T>(0);
});
}
};
template <typename T>
struct ReluCUDAFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0));
......
......@@ -93,7 +93,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
auto y_v = framework::EigenVector<T>::Flatten(*y);
auto &dev = *dev_ctx.eigen_device();
if (act_type == "relu") {
ReluFunctor<T>()(dev, x_v, y_v);
ReluCUDAFunctor<T>()(dev, x_v, y_v);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported activation type"));
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
......@@ -37,19 +38,24 @@ template <typename DeviceContext, typename T>
class GRUUnitKernel : public framework::OpKernel<T> {
public:
template <typename Device, typename X, typename Y>
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
if (act_type == identity)
void ActCompute(const int act_type, const Device& d, X x, Y y,
platform::Place place) const {
if (act_type == identity) {
y.device(d) = x;
else if (act_type == sigmoid)
} else if (act_type == sigmoid) {
SigmoidFunctor<T>()(d, x, y);
else if (act_type == tanh)
} else if (act_type == tanh) {
TanhFunctor<T>()(d, x, y);
else if (act_type == relu)
ReluFunctor<T>()(d, x, y);
else
} else if (act_type == relu) {
if (place == platform::CPUPlace())
ReluCPUFunctor<T>()(d, x, y);
else
ReluCUDAFunctor<T>()(d, x, y);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported activation type, only supports identity, sigmoid, tanh "
"and relu."));
}
}
void Compute(const framework::ExecutionContext& context) const override {
......@@ -97,11 +103,13 @@ class GRUUnitKernel : public framework::OpKernel<T> {
Eigen::array<int, 2> extents{{batch_size, frame_size}};
Eigen::array<int, 2> u_offsets{{0, 0}};
ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(u_offsets, extents), g.slice(u_offsets, extents));
g.slice(u_offsets, extents), g.slice(u_offsets, extents),
context.GetPlace());
auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets{{0, frame_size}};
ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
g.slice(r_offsets, extents), g.slice(r_offsets, extents),
context.GetPlace());
auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
......@@ -111,7 +119,8 @@ class GRUUnitKernel : public framework::OpKernel<T> {
Eigen::array<int, 2> c_offsets{{0, frame_size * 2}};
ActCompute(context.Attr<int>("activation"), place,
g.slice(c_offsets, extents), g.slice(c_offsets, extents));
g.slice(c_offsets, extents), g.slice(c_offsets, extents),
context.GetPlace());
auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
......@@ -81,18 +82,22 @@ class LSTMPKernel : public framework::OpKernel<T> {
public:
template <typename Device, typename X, typename Y>
void ActCompute(const math::detail::ActivationType act_type, const Device& d,
X x, Y y) const {
if (act_type == math::detail::ActivationType::kIdentity)
X x, Y y, platform::Place place) const {
if (act_type == math::detail::ActivationType::kIdentity) {
y.device(d) = x;
else if (act_type == math::detail::ActivationType::kSigmoid)
} else if (act_type == math::detail::ActivationType::kSigmoid) {
SigmoidFunctor<T>()(d, x, y);
else if (act_type == math::detail::ActivationType::kTanh)
} else if (act_type == math::detail::ActivationType::kTanh) {
TanhFunctor<T>()(d, x, y);
else if (act_type == math::detail::ActivationType::kReLU)
ReluFunctor<T>()(d, x, y);
else
} else if (act_type == math::detail::ActivationType::kReLU) {
if (place == platform::CPUPlace())
ReluCPUFunctor<T>()(d, x, y);
else
ReluCUDAFunctor<T>()(d, x, y);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("unsupported activation type"));
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -225,7 +230,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
&proj_t, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev, ctx.GetPlace());
}
if (proj_clip && proj_clip > 0.0) {
T* x_data = proj_t.data<T>();
......
......@@ -979,7 +979,7 @@ class RNNCPUKernel : public framework::OpKernel<T> {
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnFunc<
SimpleRNNCell<T, ReluFunctor, math::detail::ActivationType::kReLU>,
SimpleRNNCell<T, ReluCPUFunctor, math::detail::ActivationType::kReLU>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册