未验证 提交 be6f1fb4 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix bug in reduce_ops when keep_dim is true. (#2906)

test=develop
上级 0c3a74f3
......@@ -63,7 +63,19 @@ void ReduceFunctor(const lite::Tensor& input,
auto out = EigenScalar<T>::From(output);
functor(&x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, output->dims());
std::vector<DDim::value_type> out_dims;
if (keep_dim) {
// Construct the squeezed dims.
const int kDelFlag = -2;
out_dims = output->dims().Vectorize();
for (size_t i = 0; i < dims.size(); ++i) {
out_dims[reduce_dim[i]] = kDelFlag;
}
out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
out_dims.end());
}
auto out = EigenTensor<T, (D - R_D)>::From(
*output, keep_dim ? DDim(out_dims) : output->dims());
functor(&x, &out, reduce_dim);
}
}
......
......@@ -50,20 +50,18 @@ bool ReduceOp::InferShape() const {
} else {
size_t out_rank = keep_dim ? x_rank : x_rank - dims.size();
std::vector<DDim::value_type> out_dims(out_rank);
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
out_dims[dims[i]] = 1;
}
} else {
sort(dims.begin(), dims.end());
int dim_index = 0;
int out_index = 0;
for (size_t i = 0; i < x_rank; ++i) {
if (dims[dim_index] == static_cast<DDim::value_type>(i)) {
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
sort(dims.begin(), dims.end());
int dim_index = 0;
int out_index = 0;
for (size_t i = 0; i < x_rank; ++i) {
if (dim_index < dims.size() &&
dims[dim_index] == static_cast<DDim::value_type>(i)) {
if (keep_dim) {
out_dims[out_index++] = 1;
}
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
}
}
param_.output->Resize(out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册