diff --git a/lite/kernels/x86/reduce_op_function.h b/lite/kernels/x86/reduce_op_function.h index 46e1248e070350ca82c73b639f8a924958460901..179a06164dc4aa73683ba8803bce1f7733bae141 100644 --- a/lite/kernels/x86/reduce_op_function.h +++ b/lite/kernels/x86/reduce_op_function.h @@ -63,7 +63,19 @@ void ReduceFunctor(const lite::Tensor& input, auto out = EigenScalar::From(output); functor(&x, &out, reduce_dim); } else { - auto out = EigenTensor::From(*output, output->dims()); + std::vector out_dims; + if (keep_dim) { + // Construct the squeezed dims. + const int kDelFlag = -2; + out_dims = output->dims().Vectorize(); + for (size_t i = 0; i < dims.size(); ++i) { + out_dims[reduce_dim[i]] = kDelFlag; + } + out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag), + out_dims.end()); + } + auto out = EigenTensor::From( + *output, keep_dim ? DDim(out_dims) : output->dims()); functor(&x, &out, reduce_dim); } } diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc index 3f0de174715a6fd718694fb31e9d7cb7c08cf2f9..e2cc56b416dd166e6b22a0c642907844ab964cc5 100644 --- a/lite/operators/reduce_ops.cc +++ b/lite/operators/reduce_ops.cc @@ -50,20 +50,18 @@ bool ReduceOp::InferShape() const { } else { size_t out_rank = keep_dim ? x_rank : x_rank - dims.size(); std::vector out_dims(out_rank); - if (keep_dim) { - for (size_t i = 0; i < dims.size(); ++i) { - out_dims[dims[i]] = 1; - } - } else { - sort(dims.begin(), dims.end()); - int dim_index = 0; - int out_index = 0; - for (size_t i = 0; i < x_rank; ++i) { - if (dims[dim_index] == static_cast(i)) { - dim_index++; - } else { - out_dims[out_index++] = x_dims[i]; + sort(dims.begin(), dims.end()); + int dim_index = 0; + int out_index = 0; + for (size_t i = 0; i < x_rank; ++i) { + if (dim_index < dims.size() && + dims[dim_index] == static_cast(i)) { + if (keep_dim) { + out_dims[out_index++] = 1; } + dim_index++; + } else { + out_dims[out_index++] = x_dims[i]; } } param_.output->Resize(out_dims);