提交 14e04200 编写于 作者: L Liu Yiqun

Optimize the InferShape of concat by removing the use of std::vector.

上级 1abd9921
...@@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const { ...@@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const {
} }
bool ConcatOpLite::InferShape() const { bool ConcatOpLite::InferShape() const {
std::vector<lite::DDim> input_dims; std::vector<lite::Tensor *> &inputs = param_.x;
for (auto p : param_.x) { const size_t n = inputs.size();
input_dims.push_back(p->dims());
}
const size_t n = input_dims.size();
CHECK_GT_OR_FALSE(n, 0); CHECK_GT_OR_FALSE(n, 0);
int axis = 0; int axis = 0;
...@@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const { ...@@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const {
axis = axis_tensor_val[0]; axis = axis_tensor_val[0];
} }
if (axis < 0) { if (axis < 0) {
axis += input_dims[0].size(); axis += inputs[0]->dims().size();
} }
auto &out_dims = input_dims[0]; auto out_dims = inputs[0]->dims();
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; i++) {
auto &input_dims_i = inputs[i]->dims();
for (size_t j = 0; j < in_zero_dims_size; j++) { for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == static_cast<size_t>(axis)) { if (j == static_cast<size_t>(axis)) {
out_dims[axis] += input_dims[i][j]; out_dims[axis] += input_dims_i[j];
} else { } else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); CHECK_EQ_OR_FALSE(out_dims[j], input_dims_i[j]);
} }
} }
} }
...@@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const { ...@@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const {
out_dims[axis] = -1; out_dims[axis] = -1;
} }
// Set output dims // Set output dims
param_.output->Resize(lite::DDim(out_dims)); param_.output->Resize(out_dims);
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x[0]->lod(); *out_lod = param_.x[0]->lod();
return true; return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册