提交 62257e33 编写于 作者: alinag's avatar alinag

add whole dimention case

test=develop
上级 c31afacb
...@@ -25,16 +25,49 @@ namespace kernels { ...@@ -25,16 +25,49 @@ namespace kernels {
namespace x86 { namespace x86 {
struct SumFunctor { struct SumFunctor {
template <typename X, typename Y, typename Dim> template <typename X, typename Y, typename XDim, typename Dim>
void operator()(X* x, Y* y, const Dim& dim, size_t d, size_t r_d) { void operator()(X* x, Y* y, const XDim& x_dim, const Dim& dims) {
for (int i = 0; i < dim[0]; i++) { if (dims[0] == 0) {
for (int k = 0; k < dim[2]; k++) { size_t h_size = x_dim[2];
auto output_temp = x[i * dim[1] * dim[2] + k]; size_t w_size = x_dim[1] * x_dim[2];
for (int j = 1; j < dim[1]; j++) { for (int i = 0; i < x_dim[1]; i++) {
int input_d = i * dim[1] * dim[2] + j * dim[2] + k; for (int k = 0; k < x_dim[2]; k++) {
output_temp = output_temp + x[input_d]; auto input_size = i * h_size + k;
auto output_temp = x[input_size];
for (int j = 1; j < x_dim[0]; j++) {
int input_d = input_size + j * w_size;
output_temp = output_temp + x[input_d];
}
y[i * h_size + k] = output_temp;
}
}
} else if (dims[0] == 1) {
size_t h_size = x_dim[1] * x_dim[2];
size_t w_size = x_dim[2];
for (int i = 0; i < x_dim[0]; i++) {
for (int k = 0; k < x_dim[2]; k++) {
auto input_size = i * h_size + k;
auto output_temp = x[input_size];
for (int j = 1; j < x_dim[1]; j++) {
int input_d = input_size + j * w_size;
output_temp = output_temp + x[input_d];
}
y[i * w_size + k] = output_temp;
}
}
} else {
size_t h_size = x_dim[1] * x_dim[2];
size_t w_size = x_dim[2];
for (int i = 0; i < x_dim[0]; i++) {
for (int k = 0; k < x_dim[1]; k++) {
auto input_size = i * h_size + k * w_size;
auto output_temp = x[input_size];
for (int j = 1; j < x_dim[2]; j++) {
int input_d = input_size + j;
output_temp = output_temp + x[input_d];
}
y[i * x_dim[1] + k] = output_temp;
} }
y[i * dim[2] + k] = output_temp;
} }
} }
} }
......
...@@ -92,7 +92,7 @@ void ReduceFunctorTensor(const lite::Tensor& input, ...@@ -92,7 +92,7 @@ void ReduceFunctorTensor(const lite::Tensor& input,
Functor functor; Functor functor;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(); T* output_data = output->mutable_data<T>();
functor(input_data, output_data, input.dims(), D, R_D); functor(input_data, output_data, input.dims(), dims);
} }
} // namespace x86 } // namespace x86
......
...@@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size, ...@@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size,
return output_size; return output_size;
} }
inline void UpdatePaddingAndDilation(std::vector<int>* paddings, void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations, std::vector<int>* dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::string padding_algorithm, const std::string padding_algorithm,
const lite::DDim data_dims, const lite::DDim data_dims,
const lite::DDim& ksize) { const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME" // when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") { if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册