提交 b79bf021 编写于 作者: L Liu Yiqun

Optimize the InferShape of reduce_sum.

上级 b649344e
......@@ -29,39 +29,41 @@ bool ReduceOp::CheckShape() const {
}
bool ReduceOp::InferShape() const {
auto x_dims = param_.x->dims();
auto& x_dims = param_.x->dims();
auto x_rank = x_dims.size();
auto dims = param_.dim;
auto& dims = param_.dim;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
CHECK_LT(dims[i], x_rank)
<< "The dim should be in the range [-rank(input), rank(input).";
}
sort(dims.begin(), dims.end());
bool reduce_all = param_.reduce_all;
bool keep_dim = param_.keep_dim;
if (reduce_all) {
if (keep_dim)
param_.output->Resize(lite::DDim(std::vector<int64_t>(x_rank, 1)));
param_.output->Resize(std::vector<int64_t>(x_rank, 1));
else
param_.output->Resize(lite::DDim(std::vector<int64_t>{1}));
param_.output->Resize(std::vector<int64_t>{1});
} else {
auto dims_vector = x_dims.Vectorize();
size_t out_rank = keep_dim ? x_rank : x_rank - dims.size();
DDim out_dims(out_rank);
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
out_dims[dims[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
sort(dims.begin(), dims.end());
int dim_index = 0;
int out_index = 0;
for (size_t i = 0; i < x_rank; ++i) {
if (dims[dim_index] == i) {
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
}
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = lite::DDim(dims_vector);
param_.output->Resize(out_dims);
if (dims[0] != 0) {
param_.output->set_lod(param_.x->lod());
......@@ -70,7 +72,7 @@ bool ReduceOp::InferShape() const {
return true;
}
bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
bool ReduceOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册