未验证 提交 a250c56c 编写于 作者: Z Zhang Ting 提交者: GitHub

[part 4]change type of function args (#38888)

上级 86434818
...@@ -25,14 +25,14 @@ namespace operators { ...@@ -25,14 +25,14 @@ namespace operators {
// functors to use with ElementwiseComputeEx // functors to use with ElementwiseComputeEx
template <typename T> template <typename T>
struct RealAndImagToComplexFunctor { struct RealAndImagToComplexFunctor {
inline HOSTDEVICE platform::complex<T> operator()(const T& x, const T& y) { inline HOSTDEVICE platform::complex<T> operator()(const T x, const T y) {
return platform::complex<T>(x, y); return platform::complex<T>(x, y);
} }
}; };
template <typename T> template <typename T>
struct ImagAndRealToComplexFunctor { struct ImagAndRealToComplexFunctor {
inline HOSTDEVICE platform::complex<T> operator()(const T& y, const T& x) { inline HOSTDEVICE platform::complex<T> operator()(const T y, const T x) {
return platform::complex<T>(x, y); return platform::complex<T>(x, y);
} }
}; };
......
...@@ -28,7 +28,7 @@ struct LabelSmoothFunctor { ...@@ -28,7 +28,7 @@ struct LabelSmoothFunctor {
label_dim = static_cast<T>(label_dim_data); label_dim = static_cast<T>(label_dim_data);
} }
__device__ __forceinline__ T operator()(const T& x) const { __device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x + return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim)); static_cast<T>(epsilon / label_dim));
} }
...@@ -42,7 +42,7 @@ struct LabelSmoothGradFunctor { ...@@ -42,7 +42,7 @@ struct LabelSmoothGradFunctor {
epsilon = static_cast<T>(epsilon_data); epsilon = static_cast<T>(epsilon_data);
} }
__device__ __forceinline__ T operator()(const T& x) const { __device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x; return static_cast<T>(1 - epsilon) * x;
} }
}; };
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
template <typename T> template <typename T>
struct CudaLgammaFunctor { struct CudaLgammaFunctor {
__device__ __forceinline__ T operator()(const T& x) const { __device__ __forceinline__ T operator()(const T x) const {
return Eigen::numext::lgamma(x); return Eigen::numext::lgamma(x);
} }
}; };
......
...@@ -48,17 +48,17 @@ static DDim RemoveLastDim(const DDim& dim) { ...@@ -48,17 +48,17 @@ static DDim RemoveLastDim(const DDim& dim) {
template <typename T> template <typename T>
struct GreaterThanFunctor { struct GreaterThanFunctor {
HOSTDEVICE int operator()(const T& a, const T& b) const { return a > b; } HOSTDEVICE int operator()(const T a, const T b) const { return a > b; }
}; };
template <typename T> template <typename T>
struct LessThanFunctor { struct LessThanFunctor {
HOSTDEVICE int operator()(const T& a, const T& b) const { return a < b; } HOSTDEVICE int operator()(const T a, const T b) const { return a < b; }
}; };
template <typename T> template <typename T>
struct GreaterElementFunctor { struct GreaterElementFunctor {
HOSTDEVICE T operator()(const T& a, const T& b) const { HOSTDEVICE T operator()(const T a, const T b) const {
if (a > b) { if (a > b) {
return a; return a;
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册