未验证 提交 12bf0502 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement FunctionTraits to support two kinds of elementwise functor and...

Implement FunctionTraits to support two kinds of elementwise functor and remove some old codes for broadcast. (#35688)
上级 3493c46e
...@@ -24,15 +24,15 @@ struct CudaAbsFunctor; ...@@ -24,15 +24,15 @@ struct CudaAbsFunctor;
template <typename T> template <typename T>
struct CudaAbsFunctor<T, math::Complex<T, math::Real<T>>> { struct CudaAbsFunctor<T, math::Complex<T, math::Real<T>>> {
__device__ __forceinline__ math::Real<T> operator()(const T* args) const { __device__ __forceinline__ math::Real<T> operator()(const T& x) const {
return abs(args[0]); return abs(x);
} }
}; };
template <typename T> template <typename T>
struct CudaAbsFunctor<T, math::NoComplex<T, math::Real<T>>> { struct CudaAbsFunctor<T, math::NoComplex<T, math::Real<T>>> {
__device__ __forceinline__ T operator()(const T* args) const { __device__ __forceinline__ T operator()(const T& x) const {
return std::abs(args[0]); return std::abs(x);
} }
}; };
......
...@@ -24,9 +24,8 @@ struct CudaReluFunctor : public BaseActivationFunctor<T> { ...@@ -24,9 +24,8 @@ struct CudaReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f); T zero = static_cast<T>(0.0f);
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const { return x > zero ? x : zero;
return args[0] > zero ? args[0] : zero;
} }
}; };
...@@ -35,10 +34,8 @@ struct CudaReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -35,10 +34,8 @@ struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f); T zero = static_cast<T>(0.0f);
// dx = dout * (out > 0) // dx = dout * (out > 0)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return out > zero ? dout : zero;
__device__ __forceinline__ T operator()(const T* args) const {
return args[1] > zero ? args[0] : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -54,9 +51,8 @@ struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> { ...@@ -54,9 +51,8 @@ struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
} }
// leakyrelu(x) = x > 0 ? x : alpha * x // leakyrelu(x) = x > 0 ? x : alpha * x
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const { return x > zero ? x : static_cast<T>(alpha) * x;
return args[0] > zero ? args[0] : static_cast<T>(alpha) * args[0];
} }
}; };
...@@ -70,10 +66,8 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -70,10 +66,8 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = dout * (x > 0 ? 1 : alpha) // dx = dout * (x > 0 ? 1 : alpha)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return x > zero ? dout : static_cast<T>(alpha) * dout;
__device__ __forceinline__ T operator()(const T* args) const {
return args[1] > zero ? args[0] : static_cast<T>(alpha) * args[0];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -85,9 +79,8 @@ struct CudaSigmoidFunctor : public BaseActivationFunctor<T> { ...@@ -85,9 +79,8 @@ struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
// sigmoid(x) = 1 / (1 + exp(-x)) // sigmoid(x) = 1 / (1 + exp(-x))
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(one / (one + exp(-x))); return static_cast<T>(one / (one + exp(-x)));
} }
}; };
...@@ -97,10 +90,8 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -97,10 +90,8 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// dx = dout * out * (1 - out) // dx = dout * out * (1 - out)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return dout * out * (one - out);
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[1] * (one - args[1]);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -108,14 +99,12 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -108,14 +99,12 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> { struct CudaSiluFunctor : public BaseActivationFunctor<T> {
// MPType means Compute Type
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
// silu(x) = x / (1 + exp(-x)) // silu(x) = x / (1 + exp(-x))
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(x / (one + exp(-x))); return static_cast<T>(x / (one + exp(-x)));
} }
}; };
...@@ -126,11 +115,10 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> { ...@@ -126,11 +115,10 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2) // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
MPType temp = one / (one + exp(-x)); MPType temp = one / (one + exp(-x));
return static_cast<T>(dout * (temp * (one + x * (one - temp)))); return static_cast<T>(dout * (temp * (one + x * (one - temp))));
} }
...@@ -147,9 +135,8 @@ struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> { ...@@ -147,9 +135,8 @@ struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
// For numerical stability, // For numerical stability,
// logsigmoid(x) = // logsigmoid(x) =
// - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
MPType temp = x > zero ? zero : -x; MPType temp = x > zero ? zero : -x;
return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp))); return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
} }
...@@ -164,11 +151,10 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -164,11 +151,10 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
// For numerical stability: // For numerical stability:
// dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x, // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
// 0))) // 0)))
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
MPType temp1 = x > zero ? zero : -x; MPType temp1 = x > zero ? zero : -x;
MPType temp2 = exp(-x - temp1); MPType temp2 = exp(-x - temp1);
return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2))); return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
...@@ -182,9 +168,8 @@ struct CudaAtanFunctor : public BaseActivationFunctor<T> { ...@@ -182,9 +168,8 @@ struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// atan(x) = atan(x) // atan(x) = atan(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(atan(x)); return static_cast<T>(atan(x));
} }
}; };
...@@ -194,10 +179,8 @@ struct CudaAtanGradFunctor : public BaseActivationFunctor<T> { ...@@ -194,10 +179,8 @@ struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// dx = dout / (1 + x^2) // dx = dout / (1 + x^2)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return dout / (one + x * x);
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (one + args[1] * args[1]);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -214,9 +197,7 @@ struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -214,9 +197,7 @@ struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
// softshrink(x) = x - lambda, if x > lambda; // softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda; // x + lambda, if x < -lambda;
// 0, otherwise. // 0, otherwise.
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T l = static_cast<T>(lambda); T l = static_cast<T>(lambda);
T temp1 = static_cast<T>(x > l); T temp1 = static_cast<T>(x > l);
T temp2 = static_cast<T>(x < -l); T temp2 = static_cast<T>(x < -l);
...@@ -234,12 +215,9 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -234,12 +215,9 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = dout, if x > lambda or x < -lambda else 0 // dx = dout, if x > lambda or x < -lambda else 0
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[1];
T l = static_cast<T>(lambda); T l = static_cast<T>(lambda);
return (x >= -l && x <= l) ? zero : args[0]; return (x >= -l && x <= l) ? zero : dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -250,9 +228,8 @@ struct CudaCeilFunctor : public BaseActivationFunctor<T> { ...@@ -250,9 +228,8 @@ struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// ceil(x) = ceil(x) // ceil(x) = ceil(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(ceil(x)); return static_cast<T>(ceil(x));
} }
}; };
...@@ -262,9 +239,8 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> { ...@@ -262,9 +239,8 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// floor(x) = floor(x) // floor(x) = floor(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(floor(x)); return static_cast<T>(floor(x));
} }
}; };
...@@ -274,17 +250,16 @@ struct CudaRoundFunctor : public BaseActivationFunctor<T> { ...@@ -274,17 +250,16 @@ struct CudaRoundFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// round(x) = round(x) // round(x) = round(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(round(x)); return static_cast<T>(round(x));
} }
}; };
// grad functor for ceil, floor and round // GradFunctor for ceil, floor and round
template <typename T> template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> { struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T* args) const { __device__ __forceinline__ T operator()(const T& x) const {
return static_cast<T>(0.0f); return static_cast<T>(0.0f);
} }
...@@ -296,9 +271,8 @@ struct CudaCosFunctor : public BaseActivationFunctor<T> { ...@@ -296,9 +271,8 @@ struct CudaCosFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// cos(x) = cos(x) // cos(x) = cos(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(cos(x)); return static_cast<T>(cos(x));
} }
}; };
...@@ -308,11 +282,10 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> { ...@@ -308,11 +282,10 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * (-sin(x)) // dx = dout * (-sin(x))
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(-dout * sin(x)); return static_cast<T>(-dout * sin(x));
} }
...@@ -324,9 +297,8 @@ struct CudaSinFunctor : public BaseActivationFunctor<T> { ...@@ -324,9 +297,8 @@ struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// sin(x) = sin(x) // sin(x) = sin(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(sin(x)); return static_cast<T>(sin(x));
} }
}; };
...@@ -336,11 +308,10 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> { ...@@ -336,11 +308,10 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * cos(x) // dx = dout * cos(x)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * cos(x)); return static_cast<T>(dout * cos(x));
} }
...@@ -352,9 +323,8 @@ struct CudaTanFunctor : public BaseActivationFunctor<T> { ...@@ -352,9 +323,8 @@ struct CudaTanFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// tan(x) = tan(x) // tan(x) = tan(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(tan(x)); return static_cast<T>(tan(x));
} }
}; };
...@@ -364,11 +334,10 @@ struct CudaTanGradFunctor : public BaseActivationFunctor<T> { ...@@ -364,11 +334,10 @@ struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout / cos(x)^2 // dx = dout / cos(x)^2
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout / (cos(x) * cos(x))); return static_cast<T>(dout / (cos(x) * cos(x)));
} }
...@@ -380,9 +349,8 @@ struct CudaAsinFunctor : public BaseActivationFunctor<T> { ...@@ -380,9 +349,8 @@ struct CudaAsinFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// asin(x) = asin(x) // asin(x) = asin(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(asin(x)); return static_cast<T>(asin(x));
} }
}; };
...@@ -393,11 +361,10 @@ struct CudaAsinGradFunctor : public BaseActivationFunctor<T> { ...@@ -393,11 +361,10 @@ struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
// dx = dout / sqrt(1 - x^2) // dx = dout / sqrt(1 - x^2)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout / sqrt(one - x * x)); return static_cast<T>(dout / sqrt(one - x * x));
} }
...@@ -409,9 +376,8 @@ struct CudaAcosFunctor : public BaseActivationFunctor<T> { ...@@ -409,9 +376,8 @@ struct CudaAcosFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// acos(x) = acos(x) // acos(x) = acos(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(acos(x)); return static_cast<T>(acos(x));
} }
}; };
...@@ -422,11 +388,10 @@ struct CudaAcosGradFunctor : public BaseActivationFunctor<T> { ...@@ -422,11 +388,10 @@ struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
// dx = -dout / sqrt(1 - x^2) // dx = -dout / sqrt(1 - x^2)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(-dout / sqrt(one - x * x)); return static_cast<T>(-dout / sqrt(one - x * x));
} }
...@@ -438,9 +403,8 @@ struct CudaCoshFunctor : public BaseActivationFunctor<T> { ...@@ -438,9 +403,8 @@ struct CudaCoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// cosh(x) = cosh(x) // cosh(x) = cosh(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(cosh(x)); return static_cast<T>(cosh(x));
} }
}; };
...@@ -450,11 +414,10 @@ struct CudaCoshGradFunctor : public BaseActivationFunctor<T> { ...@@ -450,11 +414,10 @@ struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * sinh(x) // dx = dout * sinh(x)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * sinh(x)); return static_cast<T>(dout * sinh(x));
} }
...@@ -466,9 +429,8 @@ struct CudaSinhFunctor : public BaseActivationFunctor<T> { ...@@ -466,9 +429,8 @@ struct CudaSinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// sinh(x) = sinh(x) // sinh(x) = sinh(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(sinh(x)); return static_cast<T>(sinh(x));
} }
}; };
...@@ -478,11 +440,10 @@ struct CudaSinhGradFunctor : public BaseActivationFunctor<T> { ...@@ -478,11 +440,10 @@ struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * cosh(x) // dx = dout * cosh(x)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * cosh(x)); return static_cast<T>(dout * cosh(x));
} }
...@@ -494,9 +455,8 @@ struct CudaTanhFunctor : public BaseActivationFunctor<T> { ...@@ -494,9 +455,8 @@ struct CudaTanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// tanh(x) = tanh(x) // tanh(x) = tanh(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(tanh(x)); return static_cast<T>(tanh(x));
} }
}; };
...@@ -506,11 +466,7 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -506,11 +466,7 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// dx = dout * (1 - out^2) // dx = dout * (1 - out^2)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
T dout = static_cast<T>(args[0]);
T out = static_cast<T>(args[1]);
return dout * (one - out * out); return dout * (one - out * out);
} }
...@@ -522,19 +478,14 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> { ...@@ -522,19 +478,14 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const { return one / x; }
__device__ __forceinline__ T operator()(const T* args) const {
return one / args[0];
}
}; };
template <typename T> template <typename T>
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> { struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
// dx = -dout * out^2 // dx = -dout * out^2
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return -dout * out * out;
__device__ __forceinline__ T operator()(const T* args) const {
return -args[0] * args[1] * args[1];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -545,9 +496,8 @@ struct CudaExpFunctor : public BaseActivationFunctor<T> { ...@@ -545,9 +496,8 @@ struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// exp(x) = exp(x) // exp(x) = exp(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(exp(x)); return static_cast<T>(exp(x));
} }
}; };
...@@ -555,10 +505,8 @@ struct CudaExpFunctor : public BaseActivationFunctor<T> { ...@@ -555,10 +505,8 @@ struct CudaExpFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> { struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out // dx = dout * out
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return dout * out;
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[1];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -569,9 +517,8 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> { ...@@ -569,9 +517,8 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// expm1(x) = expm1(x) // expm1(x) = expm1(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(expm1(x)); return static_cast<T>(expm1(x));
} }
}; };
...@@ -579,10 +526,8 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> { ...@@ -579,10 +526,8 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> { struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out // dx = dout * out
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return dout * out + dout;
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[1] + args[0];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -593,9 +538,8 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> { ...@@ -593,9 +538,8 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// log(x) = log(x) // log(x) = log(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log(x)); return static_cast<T>(log(x));
} }
}; };
...@@ -603,10 +547,8 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> { ...@@ -603,10 +547,8 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> { struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
// dx = dout / x // dx = dout / x
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return dout / x;
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / args[1];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -615,10 +557,7 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> { ...@@ -615,10 +557,7 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> { struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x // square(x) = x * x
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const { return x * x; }
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * args[0];
}
}; };
template <typename T> template <typename T>
...@@ -626,10 +565,8 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> { ...@@ -626,10 +565,8 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
T two = static_cast<T>(2.0f); T two = static_cast<T>(2.0f);
// dx = dout * 2 * x // dx = dout * 2 * x
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return dout * two * x;
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] * two * args[1];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -640,9 +577,8 @@ struct CudaSqrtFunctor : public BaseActivationFunctor<T> { ...@@ -640,9 +577,8 @@ struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// sqrt(x) = sqrt(x) // sqrt(x) = sqrt(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(sqrt(x)); return static_cast<T>(sqrt(x));
} }
}; };
...@@ -652,10 +588,8 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -652,10 +588,8 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
T one_half = static_cast<T>(0.5f); T one_half = static_cast<T>(0.5f);
// dx = dout * 0.5 / out // dx = dout * 0.5 / out
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return one_half * dout / out;
__device__ __forceinline__ T operator()(const T* args) const {
return one_half * args[0] / args[1];
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -666,9 +600,8 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> { ...@@ -666,9 +600,8 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// rsqrt(x) = rsqrt(x) // rsqrt(x) = rsqrt(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(rsqrt(x)); return static_cast<T>(rsqrt(x));
} }
}; };
...@@ -677,12 +610,9 @@ template <typename T> ...@@ -677,12 +610,9 @@ template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> { struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f); T minus_one_half = static_cast<T>(-0.5f);
// dx = dout * -0.5 / out^3 // dx = -0.5 * dout * out^3
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return minus_one_half * dout * out * out * out;
__device__ __forceinline__ T operator()(const T* args) const {
T out = args[1];
return minus_one_half * args[0] * out * out * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -694,9 +624,8 @@ struct CudaLog1pFunctor : public BaseActivationFunctor<T> { ...@@ -694,9 +624,8 @@ struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
// log1p(x) = log(1 + x) // log1p(x) = log(1 + x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log(one + x)); return static_cast<T>(log(one + x));
} }
}; };
...@@ -706,10 +635,8 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> { ...@@ -706,10 +635,8 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// dx = dout / (1 + x) // dx = dout / (1 + x)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return dout / (one + x);
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (one + args[1]);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -720,9 +647,8 @@ struct CudaLog2Functor : public BaseActivationFunctor<T> { ...@@ -720,9 +647,8 @@ struct CudaLog2Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// log2(x) = log2(x) // log2(x) = log2(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log2(x)); return static_cast<T>(log2(x));
} }
}; };
...@@ -733,10 +659,8 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> { ...@@ -733,10 +659,8 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
T log_two = static_cast<T>(log(static_cast<MPType>(2.0f))); T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));
// dx = dout / (x * log(2)) // dx = dout / (x * log(2))
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return dout / (x * log_two);
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (args[1] * log_two);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -747,9 +671,8 @@ struct CudaLog10Functor : public BaseActivationFunctor<T> { ...@@ -747,9 +671,8 @@ struct CudaLog10Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// log10(x) = log10(x) // log10(x) = log10(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(log10(x)); return static_cast<T>(log10(x));
} }
}; };
...@@ -760,10 +683,8 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> { ...@@ -760,10 +683,8 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f))); T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));
// dx = dout / (x * log(10)) // dx = dout / (x * log(10))
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return dout / (x * log_ten);
__device__ __forceinline__ T operator()(const T* args) const {
return args[0] / (args[1] * log_ten);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -779,9 +700,7 @@ struct CudaBReluFunctor : public BaseActivationFunctor<T> { ...@@ -779,9 +700,7 @@ struct CudaBReluFunctor : public BaseActivationFunctor<T> {
} }
// brelu(x) = min(max(x, t_min), t_max) // brelu(x) = min(max(x, t_min), t_max)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T t_min_cast = static_cast<T>(t_min); T t_min_cast = static_cast<T>(t_min);
T t_max_cast = static_cast<T>(t_max); T t_max_cast = static_cast<T>(t_max);
T temp_max = x > t_min_cast ? x : t_min_cast; T temp_max = x > t_min_cast ? x : t_min_cast;
...@@ -801,11 +720,7 @@ struct CudaBReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -801,11 +720,7 @@ struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = (x > t_min && x < t_max) ? dout : 0 // dx = (x > t_min && x < t_max) ? dout : 0
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T dout = args[0];
T x = args[1];
T t_min_cast = static_cast<T>(t_min); T t_min_cast = static_cast<T>(t_min);
T t_max_cast = static_cast<T>(t_max); T t_max_cast = static_cast<T>(t_max);
return (x > t_min_cast && x < t_max_cast) ? dout : zero; return (x > t_min_cast && x < t_max_cast) ? dout : zero;
...@@ -825,10 +740,9 @@ struct CudaSoftReluFunctor : public BaseActivationFunctor<T> { ...@@ -825,10 +740,9 @@ struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
} }
// soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold))) // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
// Inputs: args[0], the input x
// threshold should not be negative // threshold should not be negative
__device__ __forceinline__ T operator()(const T* args) const { __device__ __forceinline__ T operator()(const T& arg_x) const {
MPType x = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType t = static_cast<MPType>(threshold); MPType t = static_cast<MPType>(threshold);
MPType temp_min = x < t ? x : t; MPType temp_min = x < t ? x : t;
MPType temp_max = temp_min > -t ? temp_min : -t; MPType temp_max = temp_min > -t ? temp_min : -t;
...@@ -847,12 +761,11 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -847,12 +761,11 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0 // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
// Inputs: args[0], the input dout
// args[1], the input out
// threshold should not be negative // threshold should not be negative
__device__ __forceinline__ T operator()(const T* args) const { __device__ __forceinline__ T operator()(const T& arg_dout,
MPType dout = static_cast<MPType>(args[0]); const T& arg_out) const {
MPType out = static_cast<MPType>(args[1]); MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType t = static_cast<MPType>(threshold); MPType t = static_cast<MPType>(threshold);
return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out))) return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
: static_cast<T>(0.0f); : static_cast<T>(0.0f);
...@@ -872,9 +785,8 @@ struct CudaSTanhFunctor : public BaseActivationFunctor<T> { ...@@ -872,9 +785,8 @@ struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
} }
// stanh(x) = b * tanh(a * x) // stanh(x) = b * tanh(a * x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
MPType a = static_cast<MPType>(scale_a); MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b); MPType b = static_cast<MPType>(scale_b);
return static_cast<T>(b * tanh(a * x)); return static_cast<T>(b * tanh(a * x));
...@@ -893,11 +805,10 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -893,11 +805,10 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x)) // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
MPType a = static_cast<MPType>(scale_a); MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b); MPType b = static_cast<MPType>(scale_b);
MPType temp = tanh(a * x); MPType temp = tanh(a * x);
...@@ -919,9 +830,8 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor<T> { ...@@ -919,9 +830,8 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
} }
// softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
MPType b = static_cast<MPType>(beta); MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold); MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta; MPType x_beta = x * beta;
...@@ -941,15 +851,14 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> { ...@@ -941,15 +851,14 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
MPType b = static_cast<MPType>(beta); MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold); MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta; MPType x_beta = x * beta;
return x_beta > t ? args[0] : static_cast<T>(dout / (one + exp(-x_beta))); return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -960,9 +869,8 @@ struct CudaSoftsignFunctor : public BaseActivationFunctor<T> { ...@@ -960,9 +869,8 @@ struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// softsign(x) = x / (1 + abs(x)) // softsign(x) = x / (1 + abs(x))
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const { return x / (one + abs(x));
return args[0] / (one + abs(args[0]));
} }
}; };
...@@ -971,11 +879,9 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> { ...@@ -971,11 +879,9 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
// dx = dout / (1 + abs(x))^2 // dx = dout / (1 + abs(x))^2
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x T temp = one + abs(x);
__device__ __forceinline__ T operator()(const T* args) const { return dout / (temp * temp);
T temp = one + abs(args[1]);
return args[0] / (temp * temp);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -991,10 +897,9 @@ struct CudaRelu6Functor : public BaseActivationFunctor<T> { ...@@ -991,10 +897,9 @@ struct CudaRelu6Functor : public BaseActivationFunctor<T> {
} }
// relu6(x) = min(max(0, x), 6) // relu6(x) = min(max(0, x), 6)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const {
T t = static_cast<T>(threshold); T t = static_cast<T>(threshold);
return args[0] <= zero ? zero : (args[0] < t ? args[0] : t); return x <= zero ? zero : (x < t ? x : t);
} }
}; };
...@@ -1008,11 +913,9 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -1008,11 +913,9 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
} }
// dx = (out > 0 && out < t) ? dout : 0 // dx = (out > 0 && out < t) ? dout : 0
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out
__device__ __forceinline__ T operator()(const T* args) const {
T t = static_cast<T>(threshold); T t = static_cast<T>(threshold);
return (args[1] > zero && args[1] < t) ? args[0] : zero; return (out > zero && out < t) ? dout : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -1023,9 +926,8 @@ struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -1023,9 +926,8 @@ struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// tanhshrink(x) = x - tanh(x) // tanhshrink(x) = x - tanh(x)
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
return static_cast<T>(x - tanh(x)); return static_cast<T>(x - tanh(x));
} }
}; };
...@@ -1035,11 +937,10 @@ struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -1035,11 +937,10 @@ struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * tanh(x)^2 // dx = dout * tanh(x)^2
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
return static_cast<T>(dout * tanh(x) * tanh(x)); return static_cast<T>(dout * tanh(x) * tanh(x));
} }
...@@ -1056,9 +957,7 @@ struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -1056,9 +957,7 @@ struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
} }
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T t = static_cast<T>(threshold); T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : x; return (x > -t && x < t) ? zero : x;
} }
...@@ -1074,12 +973,9 @@ struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -1074,12 +973,9 @@ struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = (x > -threshold && x < threshold) ? 0 : dout // dx = (x > -threshold && x < threshold) ? 0 : dout
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[1];
T t = static_cast<T>(threshold); T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : args[0]; return (x > -t && x < t) ? zero : dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -1099,9 +995,8 @@ struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> { ...@@ -1099,9 +995,8 @@ struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
// hard_sigmoid(x) = 0, when x <= -3 // hard_sigmoid(x) = 0, when x <= -3
// 1, when x >= 3 // 1, when x >= 3
// x * slope + offset, otherwise // x * slope + offset, otherwise
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const { T temp = x * static_cast<T>(slope) + static_cast<T>(offset);
T temp = args[0] * static_cast<T>(slope) + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero; T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < one ? temp_max : one; T temp_min = temp_max < one ? temp_max : one;
return temp_min; return temp_min;
...@@ -1120,11 +1015,8 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -1120,11 +1015,8 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = (out > 0 && out < 1) ? dout * slope : 0 // dx = (out > 0 && out < 1) ? dout * slope : 0
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
// args[1], the input out return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
__device__ __forceinline__ T operator()(const T* args) const {
T out = args[1];
return (out > zero && out < one) ? args[0] * static_cast<T>(slope) : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
...@@ -1141,9 +1033,8 @@ struct CudaSwishFunctor : public BaseActivationFunctor<T> { ...@@ -1141,9 +1033,8 @@ struct CudaSwishFunctor : public BaseActivationFunctor<T> {
} }
// swish(x) = x / (1 + exp(-beta * x)) // swish(x) = x / (1 + exp(-beta * x))
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[0]);
MPType b = static_cast<MPType>(beta); MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x))); return static_cast<T>(x / (one + exp(-b * x)));
} }
...@@ -1160,11 +1051,10 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1160,11 +1051,10 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2) // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
MPType b = static_cast<MPType>(beta); MPType b = static_cast<MPType>(beta);
MPType temp1 = one / (one + exp(-b * x)); MPType temp1 = one / (one + exp(-b * x));
MPType out = x * temp1; MPType out = x * temp1;
...@@ -1186,9 +1076,8 @@ struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> { ...@@ -1186,9 +1076,8 @@ struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
} }
// thresholded_relu(x) = x > threshold ? x : 0 // thresholded_relu(x) = x > threshold ? x : 0
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const { return x > static_cast<T>(threshold) ? x : zero;
return args[0] > static_cast<T>(threshold) ? args[0] : zero;
} }
}; };
...@@ -1202,10 +1091,8 @@ struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1202,10 +1091,8 @@ struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
} }
// dx = x > threshold ? dout : 0 // dx = x > threshold ? dout : 0
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x return x > static_cast<T>(threshold) ? dout : zero;
__device__ __forceinline__ T operator()(const T* args) const {
return args[1] > static_cast<T>(threshold) ? args[0] : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -1226,9 +1113,7 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> { ...@@ -1226,9 +1113,7 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
// x , when x >= threshold - offset // x , when x >= threshold - offset
// x * (x + offset) / scale, otherwise // x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default // threshold = scale = 6, offset = 3 by default
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& x) const {
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[0];
T t = static_cast<T>(threshold); T t = static_cast<T>(threshold);
T temp = x + static_cast<T>(offset); T temp = x + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero; T temp_max = temp > zero ? temp : zero;
...@@ -1254,15 +1139,12 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1254,15 +1139,12 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
// dout , when x >= threshold - offset // dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise // dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default // threshold = scale = 6, offset = 3 by default
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
// args[1], the input x
__device__ __forceinline__ T operator()(const T* args) const {
T x = args[1];
T o = static_cast<T>(offset); T o = static_cast<T>(offset);
T s = static_cast<T>(scale); T s = static_cast<T>(scale);
T temp1 = static_cast<T>(x + o > zero); T temp1 = static_cast<T>(x + o > zero);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold)); T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
return args[0] * (temp1 * temp2 * (two * x + o) / s + one - temp2); return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
...@@ -1280,9 +1162,8 @@ struct CudaELUFunctor : public BaseActivationFunctor<T> { ...@@ -1280,9 +1162,8 @@ struct CudaELUFunctor : public BaseActivationFunctor<T> {
} }
// elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1)) // elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
// Inputs: args[0], the input x __device__ __forceinline__ T operator()(const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { CT x = static_cast<CT>(arg_x);
CT x = static_cast<CT>(args[0]);
CT temp = static_cast<CT>(alpha) * (exp(x) - one); CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp); CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res); return static_cast<T>(res);
...@@ -1304,11 +1185,10 @@ struct CudaELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -1304,11 +1185,10 @@ struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0 // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0 // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0 // dx = 0, if alpha <= 0 and x <=0
// Inputs: args[0], the input dout __device__ __forceinline__ T operator()(const T& arg_dout,
// args[1], the input x const T& arg_x) const {
__device__ __forceinline__ T operator()(const T* args) const { MPType dout = static_cast<MPType>(arg_dout);
MPType dout = static_cast<MPType>(args[0]); MPType x = static_cast<MPType>(arg_x);
MPType x = static_cast<MPType>(args[1]);
MPType a = static_cast<MPType>(alpha); MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f); MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f); MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
......
...@@ -18,60 +18,46 @@ limitations under the License. */ ...@@ -18,60 +18,46 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ template <typename Functor>
template <typename T> \ class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
struct Bitwise##func##CUDAFunctor { \ : public framework::OpKernel<typename Functor::ELEM_TYPE> {
using ELEM_TYPE = T; \ public:
HOSTDEVICE T operator()(const T* args) const { \ void Compute(const framework::ExecutionContext& ctx) const override {
return args[0] expr args[1]; \ using T = typename Functor::ELEM_TYPE;
} \
}; \
\
template <> \
struct Bitwise##func##CUDAFunctor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool* args) const { \
return args[0] bool_expr args[1]; \
} \
};
BITWISE_BINARY_FUNCTOR(And, &, &&)
BITWISE_BINARY_FUNCTOR(Or, |, ||)
BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
#undef BITWISE_BINARY_FUNCTOR
template <typename T> auto* x = ctx.Input<framework::Tensor>("X");
struct BitwiseNotCUDAFunctor { auto* y = ctx.Input<framework::Tensor>("Y");
using ELEM_TYPE = T; auto* out = ctx.Output<framework::Tensor>("Out");
HOSTDEVICE T operator()(const T* args) const { return ~args[0]; } out->mutable_data<T>(ctx.GetPlace());
};
template <> auto functor = Functor();
struct BitwiseNotCUDAFunctor<bool> { std::vector<const framework::Tensor*> ins = {x, y};
using ELEM_TYPE = bool; std::vector<framework::Tensor*> outs = {out};
HOSTDEVICE bool operator()(const bool* args) const { return !args[0]; } const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, -1, functor);
}
}; };
template <typename Functor> template <typename Functor>
class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor> class UnaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
public: public:
using T = typename Functor::ELEM_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto functor = Functor(); auto functor = Functor();
std::vector<const framework::Tensor*> ins; std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs; std::vector<framework::Tensor*> outs = {out};
const auto& cuda_ctx = const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs); LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
cuda_ctx, ins, &outs, functor);
if (ins.size() == 1) {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
cuda_ctx, ins, &outs, axis, functor);
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, functor);
}
} }
}; };
...@@ -81,7 +67,7 @@ class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor> ...@@ -81,7 +67,7 @@ class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
namespace ops = ::paddle::operators; namespace ops = ::paddle::operators;
namespace plat = ::paddle::platform; namespace plat = ::paddle::platform;
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndCUDAFunctor); REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrCUDAFunctor); REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorCUDAFunctor); REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotCUDAFunctor); REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotFunctor);
...@@ -17,9 +17,6 @@ limitations under the License. */ ...@@ -17,9 +17,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,23 +35,6 @@ struct BitwiseAdd { ...@@ -38,23 +35,6 @@ struct BitwiseAdd {
} }
}; };
template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
}
};
template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
...@@ -97,6 +77,9 @@ class CompareReduceOpKernel ...@@ -97,6 +77,9 @@ class CompareReduceOpKernel
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ #define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
op_type, \ op_type, \
...@@ -109,5 +92,5 @@ class CompareReduceOpKernel ...@@ -109,5 +92,5 @@ class CompareReduceOpKernel
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>); ops::functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor) REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, EqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL #undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
...@@ -21,46 +21,11 @@ namespace plat = paddle::platform; ...@@ -21,46 +21,11 @@ namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \
template <typename T, typename Enable = void> \
struct func { \
using ELEMENT_TYPE = T; \
inline HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
template <typename T>
struct CudaEqualFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename T>
struct CudaNotEqualFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const {
return fabs(static_cast<double>(args[0] - args[1])) > 1e-8;
}
};
template <typename Functor, typename InverseFunctor> template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor> class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
public: public:
using InT = typename Functor::ELEMENT_TYPE; using InT = typename Functor::ELEM_TYPE;
using OutT = bool; using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor(); auto functor = Functor();
...@@ -87,10 +52,10 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor> ...@@ -87,10 +52,10 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>); ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor) REGISTER_CUDA_COMPARE_KERNEL(equal, EqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor) REGISTER_CUDA_COMPARE_KERNEL(not_equal, NotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor) REGISTER_CUDA_COMPARE_KERNEL(less_than, LessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor) REGISTER_CUDA_COMPARE_KERNEL(less_equal, LessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor) REGISTER_CUDA_COMPARE_KERNEL(greater_than, GreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor) REGISTER_CUDA_COMPARE_KERNEL(greater_equal, GreaterEqualFunctor)
#undef REGISTER_CUDA_COMPARE_KERNEL #undef REGISTER_CUDA_COMPARE_KERNEL
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
...@@ -24,21 +25,6 @@ namespace plat = paddle::platform; ...@@ -24,21 +25,6 @@ namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/*
input: an array;
return: the result of the math functor
1. For Unary Op, the length of input array is 1,
e.g. Relu: return args[0] > 0 ? args[0] : 0;
2. For Binary Op, the length of input array is 2,
e.g. Add: return args[0] expr args[1];
*/
template <typename T>
struct CudaAddFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] + args[1];
}
};
template <typename T> template <typename T>
class ElementwiseAddKernel<platform::CUDADeviceContext, T> class ElementwiseAddKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -51,7 +37,7 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T> ...@@ -51,7 +37,7 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs); int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>()); cuda_ctx, ins, &outs, axis, AddFunctor<T>());
} }
}; };
......
...@@ -24,13 +24,6 @@ namespace plat = paddle::platform; ...@@ -24,13 +24,6 @@ namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct CudaMulFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] * args[1];
}
};
template <typename T> template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T> class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -44,7 +37,7 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T> ...@@ -44,7 +37,7 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows); int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>()); cuda_ctx, ins, &outs, axis, MulFunctor<T>());
} }
}; };
......
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig
namespace kps = paddle::operators::kernel_primitives; namespace kps = paddle::operators::kernel_primitives;
struct DimensionsTransform { struct DimensionsTransform {
...@@ -46,10 +45,9 @@ struct DimensionsTransform { ...@@ -46,10 +45,9 @@ struct DimensionsTransform {
axis++; axis++;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth dimension of input tensor is expected to be equal " "The %d-th dimension of input tensor is expected to be equal "
"with" "with the %d-th dimension of output tensor %d or 1, but "
"the %dth dimension of output tensor %d or 1, but recieved " "recieved %d.",
"%d.\n",
in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx])); in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
} }
} while (in_idx < in_dim.size()); } while (in_idx < in_dim.size());
...@@ -61,10 +59,9 @@ struct DimensionsTransform { ...@@ -61,10 +59,9 @@ struct DimensionsTransform {
in_idx++; in_idx++;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth dimension of input tensor is expected to be equal " "The %d-th dimension of input tensor is expected to be equal "
"with" "with the %d-th dimension of output tensor %d or 1, but "
"the %dth dimension of output tensor %d or 1, but recieved " "recieved %d.",
"%d.\n",
in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx])); in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
} }
} while (in_idx < dim_size); } while (in_idx < dim_size);
...@@ -165,79 +162,71 @@ struct DimensionsTransform { ...@@ -165,79 +162,71 @@ struct DimensionsTransform {
} }
}; };
template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false> template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData( __device__ __forceinline__ void LoadData(
T *dst, const T *__restrict__ src, uint32_t block_offset, T *dst, const T *__restrict__ src, uint32_t block_offset,
const kps::details::BroadcastConfig<ShapeSize> &config, int numel, int num, const kps::details::BroadcastConfig<Rank> &config, int numel, int num,
bool need_broadcast) { bool need_broadcast) {
// numel : whole num of output // numel : whole num of output
// num: how many data will be deal with in this time // num: how many data will be deal with in this time
if (need_broadcast) { if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, ShapeSize, IsBoundary>( kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(dst, src, block_offset,
dst, src, block_offset, config, numel, 1, 1); config, numel, 1, 1);
} else { } else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num); kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
} }
} }
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize, template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int VecSize, typename Functor, bool IsBoundary = false> int Rank, bool IsBoundary = false>
__device__ void DealSegment( __device__ void DealSegment(
const framework::Array<const InT *__restrict__, ET> &in, OutT *out, const framework::Array<const InT *__restrict__, Arity> &ins, OutT *out,
const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel, const framework::Array<bool, Arity> &use_broadcast, uint32_t numel,
const framework::Array<kps::details::BroadcastConfig<ShapeSize>, const framework::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
MAX_INPUT_NUM> &configlists,
int num, Functor func) { int num, Functor func) {
InT args[ET][VecSize]; InT args[Arity][VecSize];
OutT result[VecSize]; OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize; int block_offset = blockIdx.x * blockDim.x * VecSize;
// load
#pragma unroll #pragma unroll
for (int i = 0; i < ET; i++) { for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset, LoadData<InT, VecSize, Rank, IsBoundary>(args[i], ins[i], block_offset,
configlists[i], numel, num, configs[i], numel, num,
use_broadcast[i]); use_broadcast[i]);
} }
// compute
if (ET == kUnary) { const bool kCallElementwiseAny =
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0], platform::FunctionTraits<Functor>::has_pointer_args;
func); ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
} else if (ET == kBinary) { kCallElementwiseAny>()(func, args, result);
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
} else {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
// compute
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result, kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
num); num);
} }
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize, template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int VecSize, typename Functor> int Rank>
__global__ void BroadcastKernel( __global__ void BroadcastKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out, framework::Array<const InT *__restrict__, Arity> ins, OutT *out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel, framework::Array<bool, Arity> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM> framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
configlists,
int main_tid, int tail_tid, Functor func) { int main_tid, int tail_tid, Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize; int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block // data offset of this block
if (blockIdx.x < main_tid) { if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid int num = blockDim.x * VecSize; // blockIdx.x < main_tid
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>( DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
in, out, use_broadcast, numel, configlists, num, func); ins, out, use_broadcast, numel, configs, num, func);
} else { // reminder } else { // reminder
int num = tail_tid; int num = tail_tid;
DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>( DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
in, out, use_broadcast, numel, configlists, num, func); ins, out, use_broadcast, numel, configs, num, func);
} }
} }
template <typename InT, typename OutT, ElementwiseType ET, int VecSize, template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Size, typename Functor> int Rank>
void LaunchKernel(const platform::CUDADeviceContext &ctx, void LaunchKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
framework::Tensor *out, Functor func, framework::Tensor *out, Functor func,
...@@ -251,53 +240,58 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx, ...@@ -251,53 +240,58 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx,
auto stream = ctx.stream(); auto stream = ctx.stream();
OutT *out_data = out->data<OutT>(); OutT *out_data = out->data<OutT>();
framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM> framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
configlists; framework::Array<bool, Arity> use_broadcast;
framework::Array<bool, MAX_INPUT_NUM> use_broadcast; framework::Array<const InT *__restrict__, Arity> ins_data;
framework::Array<const InT *__restrict__, ET> ins_data;
for (int i = 0; i < ET; i++) { for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>(); ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) { if (use_broadcast[i]) {
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configlists[i] = kps::details::BroadcastConfig<Size>( configs[i] = kps::details::BroadcastConfig<Rank>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
} }
BroadcastKernel<ET, InT, OutT, Size, VecSize, BroadcastKernel<InT, OutT, Functor, Arity, VecSize,
Functor><<<blocks, threads, 0, stream>>>( Rank><<<blocks, threads, 0, stream>>>(
ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid, ins_data, out_data, use_broadcast, numel, configs, main_tid, tail_tid,
func); func);
} }
template <typename InT, typename OutT, ElementwiseType ET, int VecSize, template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
typename Functor> void LaunchBroadcastKernelForDifferentVecSize(
void LaunchBroadcastKernelForDifferentDimSize(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out, const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) { int axis, Functor func) {
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
#define DIM_SIZE(size) \
case size: { \ #define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \ case rank: { \
merge_dims); \ LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>(ctx, ins, out, \
func, merge_dims); \
} break; } break;
switch (merge_dims.dim_size) { switch (merge_dims.dim_size) {
DIM_SIZE(1); CALL_BROADCAST_FOR_DIM_SIZE(1);
DIM_SIZE(2); CALL_BROADCAST_FOR_DIM_SIZE(2);
DIM_SIZE(3); CALL_BROADCAST_FOR_DIM_SIZE(3);
DIM_SIZE(4); CALL_BROADCAST_FOR_DIM_SIZE(4);
DIM_SIZE(5); CALL_BROADCAST_FOR_DIM_SIZE(5);
DIM_SIZE(6); CALL_BROADCAST_FOR_DIM_SIZE(6);
DIM_SIZE(7); CALL_BROADCAST_FOR_DIM_SIZE(7);
DIM_SIZE(8); CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size, framework::DDim::kMaxRank));
}
} }
#undef DIM_SIZE #undef CALL_BROADCAST_FOR_DIM_SIZE
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
...@@ -305,11 +299,21 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -305,11 +299,21 @@ void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) { std::vector<framework::Tensor *> *outs, int axis, Functor func) {
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary, using Traits = platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), kArity,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Currently, only Support binary calculation, " "The number of inputs is expected to be equal to the "
"but received %d input tensors.\n", "arity of functor. But recieved: the number of inputs "
static_cast<int>(ET))); "is %d, the arity of functor is %d.",
ins.size(), kArity));
PADDLE_ENFORCE_EQ(kArity, 2,
platform::errors::InvalidArgument(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
kArity));
int in_vec_size = 4; int in_vec_size = 4;
framework::Tensor *out = (*outs)[0]; framework::Tensor *out = (*outs)[0];
for (auto *in : ins) { for (auto *in : ins) {
...@@ -322,18 +326,18 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -322,18 +326,18 @@ void LaunchBroadcastElementwiseCudaKernel(
switch (vec_size) { switch (vec_size) {
case 4: { case 4: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out, LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
axis, func); ctx, ins, out, axis, func);
break; break;
} }
case 2: { case 2: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out, LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
axis, func); ctx, ins, out, axis, func);
break; break;
} }
case 1: { case 1: {
LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out, LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
axis, func); ctx, ins, out, axis, func);
break; break;
} }
default: { default: {
...@@ -369,7 +373,5 @@ void LaunchElementwiseCudaKernel( ...@@ -369,7 +373,5 @@ void LaunchElementwiseCudaKernel(
} }
} }
#undef MAX_INPUT_NUM
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -37,8 +37,10 @@ limitations under the License. */ ...@@ -37,8 +37,10 @@ limitations under the License. */
#endif #endif
#include <thrust/iterator/iterator_adaptor.h> #include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256; constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
#else #else
...@@ -278,128 +280,6 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x, ...@@ -278,128 +280,6 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
} }
} }
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *__restrict__ x_data,
const T *__restrict__ y_data,
OutType *__restrict__ out_data, int n,
int post, const size_t total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = tid; i < total; i += stride) {
int idx = i / post % n;
out_data[i] = func(x_data[i], y_data[idx]);
}
}
template <typename Functor, typename T, typename OutType>
void ComputeElementwiseCUDA(const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
int pre, int n, int post,
const platform::CUDADeviceContext &ctx,
Functor func, const bool is_xsize_larger = true) {
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
int numel = pre * n * post;
int threads = 256;
int blocks = (numel + threads - 1) / threads;
if (is_xsize_larger) {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, n, post, numel, func);
} else {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, n, post, numel, func);
}
}
template <typename Functor, typename T, typename OutType = T>
__global__ void CommonForwardBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const T *x, const T *y, OutType *out,
int out_size, int max_dim, Functor func, const bool is_xsize_larger) {
for (int out_index = blockIdx.x * blockDim.x + threadIdx.x;
out_index < out_size; out_index += blockDim.x * gridDim.x) {
int x_index = 0;
int y_index = 0;
int out_index_quotient = out_index;
int remainder = 0;
#pragma unroll
for (int i = max_dim - 1; i >= 0; --i) {
GetDivMod(out_index_quotient, out_dims_array[i], &out_index_quotient,
&remainder);
x_index += remainder * x_strides_array[i];
y_index += remainder * y_strides_array[i];
}
if (is_xsize_larger) {
out[out_index] = func(x[x_index], y[y_index]);
} else {
out[out_index] = func(y[y_index], x[x_index]);
}
}
}
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCUDA(
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, int *x_dims_array, int *y_dims_array,
int *out_dims_array, int max_dim, const platform::CUDADeviceContext &ctx,
Functor func, const bool is_xsize_larger = true) {
const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto cplace = platform::CPUPlace();
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
int x_stride = 1;
int y_stride = 1;
for (int i = max_dim - 1; i >= 0; i--) {
x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
x_stride *= x_dims_array[i];
y_stride *= y_dims_array[i];
}
int bytes = max_dim * sizeof(int);
auto x_strides_array_tmp = memory::Alloc(ctx, bytes);
int *x_strides_array_gpu =
reinterpret_cast<int *>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(),
bytes, ctx.stream());
auto y_strides_array_tmp = memory::Alloc(ctx, bytes);
int *y_strides_array_gpu =
reinterpret_cast<int *>(y_strides_array_tmp->ptr());
memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(),
bytes, ctx.stream());
auto out_dims_array_tmp = memory::Alloc(ctx, bytes);
int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes,
ctx.stream());
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
dim3 gird_size = dim3(
(out_size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
CommonForwardBroadcastCUDAKernel<
Functor, T, OutType><<<gird_size, block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_data,
y_data, out_data, out_size, max_dim, func, is_xsize_larger);
}
#endif // __NVCC__ or __HIPCC__
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
void CommonGradBroadcastCPU( void CommonGradBroadcastCPU(
const framework::Tensor &x, const framework::Tensor &y, const framework::Tensor &x, const framework::Tensor &y,
...@@ -1917,21 +1797,10 @@ void CommonElementwiseBroadcastForward( ...@@ -1917,21 +1797,10 @@ void CommonElementwiseBroadcastForward(
y_dims_array.data(), out_dims_array.data(), max_dim, y_dims_array.data(), out_dims_array.data(), max_dim,
axis); axis);
if (platform::is_gpu_place(ctx.GetPlace())) { CommonForwardBroadcastCPU<Functor, T, OutType>(
#if defined(__NVCC__) || defined(__HIPCC__) x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(),
CommonForwardBroadcastCUDA<Functor, T, OutType>( max_dim, ctx.template device_context<platform::CPUDeviceContext>(), func,
x, y, z, x_dims_array.data(), y_dims_array.data(), is_xsize_larger);
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), func,
is_xsize_larger);
#endif
} else {
CommonForwardBroadcastCPU<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CPUDeviceContext>(), func,
is_xsize_larger);
}
} }
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
...@@ -1975,12 +1844,35 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx, ...@@ -1975,12 +1844,35 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
} }
} }
// It is a common implementation to compute binary calculation with the support
// of broadcast, supporting both CPU and GPU.
// - CPU implementation cannot support the case when x needs broadcast, thus
// this function need to be called with XxxFunctor and XxxInverseFunctor,
// like paddle/fluid/operators/elementwise/elementwise_add_op.h#L49 - L55.
// - GPU implementation supports all the broadcast cases, thus there is no need
// to define and call with XxxInverseFunctor.
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor.
template <typename Functor, typename DeviceContext, typename T, template <typename Functor, typename DeviceContext, typename T,
typename OutType = T> typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext &ctx, void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func, const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) { framework::Tensor *z) {
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor *> ins = {x, y};
std::vector<framework::Tensor *> outs = {z};
z->mutable_data<OutType>(ctx.GetPlace());
const auto &dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, OutType>(
dev_ctx, ins, &outs, axis, func);
#endif
return;
}
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
bool is_xsize_larger = true; bool is_xsize_larger = true;
...@@ -2029,15 +1921,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -2029,15 +1921,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
return; return;
} }
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
ComputeElementwiseCUDA<Functor, T, OutType>(
x, y, z, pre, n, post,
ctx.template device_context<platform::CUDADeviceContext>(), func,
is_xsize_larger);
#endif
return;
}
if (post == 1) { if (post == 1) {
functor.RunRowWise(n, pre); functor.RunRowWise(n, pre);
return; return;
......
...@@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/function_traits.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256 #define ELEMENTWISE_BLOCK_SIZE 256
...@@ -28,7 +29,8 @@ namespace paddle { ...@@ -28,7 +29,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace kps = paddle::operators::kernel_primitives; namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* /*
* According to NVIDIA, if number of threads per block is 64/128/256/512, * According to NVIDIA, if number of threads per block is 64/128/256/512,
...@@ -55,8 +57,9 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx, ...@@ -55,8 +57,9 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
} }
template <typename InT, typename OutT> template <typename InT, typename OutT>
int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins, int GetVectorizedSizeForTensors(
const std::vector<framework::Tensor *> &outs) { const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) {
int vec_size = 4; int vec_size = 4;
for (auto iter = ins.begin(); iter != ins.end(); ++iter) { for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
vec_size = std::min<int>(vec_size, vec_size = std::min<int>(vec_size,
...@@ -69,56 +72,88 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins, ...@@ -69,56 +72,88 @@ int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
return vec_size; return vec_size;
} }
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, template <typename InT, typename OutT, int VecSize, typename Functor, int Arity,
typename Functor, bool IsBoundary> bool CallElementwiseAny = false>
__device__ void DealSegment( struct ElementwisePrimitiveCaller {
const framework::Array<const InT *__restrict__, ET> &in, OutT *out, int num, __device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
Functor func) { OutT *result);
int data_offset = VecSize * blockIdx.x * blockDim.x; };
InT args[ET][VecSize];
OutT result[VecSize]; template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
// load data struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
#pragma unroll __device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
for (int i = 0; i < ET; i++) { OutT *result) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f)); kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(result, args,
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset, func);
num);
} }
};
// compute template <typename InT, typename OutT, int VecSize, typename Functor>
if (ET == kUnary) { struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0], kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func); func);
} else if (ET == kBinary) { }
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0], kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func); args[1], func);
} else { }
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline OutT operator()(Functor func, InT **args, OutT *result) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>( kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func); result, args[0], args[1], args[2], func);
} }
};
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
bool IsBoundary>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, Arity> &in, OutT *out,
int num, Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
// store int data_offset = VecSize * blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
num);
}
const bool kCallElementwiseAny =
platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result, kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
num); num);
} }
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
typename Functor>
__global__ void ElementVectorizeKernel( __global__ void ElementVectorizeKernel(
framework::Array<const InT *__restrict__, ET> in, OutT *out, int size, framework::Array<const InT *__restrict__, Arity> ins, OutT *out, int size,
Functor func) { Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x; int data_offset = VecSize * blockIdx.x * blockDim.x;
int num = size - data_offset; int num = size - data_offset;
// the num this time have to deal with // the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<ET, VecSize, InT, OutT, Functor, true>(in, out, num, func); DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
} else { // complete segment } else { // complete segment
DealSegment<ET, VecSize, InT, OutT, Functor, false>(in, out, num, func); DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
} }
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
int VecSize>
void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx, void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, std::vector<framework::Tensor *> *outs,
...@@ -129,14 +164,14 @@ void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx, ...@@ -129,14 +164,14 @@ void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream(); auto stream = ctx.stream();
OutT *out = (*outs)[0]->data<OutT>(); OutT *out_data = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, ET> in; framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < ET; i++) { for (int i = 0; i < Arity; i++) {
in[i] = ins[i]->data<InT>(); ins_data[i] = ins[i]->data<InT>();
} }
ElementVectorizeKernel<ET, VecSize, InT, OutT, ElementVectorizeKernel<InT, OutT, Functor, Arity,
Functor><<<grid_size, block_size, 0, stream>>>( VecSize><<<grid_size, block_size, 0, stream>>>(
in, out, numel, func); ins_data, out_data, numel, func);
} }
template <ElementwiseType ET, typename InT, typename OutT, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
...@@ -144,17 +179,30 @@ void LaunchSameDimsElementwiseCudaKernel( ...@@ -144,17 +179,30 @@ void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx, const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) { std::vector<framework::Tensor *> *outs, Functor func) {
using Traits = platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), kArity,
platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(), kArity));
// calculate the max vec_size for all ins and outs // calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs); int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
switch (vec_size) { switch (vec_size) {
case 4: case 4:
ElementwiseCudaKernel<ET, InT, OutT, Functor, 4>(ctx, ins, outs, func); ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>(ctx, ins, outs,
func);
break; break;
case 2: case 2:
ElementwiseCudaKernel<ET, InT, OutT, Functor, 2>(ctx, ins, outs, func); ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>(ctx, ins, outs,
func);
break; break;
case 1: case 1:
ElementwiseCudaKernel<ET, InT, OutT, Functor, 1>(ctx, ins, outs, func); ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>(ctx, ins, outs,
func);
break; break;
default: { default: {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -22,13 +22,6 @@ namespace plat = paddle::platform; ...@@ -22,13 +22,6 @@ namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct CudaSubFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] - args[1];
}
};
template <typename T> template <typename T>
class ElementwiseSubKernel<platform::CUDADeviceContext, T> class ElementwiseSubKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -41,7 +34,7 @@ class ElementwiseSubKernel<platform::CUDADeviceContext, T> ...@@ -41,7 +34,7 @@ class ElementwiseSubKernel<platform::CUDADeviceContext, T>
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs); int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaSubFunctor<T>()); cuda_ctx, ins, &outs, axis, SubFunctor<T>());
} }
}; };
......
...@@ -52,10 +52,8 @@ template <typename T> ...@@ -52,10 +52,8 @@ template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType; using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T> template <typename T>
struct CudaAddFunctor { struct AddFunctor {
inline HOSTDEVICE T operator()(const T* args) const { inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
return args[0] + args[1];
}
}; };
template <typename InT, typename OutT, int ShapeSize, int VecSize, template <typename InT, typename OutT, int ShapeSize, int VecSize,
...@@ -128,7 +126,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, ...@@ -128,7 +126,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
std::vector<int64_t> out_dims = {n, m}; std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2); configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
auto func = CudaAddFunctor<T>(); auto func = AddFunctor<T>();
auto stream = ctx.stream(); auto stream = ctx.stream();
switch (vec_size) { switch (vec_size) {
case 4: { case 4: {
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#endif #endif
// #include <algorithm>
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -135,53 +134,114 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { ...@@ -135,53 +134,114 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
} // namespace details } // namespace details
/*************************** Compute Function****************************/ /**
* @brief unary function
* @param
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a) const {
* return ...;
* }
* };
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute(in[idx]));
}
}
/** /**
* @brief binary function, in1 and in2 have same shape * @brief binary function, in1 and in2 have same shape
* @param * @param
* T: data type of in1, in2 * T: data type of in1, in2
* OutT: data type of out * OutT: data type of out
* NX: the cols of in1, in2 * NX: the cols of in1, in2
* NY: the rows of in1, in2 * NY: the rows of in1, in2
* BlockSize: the config of this device * BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2 * OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b) const {
* return ...;
* }
* };
*/ */
template <typename T, typename OutT, int NX, int NY, int BlockSize, template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc> class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1, __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
const T* in2, const T* in2,
OpFunc compute) { OpFunc compute) {
T args[2];
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) { for (int idx = 0; idx < NX * NY; ++idx) {
args[0] = in1[idx]; out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
args[1] = in2[idx];
out[idx] = static_cast<OutT>(compute(args));
} }
} }
/** /**
* @brief ternary function, in1, in2 and in3 have same shape * @brief ternary function, in1, in2 and in3 have same shape
* @param * @param
* T: data type of in1, in2, in3 * T: data type of in1, in2, in3
* OutT: data type of out * OutT: data type of out
* NX: the cols of in1, in2 * NX: the cols of in1, in2
* NY: the rows of in1, in2 * NY: the rows of in1, in2
* BlockSize: the config of this device * BlockSize: the config of this device
* OpFunc: compute functor eg: out = in1 * in2 + in3 * OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b, const T& c) const {
* return ...;
* }
* };
*/ */
template <typename T, typename OutT, int NX, int NY, int BlockSize, template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc> class OpFunc>
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1, __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
const T* in2, const T* in3, const T* in2, const T* in3,
OpFunc compute) { OpFunc compute) {
T args[3];
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) { for (int idx = 0; idx < NX * NY; ++idx) {
args[0] = in1[idx]; out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
args[1] = in2[idx]; }
args[2] = in3[idx]; }
/**
* @brief a general function for elementwise computation, all inputs have
* the same shape.
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T* args) const {
* return ...;
* }
* };
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize, int Arity,
class OpFunc>
__device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
OpFunc compute) {
T args[Arity];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
for (int j = 0; j < Arity; ++j) {
args[j] = ins[j][idx];
}
out[idx] = static_cast<OutT>(compute(args)); out[idx] = static_cast<OutT>(compute(args));
} }
} }
...@@ -189,7 +249,7 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1, ...@@ -189,7 +249,7 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
/** /**
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size * @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
* is [NY, NX], out's shape size is [NY, NX] * is [NY, NX], out's shape size is [NY, NX]
* @param * @param
* T: data type of in1, in2 * T: data type of in1, in2
* OutT: data type of out * OutT: data type of out
* NX: the cols of in1, in2 * NX: the cols of in1, in2
...@@ -211,26 +271,6 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1, ...@@ -211,26 +271,6 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
} }
} }
/**
* @brief unary function
* @param:
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor eg: relu, exp
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute(in + idx));
}
}
/** /**
* @brief reduce function, in's shape size is [NX, NY]. * @brief reduce function, in's shape size is [NX, NY].
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1], * If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
...@@ -238,7 +278,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in, ...@@ -238,7 +278,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was * shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
* split, BlockYReduce will be called. If reduce_last_dim is true and * split, BlockYReduce will be called. If reduce_last_dim is true and
* reduce_num was split, BlockXReduce will be called * reduce_num was split, BlockXReduce will be called
* @typename * @typename
* T: data type of in * T: data type of in
* NX: the cols of in * NX: the cols of in
* NY: the rows of in * NY: the rows of in
......
...@@ -15,18 +15,14 @@ ...@@ -15,18 +15,14 @@
#include <unsupported/Eigen/SpecialFunctions> #include <unsupported/Eigen/SpecialFunctions>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/lgamma_op.h" #include "paddle/fluid/operators/lgamma_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename Enable = void>
struct CudaLgammaFunctor;
template <typename T> template <typename T>
struct CudaLgammaFunctor<T, math::NoComplex<T, math::Real<T>>> { struct CudaLgammaFunctor {
__device__ __forceinline__ T operator()(const T* args) const { __device__ __forceinline__ T operator()(const T& x) const {
return Eigen::numext::lgamma(args[0]); return Eigen::numext::lgamma(x);
} }
}; };
...@@ -37,15 +33,14 @@ class LgammaKernel<platform::CUDADeviceContext, T> ...@@ -37,15 +33,14 @@ class LgammaKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X"); const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<math::Real<T>>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.device_context<platform::CUDADeviceContext>(); auto& dev_ctx = context.device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x}; std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out}; std::vector<framework::Tensor*> outs = {out};
auto functor = CudaLgammaFunctor<T>(); auto functor = CudaLgammaFunctor<T>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
math::Real<T>>(dev_ctx, ins, &outs, dev_ctx, ins, &outs, functor);
functor);
} }
}; };
......
...@@ -129,17 +129,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> { ...@@ -129,17 +129,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k), compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace()); context.GetPlace());
int axis = -1; int axis = -1;
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CUDADeviceContext, T,
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CUDADeviceContext, int64_t>(context, &eigenvalue_tensor, &tol_tensor,
T, int64_t>(context, &eigenvalue_tensor, &tol_tensor, axis, GreaterThanFunctor<T>(),
axis, GreaterThanFunctor<T>(), &compare_result);
&compare_result);
} else {
ElementwiseComputeEx<LessThanFunctor<T>, platform::CUDADeviceContext, T,
int64_t>(context, &eigenvalue_tensor, &tol_tensor,
axis, LessThanFunctor<T>(),
&compare_result);
}
auto dito_int = auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext, math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
int64_t>(context); int64_t>(context);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <Eigen/src/Core/util/Constants.h> #include <Eigen/src/Core/util/Constants.h>
#include <Eigen/Dense> #include <Eigen/Dense>
#include <Eigen/SVD> #include <Eigen/SVD>
...@@ -296,14 +297,23 @@ struct DeviceIndependenceTensorOperations { ...@@ -296,14 +297,23 @@ struct DeviceIndependenceTensorOperations {
framework::Tensor ret; framework::Tensor ret;
std::vector<int> out_shape = GetBroadcastShape({&x, &y}); std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape)); ret.Resize(framework::make_ddim(out_shape));
if (x.dims().size() >= y.dims().size()) { if (platform::is_gpu_place(context.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
// For GPU, there is no need to define XxxInverseFunctor and call
// ElementwiseComputeEx in two branches.
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
context, &x, &y, -1, SubFunctor<T>(), &ret); context, &x, &y, -1, SubFunctor<T>(), &ret);
#endif
} else { } else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>( if (x.dims().size() >= y.dims().size()) {
// This is copyed from elementwise_sub, which means we ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
// need reverse will xrank < yrank context, &x, &y, -1, SubFunctor<T>(), &ret);
context, &x, &y, -1, InverseSubFunctor<T>(), &ret); } else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context, &x, &y, -1, InverseSubFunctor<T>(), &ret);
}
} }
return ret; return ret;
} }
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.1 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.1
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <tuple>
namespace paddle {
namespace platform {
// Declare a template class with a single template parameter.
template <typename>
struct FunctionTraits;
// A forwarding trait allowing functors (objects which have an operator())
// to be used with this traits class.
template <typename T>
struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
// A partial specialization of FunctionTraits for pointers to member functions
// and has const/non-const class member functions.
template <typename ClassType, typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType (ClassType::*)(Args...) const>
: public FunctionTraits<ReturnType(Args...)> {};
template <typename ClassType, typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType (ClassType::*)(Args...)>
: public FunctionTraits<ReturnType(Args...)> {};
// An implementation for common function.
template <typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType(Args...)> {
static const size_t arity = sizeof...(Args);
static const bool has_pointer_args =
(arity == 1) &&
(std::is_pointer<
typename std::tuple_element<0, std::tuple<Args...>>::type>::value);
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册