未验证 提交 75a2f9d5 编写于 作者: W wenbin 提交者: GitHub

fix concat axis bug (#50951)

* fix concat bug

* recommit for ci
上级 6c471ed0
...@@ -45,10 +45,11 @@ class ConcatOpConverter : public OpConverter { ...@@ -45,10 +45,11 @@ class ConcatOpConverter : public OpConverter {
itensors.push_back(engine_->GetITensor(input_name)); itensors.push_back(engine_->GetITensor(input_name));
} }
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis")); int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis == -1) { if (axis < 0) {
axis = (engine_->GetITensor(op_desc.Input("X").front())->getDimensions()) axis = engine_->GetITensor(op_desc.Input("X").front())
.nbDims - ->getDimensions()
1; .nbDims +
axis;
} else { } else {
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim axis = axis - 1; // Remove batch dim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册