diff --git a/lite/kernels/x86/reduce_compute.h b/lite/kernels/x86/reduce_compute.h index d3aeee05b1b662097c26d158d3d904da6c3260f1..5192484fbcd355ef354dfc94a87eb21ae01257f4 100644 --- a/lite/kernels/x86/reduce_compute.h +++ b/lite/kernels/x86/reduce_compute.h @@ -25,11 +25,6 @@ namespace kernels { namespace x86 { struct SumFunctor { - template - void operator()(X* x, Y* y, const Dim& dim) { - y->device(lite::fluid::EigenDeviceType()) = x->sum(dim); - } - template void operator()(X* x, Y* y, const Dim& dim, size_t d, size_t r_d) { for (int i = 0; i < dim[0]; i++) { @@ -43,8 +38,23 @@ struct SumFunctor { } } } + + template + void operator()(X* x, Y* y, const Dim& dim) { + y->device(lite::fluid::EigenDeviceType()) = x->sum(dim); + } }; +#define HANDLE_DIMT(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + paddle::lite::kernels::x86::ReduceFunctorTensor( \ + *input, output, dims, keep_dim); \ + } + #define HANDLE_DIM(NDIM, RDIM) \ if (ndim == NDIM && rdim == RDIM) { \ paddle::lite::kernels::x86:: \