未验证 提交 eea85814 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix compile bug of windows cuda11.5 (#41433)

上级 633ac4e6
...@@ -1878,12 +1878,17 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> { ...@@ -1878,12 +1878,17 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> { struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; // exp(x) = expf(x)
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(expf(static_cast<float>(x)));
}
};
template <>
struct CudaExpFunctor<double> : public BaseActivationFunctor<double> {
// exp(x) = exp(x) // exp(x) = exp(x)
__device__ __forceinline__ T operator()(const T arg_x) const { __device__ __forceinline__ double operator()(const double x) const {
MPType x = static_cast<MPType>(arg_x); return exp(x);
return static_cast<T>(exp(x));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册