未验证 提交 eca8dcc7 编写于 作者: Z Zhang Zheng 提交者: GitHub

Unify the implementation of activation operation (#32348)

上级 6f6e159a
......@@ -10,382 +10,719 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using float16 = paddle::platform::float16;
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
// relu(x) = max(x, 0)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] > zero ? args[0] : zero;
}
};
template <typename T>
struct CudaVecType {
using type = T;
static constexpr int vecsize = 1;
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
// dx = dout * (out > 0)
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
return args[1] > zero ? args[0] : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <>
struct CudaVecType<platform::float16> {
using type = __half2;
static constexpr int vecsize = 2;
template <typename T>
struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// leakyrelu(x) = x > 0 ? x : alpha * x
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] > zero ? args[0] : static_cast<T>(alpha) * args[0];
}
};
template <>
struct CudaVecType<float> {
using type = float4;
static constexpr int vecsize = 4;
template <typename T>
struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout * (x > 0 ? 1 : alpha)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[1] > zero ? args[0] : static_cast<T>(alpha) * args[0];
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
class BaseGPUFunctor {
public:
using ELEMENT_TYPE = T;
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// sigmoid(x) = 1 / (1 + exp(-x))
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(one / (one + exp(-x)));
}
};
using AttrPair = std::vector<std::pair<const char*, float*>>;
template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout * out * (1 - out)
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[1] * (one - args[1]);
}
AttrPair GetAttrs() { return AttrPair(); }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
/* ========================================================================== */
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
// MPType means Compute Type
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// silu(x) = x / (1 + exp(-x))
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(x / (one + exp(-x)));
}
};
/* =========================== relu forward ============================ */
template <typename T>
class ReluGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
MPType temp = one / (one + exp(-x));
return static_cast<T>(dout * (temp * (one + x * (one - temp))));
}
public:
ReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }
// for relu forward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type in) {
// relu forward : out = max(x, 0)
return in > zero_ ? in : zero_;
}
// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T in) {
// relu forward : out = max(x, 0)
return in > zero_ ? in : zero_;
}
};
template <>
__device__ __forceinline__ CudaVecType<float>::type
ReluGPUFunctor<float>::Compute(const CudaVecType<float>::type in) {
// relu forward : out = max(in, 0)
return make_float4((in.x > zero_) * (in.x), (in.y > zero_) * (in.y),
(in.z > zero_) * (in.z), (in.w > zero_) * (in.w));
}
template <>
__device__ __forceinline__ CudaVecType<float16>::type
ReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type in) {
// relu forward : out = max(in, 0)
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half2 kzero = __float2half2_rn(0.0f);
return __hmul2(__hgt2(in, kzero), in);
#else
const float2 xx = __half22float2(in);
return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(xx.x),
(xx.y > 0.0f) * static_cast<float>(xx.y));
#endif
}
/* ========================================================================== */
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
/* =========================== relu backward ============================
*/
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
// logsigmoid(x) = log(1 / (1 + exp(-x)))
// For numerical stability,
// logsigmoid(x) =
// - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
MPType temp = x > zero ? zero : -x;
return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
}
};
template <typename T>
class ReluGradGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
// dx = dout * exp(-x) / (1 + exp(-x))
// For numerical stability:
// dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
// 0)))
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
MPType temp1 = x > zero ? zero : -x;
MPType temp2 = exp(-x - temp1);
return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
}
public:
ReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// atan(x) = atan(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(atan(x));
}
};
template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + x^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (one + args[1] * args[1]);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda;
// 0, otherwise.
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T l = static_cast<T>(lambda);
T temp1 = static_cast<T>(x > l);
T temp2 = static_cast<T>(x < -l);
return temp1 * (x - l) + temp2 * (x + l);
}
};
template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// dx = dout, if x > lambda or x < -lambda else 0
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[1];
T l = static_cast<T>(lambda);
return (x >= -l && x <= l) ? zero : args[0];
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// ceil(x) = ceil(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(ceil(x));
}
};
template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// floor(x) = floor(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(floor(x));
}
};
template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// round(x) = round(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(round(x));
}
};
// grad functor for ceil, floor and round
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T* args) const {
return static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
};
template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// cos(x) = cos(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(cos(x));
}
};
template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * (-sin(x))
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(-dout * sin(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sin(x) = sin(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(sin(x));
}
};
template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * cos(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * cos(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// tan(x) = tan(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(tan(x));
}
};
template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout / cos(x)^2
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout / (cos(x) * cos(x)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// for relu backward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type out,
const typename CudaVecType<T>::type dout) {
return out > zero_ ? dout : zero_;
template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// asin(x) = asin(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(asin(x));
}
};
// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T out, const T dout) {
// relu backward : dx = out > 0 ? dout : 0
return out > zero_ ? dout : zero_;
template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout / sqrt(1 - x^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout / sqrt(one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// acos(x) = acos(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(acos(x));
}
};
template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = -dout / sqrt(1 - x^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(-dout / sqrt(one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// cosh(x) = cosh(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(cosh(x));
}
};
template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * sinh(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * sinh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sinh(x) = sinh(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(sinh(x));
}
};
template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * cosh(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType dout = static_cast<MPType>(args[0]);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * cosh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaTanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// tanh(x) = tanh(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(tanh(x));
}
};
template <typename T>
struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout * (1 - out^2)
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
T dout = static_cast<T>(args[0]);
T out = static_cast<T>(args[1]);
return dout * (one - out * out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <>
__device__ __forceinline__ CudaVecType<float>::type
ReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type out,
const CudaVecType<float>::type dout) {
// relu backward : dx = out > 0 ? dout : 0;
return make_float4((out.x > zero_) * (dout.x), (out.y > zero_) * (dout.y),
(out.z > zero_) * (dout.z), (out.w > zero_) * (dout.w));
}
template <>
__device__ __forceinline__ CudaVecType<float16>::type
ReluGradGPUFunctor<float16>::Compute(const CudaVecType<float16>::type out,
const CudaVecType<float16>::type dout) {
// relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half2 kzero = __float2half2_rn(0.0f);
return __hmul2(__hgt2(out, kzero), dout);
#else
const float2 xx = __half22float2(out);
const float2 yy = __half22float2(dout);
return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(yy.x),
(xx.y > 0.0f) * static_cast<float>(yy.y));
#endif
}
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// reciprocal(x) = 1 / x
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return one / args[0];
}
};
/* ========================================================================== */
/* ======================== leaky relu forward ========================
*/
template <typename T>
class LeakyReluGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
float alpha_;
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
// dx = -dout * out^2
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
return -args[0] * args[1] * args[1];
}
public:
LeakyReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha_}};
}
// leakyrelu forward : out = x > 0 ? x : x * alpha
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type in) {
return in > zero_ ? in : static_cast<T>(alpha_) * in;
}
__device__ __forceinline__ T ComputeRemainder(const T in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
return in > zero_ ? in : static_cast<T>(alpha_) * in;
}
};
template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGPUFunctor<float>::Compute(const CudaVecType<float>::type in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
return make_float4((in.x > zero_) ? (in.x) : (in.x) * alpha_,
(in.y > zero_) ? (in.y) : (in.y) * alpha_,
(in.z > zero_) ? (in.z) : (in.z) * alpha_,
(in.w > zero_) ? (in.w) : (in.w) * alpha_);
}
template <>
__device__ __forceinline__ CudaVecType<float16>::type
LeakyReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type in) {
// leakyrelu forward : out = x > 0 ? x : x * alpha
const float2 xx = __half22float2(in);
return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_,
(xx.y > 0.0f) ? xx.y : xx.y * alpha_);
}
/* ========================================================================== */
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// exp(x) = exp(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(exp(x));
}
};
/* =========================== leaky relu backward =======================
*/
template <typename T>
class LeakyReluGradGPUFunctor : public BaseGPUFunctor<T> {
private:
T zero_;
float alpha_;
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[1];
}
public:
LeakyReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha_}};
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// log(x) = log(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log(x));
}
};
// for leaky relu backward when T is double
__device__ __forceinline__ typename CudaVecType<T>::type Compute(
const typename CudaVecType<T>::type in,
const typename CudaVecType<T>::type dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return in > zero_ ? dout : static_cast<T>(alpha_) * dout;
template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
// dx = dout / x
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / args[1];
}
// when num % vecsize != 0 this func will be used
__device__ __forceinline__ T ComputeRemainder(const T in, const T dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return in > zero_ ? dout : static_cast<T>(alpha_) * dout;
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[0];
}
};
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
T two = static_cast<T>(2.0f);
// dx = dout * 2 * x
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * two * args[1];
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type in,
const CudaVecType<float>::type dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return make_float4((in.x > zero_) ? (dout.x) : alpha_ * (dout.x),
(in.y > zero_) ? (dout.y) : alpha_ * (dout.y),
(in.z > zero_) ? (dout.z) : alpha_ * (dout.z),
(in.w > zero_) ? (dout.w) : alpha_ * (dout.w));
}
template <>
__device__ __forceinline__ CudaVecType<float16>::type LeakyReluGradGPUFunctor<
float16>::Compute(const CudaVecType<float16>::type in,
const CudaVecType<float16>::type dout) {
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
const float2 xx = __half22float2(in);
const float2 yy = __half22float2(dout);
return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_ * yy.x,
(xx.y > 0.0f) ? yy.y : alpha_ * yy.y);
}
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sqrt(x) = sqrt(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(sqrt(x));
}
};
/* ========================================================================== */
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
T one_half = static_cast<T>(0.5f);
// dx = dout * 0.5 / out
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
return one_half * args[0] / args[1];
}
template <typename T, typename Functor>
__global__ void ActivationGradKernelVec(const T* forward_data, const T* dout,
T* dx, int num, Functor functor) {
using VecType = typename CudaVecType<T>::type;
constexpr int vecsize = CudaVecType<T>::vecsize;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int loop = num / vecsize;
int tail = num % vecsize;
const VecType* in_forward = reinterpret_cast<const VecType*>(forward_data);
const VecType* in_dout = reinterpret_cast<const VecType*>(dout);
VecType* out = reinterpret_cast<VecType*>(dx);
VecType forward_vec, dout_vec;
T in_data, dout_data;
for (int i = idx; i < loop; i += stride) {
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
forward_vec = __ldg(in_forward + i);
dout_vec = __ldg(in_dout + i);
#else
forward_vec = in_forward[i];
dout_vec = in_dout[i];
#endif
out[i] = functor.Compute(forward_vec, dout_vec);
}
while (idx == loop && tail) {
in_data = forward_data[num - tail];
dout_data = dout[num - tail];
dx[num - tail] = functor.ComputeRemainder(in_data, dout_data);
--tail;
}
}
template <typename T, typename Functor>
__global__ void ActivationkernelVec(const T* src, T* dst, int num,
Functor functor) {
constexpr int vecsize = CudaVecType<T>::vecsize;
using VecType = typename CudaVecType<T>::type;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int loop = num / vecsize;
int tail = num % vecsize;
const VecType* in = reinterpret_cast<const VecType*>(src);
VecType* out = reinterpret_cast<VecType*>(dst);
VecType x_vec;
for (int i = idx; i < loop; i += stride) {
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
x_vec = __ldg(in + i);
#else
x_vec = in[i];
#endif
out[i] = functor.Compute(x_vec);
}
while (idx == loop && tail) {
dst[num - tail] = functor.ComputeRemainder(src[num - tail]);
--tail;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// rsqrt(x) = rsqrt(x)
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T* args) const {
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(rsqrt(x));
}
};
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
// dx = dout * -0.5 / out^3
// Inputs: args[0], the input dout
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
T out = args[1];
return minus_one_half * args[0] * out * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename DeviceContext, typename Functor>
class ActivationGPUKernel
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = nullptr;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* x = nullptr;
framework::Tensor* out = nullptr;
ExtractActivationTensor(context, &in_x, &out);
auto& dev_ctx = context.template device_context<DeviceContext>();
int num = in_x->numel();
const T* input_data = in_x->data<T>();
T* output_data = out->mutable_data<T>(dev_ctx.GetPlace(),
static_cast<size_t>(num * sizeof(T)));
int block = 512;
#ifdef __HIPCC__
block = 256;
#endif
Functor functor;
ExtractActivationTensor(ctx, &x, &out);
out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = Functor();
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
*attr.second = ctx.Attr<float>(attr.first);
}
constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((num / vecsize + block - 1) / block, 1);
auto stream = context.cuda_device_context().stream();
ActivationkernelVec<T, Functor><<<grid, block, 0, stream>>>(
input_data, output_data, num, functor);
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T>(dev_ctx, ins, &outs,
functor);
}
};
template <typename DeviceContext, typename Functor>
class ActivationGradGPUKernel
class ActivationGradCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *x, *out, *d_out;
framework::Tensor* d_x = nullptr;
x = out = d_out = nullptr;
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &x, &out, &d_out,
ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
&d_x);
int numel = d_out->numel();
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* dx_data = d_x->mutable_data<T>(
dev_ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto* dout_data = d_out->data<T>();
d_x->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto functor = Functor();
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
std::vector<const framework::Tensor*> ins = {d_out};
std::vector<framework::Tensor*> outs = {d_x};
auto* forward_data = dout_data;
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
// Only need forward output Out
forward_data = out->data<T>();
ins.push_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(dev_ctx, ins,
&outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(kDepX)) {
// Only need forward input X
forward_data = x->data<T>();
}
int block = 512;
#ifdef __HIPCC__
block = 256;
#endif
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
ins.push_back(x);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T>(dev_ctx, ins,
&outs, functor);
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T>(dev_ctx, ins,
&outs, functor);
}
constexpr int vecsize = CudaVecType<T>::vecsize;
int grid = max((numel / vecsize + block - 1) / block, 1);
auto stream = context.cuda_device_context().stream();
ActivationGradKernelVec<T, Functor><<<grid, block, 0, stream>>>(
forward_data, dout_data, dx_data, numel, functor);
}
};
......@@ -395,12 +732,13 @@ class ActivationGradGPUKernel
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<double>>, \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
......@@ -410,28 +748,28 @@ namespace plat = paddle::platform;
ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext, \
act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext, \
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationGPUKernel<plat::CUDADeviceContext, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
act_type##_grad, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>, \
ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
/* ======================== leaky relu register ============================ */
REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor,
LeakyReluGradGPUFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor,
CudaLeakyReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
leaky_relu_grad_grad,
......@@ -444,7 +782,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
......@@ -456,7 +794,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== relu register ============================ */
REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
CudaReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
......@@ -469,7 +808,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== tanh register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, TanhFunctor, TanhGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, CudaTanhFunctor,
CudaTanhGradFunctor);
REGISTER_OP_CUDA_KERNEL(
tanh_grad_grad,
......@@ -482,7 +822,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
CudaSqrtGradFunctor);
REGISTER_OP_CUDA_KERNEL(
sqrt_grad_grad,
......@@ -496,7 +837,8 @@ REGISTER_OP_CUDA_KERNEL(
/* =========================== rsqrt register =============================
*/
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
CudaRsqrtGradFunctor);
REGISTER_OP_CUDA_KERNEL(
rsqrt_grad_grad,
......@@ -510,24 +852,28 @@ REGISTER_OP_CUDA_KERNEL(
/* =========================== square register ============================ */
REGISTER_OP_CUDA_KERNEL(
square,
ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<float>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<double>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<int>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<int64_t>>,
ops::ActivationKernel<plat::CUDADeviceContext,
ops::SquareFunctor<plat::float16>>);
square, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<float>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<int>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<int64_t>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
square_grad, ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::SquareGradFunctor<float>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::SquareGradFunctor<double>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::SquareGradFunctor<int>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::SquareGradFunctor<int64_t>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::SquareGradFunctor<plat::float16>>);
square_grad,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<int>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaSquareGradFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
square_grad_grad,
......@@ -564,27 +910,29 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================== exp register ============================ */
REGISTER_OP_CUDA_KERNEL(
exp, ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<float>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<double>>,
exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<float>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<double>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
ops::ActivationKernel<plat::CUDADeviceContext,
ops::ExpFunctor<plat::float16>>);
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
exp_grad, ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::ExpGradFunctor<float>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::ExpGradFunctor<double>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::ExpGradFunctor<int>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::ExpGradFunctor<int64_t>>,
ops::ActivationGradKernel<plat::CUDADeviceContext,
ops::ExpGradFunctor<plat::float16>>);
exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<int>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
REGISTER_OP_CUDA_KERNEL(
log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
......@@ -594,3 +942,57 @@ REGISTER_OP_CUDA_KERNEL(
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
CudaSigmoidGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(silu, Silu, CudaSiluFunctor,
CudaSiluGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,
CudaLogSigmoidGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(atan, Atan, CudaAtanFunctor,
CudaAtanGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor,
CudaSoftShrinkGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(ceil, Ceil, CudaCeilFunctor,
CudaZeroGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(floor, Floor, CudaFloorFunctor,
CudaZeroGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(cos, Cos, CudaCosFunctor, CudaCosGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(tan, Tan, CudaTanFunctor, CudaTanGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(acos, Acos, CudaAcosFunctor,
CudaAcosGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(sin, Sin, CudaSinFunctor, CudaSinGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(asin, Asin, CudaAsinFunctor,
CudaAsinGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(sinh, Sinh, CudaSinhFunctor,
CudaSinhGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(cosh, Cosh, CudaCoshFunctor,
CudaCoshGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(round, Round, CudaRoundFunctor,
CudaZeroGradFunctor);
REGISTER_ACTIVATION_CUDA_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor,
CudaReciprocalGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(brelu, BRelu, BReluFunctor, BReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(soft_relu, SoftRelu, SoftReluFunctor,
SoftReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(stanh, STanh, STanhFunctor, STanhGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(softplus, Softplus, SoftplusFunctor,
SoftplusGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(softsign, Softsign, SoftsignFunctor,
SoftsignGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(tanh_shrink, TanhShrink, TanhShrinkFunctor,
TanhShrinkGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_shrink, HardShrink, HardShrinkFunctor,
HardShrinkGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,
HardSigmoidGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(swish, Swish, SwishFunctor, SwishGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(thresholded_relu, ThresholdedRelu,
ThresholdedReluFunctor,
ThresholdedReluGradFunctor);
REGISTER_ACTIVATION_GPU_KERNEL(hard_swish, HardSwish, HardSwishFunctor,
HardSwishGradFunctor);
......@@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 + temp2).template cast<T>();
out.device(d) = x * (temp1 || temp2).template cast<T>();
}
};
......@@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 || temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册