diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc index e986b0ca5412f8380cccc9f981e5e4069ffcdabc..0aada7293e87b27839428fe2eb957c3f2738208f 100644 --- a/lite/operators/reduce_ops.cc +++ b/lite/operators/reduce_ops.cc @@ -29,39 +29,41 @@ bool ReduceOp::CheckShape() const { } bool ReduceOp::InferShape() const { - auto x_dims = param_.x->dims(); + auto& x_dims = param_.x->dims(); auto x_rank = x_dims.size(); - auto dims = param_.dim; + auto& dims = param_.dim; for (size_t i = 0; i < dims.size(); ++i) { if (dims[i] < 0) dims[i] = x_rank + dims[i]; CHECK_LT(dims[i], x_rank) << "The dim should be in the range [-rank(input), rank(input)."; } - sort(dims.begin(), dims.end()); bool reduce_all = param_.reduce_all; bool keep_dim = param_.keep_dim; if (reduce_all) { if (keep_dim) - param_.output->Resize(lite::DDim(std::vector(x_rank, 1))); + param_.output->Resize(std::vector(x_rank, 1)); else - param_.output->Resize(lite::DDim(std::vector{1})); + param_.output->Resize(std::vector{1}); } else { - auto dims_vector = x_dims.Vectorize(); + size_t out_rank = keep_dim ? x_rank : x_rank - dims.size(); + DDim out_dims(out_rank); if (keep_dim) { for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = 1; + out_dims[dims[i]] = 1; } } else { - const int kDelFlag = -2; - for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = kDelFlag; + sort(dims.begin(), dims.end()); + int dim_index = 0; + int out_index = 0; + for (size_t i = 0; i < x_rank; ++i) { + if (dims[dim_index] == i) { + dim_index++; + } else { + out_dims[out_index++] = x_dims[i]; + } } - dims_vector.erase( - remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); } - auto out_dims = lite::DDim(dims_vector); param_.output->Resize(out_dims); if (dims[0] != 0) { param_.output->set_lod(param_.x->lod()); @@ -70,7 +72,7 @@ bool ReduceOp::InferShape() const { return true; } -bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { +bool ReduceOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { param_.x = scope->FindVar(opdesc.Input("X").front())->GetMutable(); param_.output =