未验证 提交 7f502b7c 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[cherry-pick2.3]fix compile bug of windows cuda11.5 (#41464)

cherry-pick

fix compile bug of windows cuda11.5 #41433
上级 5b85f3dc
......@@ -1878,12 +1878,17 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// exp(x) = expf(x)
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(expf(static_cast<float>(x)));
}
};
template <>
struct CudaExpFunctor<double> : public BaseActivationFunctor<double> {
// exp(x) = exp(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(exp(x));
__device__ __forceinline__ double operator()(const double x) const {
return exp(x);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册