提交 d7015d6c 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix bug in reduce_ops when keep_dim is true. (#2906)

test=develop
上级 e4489484
...@@ -63,7 +63,19 @@ void ReduceFunctor(const lite::Tensor& input, ...@@ -63,7 +63,19 @@ void ReduceFunctor(const lite::Tensor& input,
auto out = EigenScalar<T>::From(output); auto out = EigenScalar<T>::From(output);
functor(&x, &out, reduce_dim); functor(&x, &out, reduce_dim);
} else { } else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, output->dims()); std::vector<DDim::value_type> out_dims;
if (keep_dim) {
// Construct the squeezed dims.
const int kDelFlag = -2;
out_dims = output->dims().Vectorize();
for (size_t i = 0; i < dims.size(); ++i) {
out_dims[reduce_dim[i]] = kDelFlag;
}
out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
out_dims.end());
}
auto out = EigenTensor<T, (D - R_D)>::From(
*output, keep_dim ? DDim(out_dims) : output->dims());
functor(&x, &out, reduce_dim); functor(&x, &out, reduce_dim);
} }
} }
......
...@@ -50,20 +50,18 @@ bool ReduceOp::InferShape() const { ...@@ -50,20 +50,18 @@ bool ReduceOp::InferShape() const {
} else { } else {
size_t out_rank = keep_dim ? x_rank : x_rank - dims.size(); size_t out_rank = keep_dim ? x_rank : x_rank - dims.size();
std::vector<DDim::value_type> out_dims(out_rank); std::vector<DDim::value_type> out_dims(out_rank);
if (keep_dim) { sort(dims.begin(), dims.end());
for (size_t i = 0; i < dims.size(); ++i) { int dim_index = 0;
out_dims[dims[i]] = 1; int out_index = 0;
} for (size_t i = 0; i < x_rank; ++i) {
} else { if (dim_index < dims.size() &&
sort(dims.begin(), dims.end()); dims[dim_index] == static_cast<DDim::value_type>(i)) {
int dim_index = 0; if (keep_dim) {
int out_index = 0; out_dims[out_index++] = 1;
for (size_t i = 0; i < x_rank; ++i) {
if (dims[dim_index] == static_cast<DDim::value_type>(i)) {
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
} }
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
} }
} }
param_.output->Resize(out_dims); param_.output->Resize(out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册