提交 b79bf021 编写于 作者: L Liu Yiqun

Optimize the InferShape of reduce_sum.

上级 b649344e
...@@ -29,39 +29,41 @@ bool ReduceOp::CheckShape() const { ...@@ -29,39 +29,41 @@ bool ReduceOp::CheckShape() const {
} }
bool ReduceOp::InferShape() const { bool ReduceOp::InferShape() const {
auto x_dims = param_.x->dims(); auto& x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
auto dims = param_.dim; auto& dims = param_.dim;
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i]; if (dims[i] < 0) dims[i] = x_rank + dims[i];
CHECK_LT(dims[i], x_rank) CHECK_LT(dims[i], x_rank)
<< "The dim should be in the range [-rank(input), rank(input)."; << "The dim should be in the range [-rank(input), rank(input).";
} }
sort(dims.begin(), dims.end());
bool reduce_all = param_.reduce_all; bool reduce_all = param_.reduce_all;
bool keep_dim = param_.keep_dim; bool keep_dim = param_.keep_dim;
if (reduce_all) { if (reduce_all) {
if (keep_dim) if (keep_dim)
param_.output->Resize(lite::DDim(std::vector<int64_t>(x_rank, 1))); param_.output->Resize(std::vector<int64_t>(x_rank, 1));
else else
param_.output->Resize(lite::DDim(std::vector<int64_t>{1})); param_.output->Resize(std::vector<int64_t>{1});
} else { } else {
auto dims_vector = x_dims.Vectorize(); size_t out_rank = keep_dim ? x_rank : x_rank - dims.size();
DDim out_dims(out_rank);
if (keep_dim) { if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1; out_dims[dims[i]] = 1;
} }
} else { } else {
const int kDelFlag = -2; sort(dims.begin(), dims.end());
for (size_t i = 0; i < dims.size(); ++i) { int dim_index = 0;
dims_vector[dims[i]] = kDelFlag; int out_index = 0;
for (size_t i = 0; i < x_rank; ++i) {
if (dims[dim_index] == i) {
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
}
} }
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
} }
auto out_dims = lite::DDim(dims_vector);
param_.output->Resize(out_dims); param_.output->Resize(out_dims);
if (dims[0] != 0) { if (dims[0] != 0) {
param_.output->set_lod(param_.x->lod()); param_.output->set_lod(param_.x->lod());
...@@ -70,7 +72,7 @@ bool ReduceOp::InferShape() const { ...@@ -70,7 +72,7 @@ bool ReduceOp::InferShape() const {
return true; return true;
} }
bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool ReduceOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
param_.x = param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output = param_.output =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册