未验证 提交 df515255 编写于 作者: Z Zhang Ting 提交者: GitHub

modify DivideFunctor to match ElementwiseSameDims template (#39041)

上级 b47fb764
......@@ -59,8 +59,7 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
return;
}
using MT = typename details::MPTypeTrait<T>::Type;
using Div = kernel_primitives::DivideFunctor<T, MT>;
using Div = kernel_primitives::DivideFunctor<T, T>;
std::vector<int> reduce_dims;
reduce_dims.reserve(rank);
for (decltype(rank) i = 0; i < rank; ++i) {
......
......@@ -52,21 +52,6 @@ namespace pten {
dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \
}
/**
* Util Functors
*/
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T x) const { return x * n_inv; }
private:
T n_inv;
};
/**
* Kernels
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册