未验证 提交 df515255 编写于 作者: Z Zhang Ting 提交者: GitHub

modify DivideFunctor to match ElementwiseSameDims template (#39041)

上级 b47fb764
...@@ -59,8 +59,7 @@ class MeanCUDAKernel : public framework::OpKernel<T> { ...@@ -59,8 +59,7 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
using MT = typename details::MPTypeTrait<T>::Type; using Div = kernel_primitives::DivideFunctor<T, T>;
using Div = kernel_primitives::DivideFunctor<T, MT>;
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
reduce_dims.reserve(rank); reduce_dims.reserve(rank);
for (decltype(rank) i = 0; i < rank; ++i) { for (decltype(rank) i = 0; i < rank; ++i) {
......
...@@ -52,21 +52,6 @@ namespace pten { ...@@ -52,21 +52,6 @@ namespace pten {
dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \ dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \
} }
/**
* Util Functors
*/
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T x) const { return x * n_inv; }
private:
T n_inv;
};
/** /**
* Kernels * Kernels
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册