未验证 提交 12c5b1fe 编写于 作者: Z Zhang Ting 提交者: GitHub

[part 6]change type of function args (#38891)

上级 5fc8bbf7
...@@ -53,7 +53,7 @@ struct ExpFunctor { ...@@ -53,7 +53,7 @@ struct ExpFunctor {
HOSTDEVICE explicit inline ExpFunctor(int n) {} HOSTDEVICE explicit inline ExpFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx x) const {
return static_cast<Ty>(details::Exp(x)); return static_cast<Ty>(details::Exp(x));
} }
}; };
...@@ -67,7 +67,7 @@ struct IdentityFunctor { ...@@ -67,7 +67,7 @@ struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor(int n) {} HOSTDEVICE explicit inline IdentityFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx x) const {
return static_cast<Ty>(x); return static_cast<Ty>(x);
} }
}; };
...@@ -85,7 +85,7 @@ struct DivideFunctor { ...@@ -85,7 +85,7 @@ struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {} HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx x) const {
return static_cast<Ty>(static_cast<MPType>(x) * n_inv); return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
} }
...@@ -102,7 +102,7 @@ struct InverseFunctor { ...@@ -102,7 +102,7 @@ struct InverseFunctor {
HOSTDEVICE explicit inline InverseFunctor(int n) {} HOSTDEVICE explicit inline InverseFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx x) const {
return static_cast<Ty>(-x); return static_cast<Ty>(-x);
} }
}; };
...@@ -116,7 +116,7 @@ struct SquareFunctor { ...@@ -116,7 +116,7 @@ struct SquareFunctor {
HOSTDEVICE explicit inline SquareFunctor(int n) {} HOSTDEVICE explicit inline SquareFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { HOSTDEVICE inline Ty operator()(const Tx x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x); return static_cast<Ty>(x) * static_cast<Ty>(x);
} }
}; };
...@@ -130,7 +130,7 @@ template <typename T> ...@@ -130,7 +130,7 @@ template <typename T>
struct MinFunctor { struct MinFunctor {
inline T initial() { return static_cast<T>(std::numeric_limits<T>::max()); } inline T initial() { return static_cast<T>(std::numeric_limits<T>::max()); }
__device__ __forceinline__ T operator()(const T& a, const T& b) const { __device__ __forceinline__ T operator()(const T a, const T b) const {
return (b < a) ? b : a; return (b < a) ? b : a;
} }
}; };
...@@ -144,7 +144,7 @@ struct MaxFunctor { ...@@ -144,7 +144,7 @@ struct MaxFunctor {
return static_cast<T>(std::numeric_limits<T>::lowest()); return static_cast<T>(std::numeric_limits<T>::lowest());
} }
__device__ __forceinline__ T operator()(const T& a, const T& b) const { __device__ __forceinline__ T operator()(const T a, const T b) const {
return (b > a) ? b : a; return (b > a) ? b : a;
} }
}; };
...@@ -156,7 +156,7 @@ template <typename T> ...@@ -156,7 +156,7 @@ template <typename T>
struct AddFunctor { struct AddFunctor {
inline T initial() { return static_cast<T>(0.0f); } inline T initial() { return static_cast<T>(0.0f); }
__device__ __forceinline__ T operator()(const T& a, const T& b) const { __device__ __forceinline__ T operator()(const T a, const T b) const {
return b + a; return b + a;
} }
}; };
...@@ -168,7 +168,7 @@ template <typename T> ...@@ -168,7 +168,7 @@ template <typename T>
struct MulFunctor { struct MulFunctor {
inline T initial() { return static_cast<T>(1.0f); } inline T initial() { return static_cast<T>(1.0f); }
__device__ __forceinline__ T operator()(const T& a, const T& b) const { __device__ __forceinline__ T operator()(const T a, const T b) const {
return b * a; return b * a;
} }
}; };
...@@ -180,7 +180,7 @@ template <typename T> ...@@ -180,7 +180,7 @@ template <typename T>
struct LogicalOrFunctor { struct LogicalOrFunctor {
inline T initial() { return static_cast<T>(false); } inline T initial() { return static_cast<T>(false); }
__device__ __forceinline__ T operator()(const T& a, const T& b) const { __device__ __forceinline__ T operator()(const T a, const T b) const {
return b || a; return b || a;
} }
}; };
...@@ -192,7 +192,7 @@ template <typename T> ...@@ -192,7 +192,7 @@ template <typename T>
struct LogicalAndFunctor { struct LogicalAndFunctor {
inline T initial() { return static_cast<T>(true); } inline T initial() { return static_cast<T>(true); }
__device__ __forceinline__ T operator()(const T& a, const T& b) const { __device__ __forceinline__ T operator()(const T a, const T b) const {
return b && a; return b && a;
} }
}; };
...@@ -204,7 +204,7 @@ template <typename T> ...@@ -204,7 +204,7 @@ template <typename T>
struct SubFunctor { struct SubFunctor {
inline T initial() { return static_cast<T>(0.0f); } inline T initial() { return static_cast<T>(0.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; } inline HOSTDEVICE T operator()(const T a, const T b) const { return a - b; }
}; };
/** /**
...@@ -214,7 +214,7 @@ template <typename T, typename Enable = void> ...@@ -214,7 +214,7 @@ template <typename T, typename Enable = void>
struct DivFunctor { struct DivFunctor {
inline T initial() { return static_cast<T>(1.0f); } inline T initial() { return static_cast<T>(1.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; }
}; };
template <typename T> template <typename T>
...@@ -222,7 +222,7 @@ struct DivFunctor<T, ...@@ -222,7 +222,7 @@ struct DivFunctor<T,
typename std::enable_if<std::is_integral<T>::value>::type> { typename std::enable_if<std::is_integral<T>::value>::type> {
inline T initial() { return static_cast<T>(1.0f); } inline T initial() { return static_cast<T>(1.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
// For int32/int64, need to check whether the divison is zero. // For int32/int64, need to check whether the divison is zero.
PADDLE_ENFORCE_NE(b, 0, PADDLE_ENFORCE_NE(b, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -239,7 +239,7 @@ template <typename T> ...@@ -239,7 +239,7 @@ template <typename T>
struct FloorDivFunctor { struct FloorDivFunctor {
inline T initial() { return static_cast<T>(1.0f); } inline T initial() { return static_cast<T>(1.0f); }
inline HOSTDEVICE T operator()(const T& a, const T& b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
PADDLE_ENFORCE_NE(b, 0, PADDLE_ENFORCE_NE(b, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Integer division by zero encountered " "Integer division by zero encountered "
......
...@@ -61,7 +61,7 @@ struct DivideFunctor { ...@@ -61,7 +61,7 @@ struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {} : n_inv(static_cast<T>(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } HOSTDEVICE inline T operator()(const T x) const { return x * n_inv; }
private: private:
T n_inv; T n_inv;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册