提交 14e04200 编写于 作者: L Liu Yiqun

Optimize the InferShape of concat by removing the use of std::vector.

上级 1abd9921
......@@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const {
}
bool ConcatOpLite::InferShape() const {
std::vector<lite::DDim> input_dims;
for (auto p : param_.x) {
input_dims.push_back(p->dims());
}
const size_t n = input_dims.size();
std::vector<lite::Tensor *> &inputs = param_.x;
const size_t n = inputs.size();
CHECK_GT_OR_FALSE(n, 0);
int axis = 0;
......@@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const {
axis = axis_tensor_val[0];
}
if (axis < 0) {
axis += input_dims[0].size();
axis += inputs[0]->dims().size();
}
auto &out_dims = input_dims[0];
auto out_dims = inputs[0]->dims();
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
auto &input_dims_i = inputs[i]->dims();
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == static_cast<size_t>(axis)) {
out_dims[axis] += input_dims[i][j];
out_dims[axis] += input_dims_i[j];
} else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]);
CHECK_EQ_OR_FALSE(out_dims[j], input_dims_i[j]);
}
}
}
......@@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const {
out_dims[axis] = -1;
}
// Set output dims
param_.output->Resize(lite::DDim(out_dims));
param_.output->Resize(out_dims);
auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x[0]->lod();
return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册