未验证 提交 7a0602c8 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix tensorrt output shape error (#29308) (#29344)

* fix tensorrt output shape error

* fix unittest tensorrt_engine_op_test

* fix code style for unitest
上级 67e12faa
......@@ -151,9 +151,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
std::vector<int> origin_output_dims;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
origin_output_dims.push_back(x->Var()->GetShape().size());
}
std::unordered_map<std::string, std::string> output_name_map;
......@@ -224,6 +226,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("workspace_size", Get<int>("workspace_size"));
op_desc->SetAttr("gpu_id", Get<int>("gpu_device_id"));
op_desc->SetAttr("output_name_mapping", output_mapping);
op_desc->SetAttr("origin_output_dims", origin_output_dims);
op_desc->SetAttr("parameters", params);
// we record all inputs' shapes in attr to check if they are consistent
......
......@@ -288,6 +288,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
// Bind output tensor to TRT.
int output_index = 0;
std::vector<int> origin_output_dims =
Attr<std::vector<int>>("origin_output_dims");
VLOG(4) << "TensorRT Engine Op Outputs:";
for (const auto &y : Outputs("Ys")) {
const int bind_index =
......@@ -306,7 +308,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto dims = trt_context->getBindingDimensions(bind_index);
int nb_dims = dims.nbDims;
for (; nb_dims > 0; nb_dims--) {
if (dims.d[nb_dims - 1] != 1) break;
// some 'x 1' of shape is normal, no need to remove it
if (dims.d[nb_dims - 1] != 1 ||
nb_dims == origin_output_dims[output_index])
break;
}
for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]);
#endif
......
......@@ -109,6 +109,7 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("origin_output_dims", std::vector<int>({2}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
......@@ -210,6 +211,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("use_calib_mode", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("origin_output_dims", std::vector<int>({2}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册