diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index 1941a88bbf50fa0da9a8fd22db3fa9146d242957..343d10475b4edcc7a05d320aa64bafdfcf893801 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const { } bool ConcatOpLite::InferShape() const { - std::vector input_dims; - for (auto p : param_.x) { - input_dims.push_back(p->dims()); - } - const size_t n = input_dims.size(); + std::vector &inputs = param_.x; + const size_t n = inputs.size(); CHECK_GT_OR_FALSE(n, 0); int axis = 0; @@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const { axis = axis_tensor_val[0]; } if (axis < 0) { - axis += input_dims[0].size(); + axis += inputs[0]->dims().size(); } - auto &out_dims = input_dims[0]; + auto out_dims = inputs[0]->dims(); size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { + auto &input_dims_i = inputs[i]->dims(); for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == static_cast(axis)) { - out_dims[axis] += input_dims[i][j]; + out_dims[axis] += input_dims_i[j]; } else { - CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + CHECK_EQ_OR_FALSE(out_dims[j], input_dims_i[j]); } } } @@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const { out_dims[axis] = -1; } // Set output dims - param_.output->Resize(lite::DDim(out_dims)); + param_.output->Resize(out_dims); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x[0]->lod(); return true;