提交 c210374c 编写于 作者: alinag's avatar alinag

add sumfunctor using tensor computation

test=develop
上级 15d7e8b3
......@@ -29,6 +29,20 @@ struct SumFunctor {
void operator()(X* x, Y* y, const Dim& dim) {
y->device(lite::fluid::EigenDeviceType<TARGET(kX86)>()) = x->sum(dim);
}
template <typename X, typename Y, typename Dim>
void operator()(X* x, Y* y, const Dim& dim, size_t d, size_t r_d) {
for (int i = 0; i < dim[0]; i++) {
for (int k = 0; k < dim[2]; k++) {
auto output_temp = x[i * dim[1] * dim[2] + k];
for (int j = 1; j < dim[1]; j++) {
int input_d = i * dim[1] * dim[2] + j * dim[2] + k;
output_temp = output_temp + x[input_d];
}
y[i * dim[2] + k] = output_temp;
}
}
}
};
#define HANDLE_DIM(NDIM, RDIM) \
......@@ -68,7 +82,7 @@ class ReduceSumCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIMT(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
......
......@@ -46,24 +46,6 @@ void ReduceFunctor(const lite::Tensor& input,
lite::Tensor* output,
const std::vector<int>& dims,
bool keep_dim) {
auto te = strstr(typeid(Functor).name(), "SumFunctor");
if (D == 3 && R_D == 1 && te != NULL) {
const lite::DDim& input_dims = input.dims();
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>();
for (int i = 0; i < input_dims[0]; i++) {
for (int k = 0; k < input_dims[2]; k++) {
int out_d = i * input_dims[2] + k;
T output_temp = 0;
for (int j = 0; j < input_dims[1]; j++) {
int input_d =
i * input_dims[1] * input_dims[2] + j * input_dims[2] + k;
output_temp = output_temp + input_data[input_d];
}
output_data[out_d] = output_temp;
}
}
} else {
auto x = EigenTensor<T, D>::From(input);
auto reduce_dim = Eigen::array<int, R_D>();
......@@ -84,7 +66,21 @@ void ReduceFunctor(const lite::Tensor& input,
auto out = EigenTensor<T, (D - R_D)>::From(*output, output->dims());
functor(&x, &out, reduce_dim);
}
}
}
template <lite::TargetType Target,
typename T,
size_t D,
size_t R_D,
typename Functor>
void ReduceFunctorTensor(const lite::Tensor& input,
lite::Tensor* output,
const std::vector<int>& dims,
bool keep_dim) {
Functor functor;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>();
functor(input_data, output_data, input.dims(), D, R_D);
}
} // namespace x86
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册