未验证 提交 b2160e73 编写于 作者: Z Zhang Zheng 提交者: GitHub

add other 15 activation ops (#32622)

上级 74682530
......@@ -663,6 +663,640 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// log1p(x) = log(1 + x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log(one + x));
}
};
template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (one + args[1]);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// log2(x) = log2(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log2(x));
}
};
template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));
// dx = dout / (x * log(2))
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (args[1] * log_two);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// log10(x) = log10(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log10(x));
}
};
template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));
// dx = dout / (x * log(10))
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (args[1] * log_ten);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaBReluFunctor : public BaseActivationFunctor<T> {
float t_min;
float t_max;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
// brelu(x) = min(max(x, t_min), t_max)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T t_min_cast = static_cast<T>(t_min);
T t_max_cast = static_cast<T>(t_max);
T temp_max = x > t_min_cast ? x : t_min_cast;
T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast;
return temp_min;
}
};
template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float t_min;
float t_max;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
// dx = (x > t_min && x < t_max) ? dout : 0
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T dout = args[0];
T x = args[1];
T t_min_cast = static_cast<T>(t_min);
T t_max_cast = static_cast<T>(t_max);
return (x > t_min_cast && x < t_max_cast) ? dout : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
// Inputs: args[0], the input x
// threshold should not be negative
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
MPType t = static_cast<MPType>(threshold);
MPType temp_min = x < t ? x : t;
MPType temp_max = temp_min > -t ? temp_min : -t;
return static_cast<T>(log(one + exp(temp_max)));
}
};
template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
// Inputs: args[0], the input dout
// args[1], the input out
// threshold should not be negative
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType out = static_cast<MPType>(args[1]);
MPType t = static_cast<MPType>(threshold);
return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
: static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// stanh(x) = b * tanh(a * x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
return static_cast<T>(b * tanh(a * x));
}
};
template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
MPType temp = tanh(a * x);
return static_cast<T>(dout * a * b * (one - temp * temp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
}
};
template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return x_beta > t ? args[0] : static_cast<T>(dout / (one + exp(-x_beta)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// softsign(x) = x / (1 + abs(x))
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (one + abs(args[0]));
}
};
template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + abs(x))^2
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T temp = one + abs(args[1]);
return args[0] / (temp * temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// relu6(x) = min(max(0, x), 6)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T t = static_cast<T>(threshold);
return args[0] <= zero ? zero : (args[0] < t ? args[0] : t);
}
};
template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (out > 0 && out < t) ? dout : 0
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
T t = static_cast<T>(threshold);
return (args[1] > zero && args[1] < t) ? args[0] : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// tanhshrink(x) = x - tanh(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(x - tanh(x));
}
};
template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * tanh(x)^2
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * tanh(x) * tanh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : x;
}
};
template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (x > -threshold && x < threshold) ? 0 : dout
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[1];
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : args[0];
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
// hard_sigmoid(x) = 0, when x <= -3
// 1, when x >= 3
// x * slope + offset, otherwise
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T temp = args[0] * static_cast<T>(slope) + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < one ? temp_max : one;
return temp_min;
}
};
template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
float slope;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"slope", &slope}, {"offset", &offset}};
}
// dx = (out > 0 && out < 1) ? dout * slope : 0
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
T out = args[1];
return (out > zero && out < one) ? args[0] * static_cast<T>(slope) : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// swish(x) = x / (1 + exp(-beta * x))
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x)));
}
};
template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
MPType b = static_cast<MPType>(beta);
MPType temp1 = one / (one + exp(-b * x));
MPType out = x * temp1;
MPType temp2 = b * out;
MPType temp3 = temp1 * (one - temp2);
return static_cast<T>(dout * (temp2 + temp3));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// thresholded_relu(x) = x > threshold ? x : 0
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] > static_cast<T>(threshold) ? args[0] : zero;
}
};
template <typename T>
struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = x > threshold ? dout : 0
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[1] > static_cast<T>(threshold) ? args[0] : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// hard_swish(x) = 0, when x <= -offset
// x , when x >= threshold - offset
// x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T t = static_cast<T>(threshold);
T temp = x + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < t ? temp_max : t;
return temp_min * x / static_cast<T>(scale);
}
};
template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
T two = static_cast<T>(2.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[1];
T o = static_cast<T>(offset);
T s = static_cast<T>(scale);
T temp1 = static_cast<T>(x + o > zero);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
return args[0] * (temp1 * temp2 * (two * x + o) / s + one - temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
CT x = static_cast<CT>(args[0]);
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
temp_a_neg * temp_x_pos * (one + a * exp(x))));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -732,30 +1366,35 @@ class ActivationGradCudaKernel
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationKernel<plat::CUDADeviceContext, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
act_type##_grad, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<int>>, \
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
......@@ -764,6 +1403,10 @@ namespace plat = paddle::platform;
ops::grad_functor<float>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<int>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<int64_t>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
......@@ -782,7 +1425,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_ACTIVATION_GPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
......@@ -851,29 +1494,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== square register ============================ */
REGISTER_OP_CUDA_KERNEL(
square, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<float>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<int>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<int64_t>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
square_grad,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<int>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<plat::float16>>);
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
CudaSquareGradFunctor);
REGISTER_OP_CUDA_KERNEL(
square_grad_grad,
......@@ -890,7 +1512,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ========================== pow register ============================ */
REGISTER_OP_CUDA_KERNEL(
pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
......@@ -908,7 +1529,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ========================== exp register ============================ */
REGISTER_OP_CUDA_KERNEL(
exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<float>>,
......@@ -943,56 +1563,44 @@ REGISTER_OP_CUDA_KERNEL(
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
CudaSigmoidGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor,
CudaSiluGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,
CudaLogSigmoidGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor,
CudaAtanGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor,
CudaSoftShrinkGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CudaCeilFunctor,
CudaZeroGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, CudaFloorFunctor,
CudaZeroGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, CudaAcosFunctor,
CudaAcosGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, CudaAsinFunctor,
CudaAsinGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, CudaSinhFunctor,
CudaSinhGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CudaCoshFunctor,
CudaCoshGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor,
CudaZeroGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor,
CudaReciprocalGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, SoftReluFunctor,
SoftReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, SoftplusFunctor,
SoftplusGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, SoftsignFunctor,
SoftsignGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor,
TanhShrinkGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor,
HardShrinkGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,
HardSigmoidGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(thresholded_relu, ThresholdedRelu,
ThresholdedReluFunctor,
ThresholdedReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor,
HardSwishGradFunctor);
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(sigmoid, Sigmoid, CudaSigmoidFunctor, CudaSigmoidGradFunctor); \
__macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \
CudaLogSigmoidGradFunctor); \
__macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor); \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \
__macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \
__macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \
__macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); \
__macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); \
__macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor); \
__macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); \
__macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); \
__macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); \
__macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); \
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \
__macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \
__macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \
__macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \
__macro(brelu, BRelu, CudaBReluFunctor, CudaBReluGradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \
__macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); \
__macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \
__macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, \
CudaTanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \
CudaHardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \
CudaHardSigmoidGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor, \
CudaThresholdedReluGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
......@@ -2718,7 +2718,7 @@ create_test_act_fp16_class(TestRelu)
create_test_act_fp16_class(TestGelu)
create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu)
create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85)
create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestReciprocal)
create_test_act_fp16_class(TestLog)
......@@ -2736,7 +2736,7 @@ create_test_act_fp16_class(TestSoftplus)
create_test_act_fp16_class(TestSoftsign)
create_test_act_fp16_class(TestThresholdedRelu)
create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestSwish, grad_atol=0.85)
create_test_act_fp16_class(TestHardSwish)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册