未验证 提交 4490e8af 编写于 作者: N niuliling123 提交者: GitHub

add leaky_relu forward and backward in activation_op.cu (#31841)

* add leaky_relu forward and backward in activation_op.cu
上级 0b42f489
...@@ -42,6 +42,10 @@ template <typename T> ...@@ -42,6 +42,10 @@ template <typename T>
class BaseGPUFunctor { class BaseGPUFunctor {
public: public:
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char*, float*>>;
AttrPair GetAttrs() { return AttrPair(); }
}; };
/* ========================================================================== */ /* ========================================================================== */
...@@ -57,42 +61,35 @@ class ReluGPUFunctor : public BaseGPUFunctor<T> { ...@@ -57,42 +61,35 @@ class ReluGPUFunctor : public BaseGPUFunctor<T> {
// for relu forward when T is double // for relu forward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute( __device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type* x); const typename CudaVecType<T>::type in) {
// relu forward : out = max(x, 0)
return in > zero_ ? in : zero_;
}
// when num % vecsize != 0 this func will be used // when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T x) { __device__ __forceinline__ T ComputeRemainder(const T in) {
return x > zero_ ? x : zero_; // relu forward : out = max(x, 0)
return in > zero_ ? in : zero_;
} }
}; };
template <>
__device__ __forceinline__ CudaVecType<double>::type
ReluGPUFunctor<double>::Compute(const CudaVecType<double>::type* x) {
// relu forward : out = max(x, 0)
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
return __ldg(x) > zero_ ? __ldg(x) : zero_;
#else
return (*x) > zero_ ? (*x) : zero_;
#endif
}
template <> template <>
__device__ __forceinline__ CudaVecType<float>::type __device__ __forceinline__ CudaVecType<float>::type
ReluGPUFunctor<float>::Compute(const CudaVecType<float>::type* xx) { ReluGPUFunctor<float>::Compute(const CudaVecType<float>::type in) {
// relu forward : out = max(xx, 0) // relu forward : out = max(in, 0)
return make_float4((xx->x > zero_) * (xx->x), (xx->y > zero_) * (xx->y), return make_float4((in.x > zero_) * (in.x), (in.y > zero_) * (in.y),
(xx->z > zero_) * (xx->z), (xx->w > zero_) * (xx->w)); (in.z > zero_) * (in.z), (in.w > zero_) * (in.w));
} }
template <> template <>
__device__ __forceinline__ CudaVecType<float16>::type __device__ __forceinline__ CudaVecType<float16>::type
ReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type* in) { ReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type in) {
// relu forward : out = max(in, 0) // relu forward : out = max(in, 0)
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half2 kzero = __float2half2_rn(0.0f); const half2 kzero = __float2half2_rn(0.0f);
return __hmul2(__hgt2(__ldg(in), kzero), __ldg(in)); return __hmul2(__hgt2(in, kzero), in);
#else #else
const float2 xx = __half22float2(*in); const float2 xx = __half22float2(in);
return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(xx.x), return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(xx.x),
(xx.y > 0.0f) * static_cast<float>(xx.y)); (xx.y > 0.0f) * static_cast<float>(xx.y));
#endif #endif
...@@ -112,8 +109,10 @@ class ReluGradGPUFunctor : public BaseGPUFunctor<T> { ...@@ -112,8 +109,10 @@ class ReluGradGPUFunctor : public BaseGPUFunctor<T> {
// for relu backward when T is double // for relu backward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute( __device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type* out, const typename CudaVecType<T>::type out,
const typename CudaVecType<T>::type* dout); const typename CudaVecType<T>::type dout) {
return out > zero_ ? dout : zero_;
}
// when num % vecsize != 0 this func will be used // when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T out, const T dout) { __device__ __forceinline__ T ComputeRemainder(const T out, const T dout) {
...@@ -124,44 +123,132 @@ class ReluGradGPUFunctor : public BaseGPUFunctor<T> { ...@@ -124,44 +123,132 @@ class ReluGradGPUFunctor : public BaseGPUFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <>
__device__ __forceinline__ CudaVecType<double>::type
ReluGradGPUFunctor<double>::Compute(const CudaVecType<double>::type* out,
const CudaVecType<double>::type* dout) {
// relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
return __ldg(out) > zero_ ? __ldg(dout) : zero_;
#else
return (*out) > zero_ ? (*dout) : zero_;
#endif
}
template <> template <>
__device__ __forceinline__ CudaVecType<float>::type __device__ __forceinline__ CudaVecType<float>::type
ReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type* out, ReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type out,
const CudaVecType<float>::type* dout) { const CudaVecType<float>::type dout) {
// relu backward : dx = out > 0 ? dout : 0; // relu backward : dx = out > 0 ? dout : 0;
return make_float4((out->x > zero_) * (dout->x), (out->y > zero_) * (dout->y), return make_float4((out.x > zero_) * (dout.x), (out.y > zero_) * (dout.y),
(out->z > zero_) * (dout->z), (out.z > zero_) * (dout.z), (out.w > zero_) * (dout.w));
(out->w > zero_) * (dout->w));
} }
template <> template <>
__device__ __forceinline__ CudaVecType<float16>::type __device__ __forceinline__ CudaVecType<float16>::type
ReluGradGPUFunctor<float16>::Compute(const CudaVecType<float16>::type* out, ReluGradGPUFunctor<float16>::Compute(const CudaVecType<float16>::type out,
const CudaVecType<float16>::type* dout) { const CudaVecType<float16>::type dout) {
// relu backward : dx = out > 0 ? dout : 0; // relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half2 kzero = __float2half2_rn(0.0f); const half2 kzero = __float2half2_rn(0.0f);
return __hmul2(__hgt2(__ldg(out), kzero), __ldg(dout)); return __hmul2(__hgt2(out, kzero), dout);
#else #else
const float2 xx = __half22float2(*out); const float2 xx = __half22float2(out);
const float2 yy = __half22float2(*dout); const float2 yy = __half22float2(dout);
return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(yy.x), return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(yy.x),
(xx.y > 0.0f) * static_cast<float>(yy.y)); (xx.y > 0.0f) * static_cast<float>(yy.y));
#endif #endif
} }
/* ========================================================================== */
/* ======================== leaky relu forward ========================
*/
template <typename T>
class LeakyReluGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
float alpha_;
public:
LeakyReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha_}};
}
// leakyrelu forward : out = x > 0 ? x : x * alpha
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type in) {
return in > zero_ ? in : static_cast<T>(alpha_) * in;
}
__device__ __forceinline__ T ComputeRemainder(const T in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
return in > zero_ ? in : static_cast<T>(alpha_) * in;
}
};
template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGPUFunctor<float>::Compute(const CudaVecType<float>::type in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
return make_float4((in.x > zero_) ? (in.x) : (in.x) * alpha_,
(in.y > zero_) ? (in.y) : (in.y) * alpha_,
(in.z > zero_) ? (in.z) : (in.z) * alpha_,
(in.w > zero_) ? (in.w) : (in.w) * alpha_);
}
template <>
__device__ __forceinline__ CudaVecType<float16>::type
LeakyReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
const float2 xx = __half22float2(in);
return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_,
(xx.y > 0.0f) ? xx.y : xx.y * alpha_);
}
/* ========================================================================== */
/* =========================== leaky relu backward =======================
*/
template <typename T>
class LeakyReluGradGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
float alpha_;
public:
LeakyReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha_}};
}
// for leaky relu backward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type in,
const typename CudaVecType<T>::type dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return in > zero_ ? dout : static_cast<T>(alpha_) * dout;
}
// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T in, const T dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return in > zero_ ? dout : static_cast<T>(alpha_) * dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type in,
const CudaVecType<float>::type dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return make_float4((in.x > zero_) ? (dout.x) : alpha_ * (dout.x),
(in.y > zero_) ? (dout.y) : alpha_ * (dout.y),
(in.z > zero_) ? (dout.z) : alpha_ * (dout.z),
(in.w > zero_) ? (dout.w) : alpha_ * (dout.w));
}
template <>
__device__ __forceinline__ CudaVecType<float16>::type LeakyReluGradGPUFunctor<
float16>::Compute(const CudaVecType<float16>::type in,
const CudaVecType<float16>::type dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
const float2 xx = __half22float2(in);
const float2 yy = __half22float2(dout);
return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_ * yy.x,
(xx.y > 0.0f) ? yy.y : alpha_ * yy.y);
}
/* ========================================================================== */ /* ========================================================================== */
template <typename T, typename Functor> template <typename T, typename Functor>
...@@ -176,14 +263,23 @@ __global__ void ActivationGradKernelVec(const T* forward_data, const T* dout, ...@@ -176,14 +263,23 @@ __global__ void ActivationGradKernelVec(const T* forward_data, const T* dout,
const VecType* in_forward = reinterpret_cast<const VecType*>(forward_data); const VecType* in_forward = reinterpret_cast<const VecType*>(forward_data);
const VecType* in_dout = reinterpret_cast<const VecType*>(dout); const VecType* in_dout = reinterpret_cast<const VecType*>(dout);
VecType* out = reinterpret_cast<VecType*>(dx); VecType* out = reinterpret_cast<VecType*>(dx);
VecType forward_vec, dout_vec;
T in_data, dout_data;
for (int i = idx; i < loop; i += stride) { for (int i = idx; i < loop; i += stride) {
out[i] = functor.Compute((in_forward + i), (in_dout + i)); #ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
forward_vec = __ldg(in_forward + i);
dout_vec = __ldg(in_dout + i);
#else
forward_vec = in_forward[i];
dout_vec = in_dout[i];
#endif
out[i] = functor.Compute(forward_vec, dout_vec);
} }
while (idx == loop && tail) { while (idx == loop && tail) {
dx[num - tail] = in_data = forward_data[num - tail];
functor.ComputeRemainder(forward_data[num - tail], dout[num - tail]); dout_data = dout[num - tail];
dx[num - tail] = functor.ComputeRemainder(in_data, dout_data);
--tail; --tail;
} }
} }
...@@ -199,9 +295,14 @@ __global__ void ActivationkernelVec(const T* src, T* dst, int num, ...@@ -199,9 +295,14 @@ __global__ void ActivationkernelVec(const T* src, T* dst, int num,
int tail = num % vecsize; int tail = num % vecsize;
const VecType* in = reinterpret_cast<const VecType*>(src); const VecType* in = reinterpret_cast<const VecType*>(src);
VecType* out = reinterpret_cast<VecType*>(dst); VecType* out = reinterpret_cast<VecType*>(dst);
VecType x_vec;
for (int i = idx; i < loop; i += stride) { for (int i = idx; i < loop; i += stride) {
out[i] = functor.Compute((in + i)); #ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
x_vec = __ldg(in + i);
#else
x_vec = in[i];
#endif
out[i] = functor.Compute(x_vec);
} }
while (idx == loop && tail) { while (idx == loop && tail) {
...@@ -231,6 +332,10 @@ class ActivationGPUKernel ...@@ -231,6 +332,10 @@ class ActivationGPUKernel
block = 256; block = 256;
#endif #endif
Functor functor; Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
constexpr int vecsize = CudaVecType<T>::vecsize; constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((num / vecsize + block - 1) / block, 1); int grid = max((num / vecsize + block - 1) / block, 1);
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
...@@ -270,7 +375,12 @@ class ActivationGradGPUKernel ...@@ -270,7 +375,12 @@ class ActivationGradGPUKernel
#ifdef __HIPCC__ #ifdef __HIPCC__
block = 256; block = 256;
#endif #endif
Functor functor; Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
constexpr int vecsize = CudaVecType<T>::vecsize; constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((numel / vecsize + block - 1) / block, 1); int grid = max((numel / vecsize + block - 1) / block, 1);
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
...@@ -300,12 +410,28 @@ namespace plat = paddle::platform; ...@@ -300,12 +410,28 @@ namespace plat = paddle::platform;
ops::grad_functor<double>>, \ ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \ ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>); ops::grad_functor<plat::float16>>);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationGPUKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>, \
ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
/* ======================== leaky relu register ============================ */ /* ======================== leaky relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor,
LeakyReluGradFunctor); LeakyReluGradGPUFunctor);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
leaky_relu_grad_grad, leaky_relu_grad_grad,
...@@ -330,21 +456,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -330,21 +456,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */ /* ========================================================================== */
/* =========================== relu register ============================ */ /* =========================== relu register ============================ */
REGISTER_OP_CUDA_KERNEL( REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor);
relu, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGPUFunctor<float>>,
ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGPUFunctor<double>>,
ops::ActivationGPUKernel<plat::CUDADeviceContext,
ops::ReluGPUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGPUFunctor<float>>,
ops::ActivationGradGPUKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGPUFunctor<double>>,
ops::ActivationGradGPUKernel<plat::CUDADeviceContext,
ops::ReluGradGPUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
relu_grad_grad, relu_grad_grad,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册