未验证 提交 20c3224d 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add gather_nd trt converter. (#47589)

* add_gather_nd_

* add_gather_nd_

* add_gather_nd_
上级 827fd5cd
...@@ -24,13 +24,24 @@ class GatherNdOpConverter : public OpConverter { ...@@ -24,13 +24,24 @@ class GatherNdOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, const framework::Scope& scope,
bool test_mode) override { bool test_mode) override {
VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto input = engine_->GetITensor(op_desc.Input("X")[0]);
auto index = engine_->GetITensor(op_desc.Input("Index")[0]);
auto output_name = op_desc.Output("Out")[0];
// AddGatherV2 is supported by the trt version of 8.2.
#if IS_TRT_VERSION_GE(8200)
VLOG(3) << "convert gather_nd op to tensorrt gather_nd layer";
auto layer = TRT_ENGINE_ADD_LAYER(
engine_, GatherV2, *input, *index, nvinfer1::GatherMode::kND);
layer->setNbElementWiseDims(0);
RreplenishLayerAndOutput(layer, "gather_nd", {output_name}, test_mode);
#else
VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin";
// Declare inputs // Declare inputs
std::vector<nvinfer1::ITensor*> inputs; std::vector<nvinfer1::ITensor*> inputs;
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto* index = engine_->GetITensor(op_desc.Input("Index")[0]);
inputs.emplace_back(input); inputs.emplace_back(input);
inputs.emplace_back(index); inputs.emplace_back(index);
...@@ -41,7 +52,6 @@ class GatherNdOpConverter : public OpConverter { ...@@ -41,7 +52,6 @@ class GatherNdOpConverter : public OpConverter {
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin); layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
std::string layer_name = "gather_nd (Output: "; std::string layer_name = "gather_nd (Output: ";
auto output_name = op_desc.Output("Out")[0];
layer->getOutput(0)->setName(output_name.c_str()); layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0)); engine_->SetITensor(output_name, layer->getOutput(0));
layer_name += output_name; layer_name += output_name;
...@@ -49,6 +59,7 @@ class GatherNdOpConverter : public OpConverter { ...@@ -49,6 +59,7 @@ class GatherNdOpConverter : public OpConverter {
engine_->DeclareOutput(output_name); engine_->DeclareOutput(output_name);
} }
layer->setName((layer_name + ")").c_str()); layer->setName((layer_name + ")").c_str());
#endif
} }
}; };
......
...@@ -566,9 +566,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -566,9 +566,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass."; "the pass.";
return false; return false;
} }
auto x_var_name = desc.Input("X")[0];
auto index_var_name = desc.Input("Index")[0]; auto index_var_name = desc.Input("Index")[0];
auto* x_var_desc = block->FindVar(x_var_name);
auto* index_var_desc = block->FindVar(index_var_name); auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype. // The index input must be int32 datatype.
...@@ -578,6 +577,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -578,6 +577,9 @@ struct SimpleOpTypeSetTeller : public Teller {
return false; return false;
} }
#if IS_TRT_VERSION_LT(8200)
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto index_shape = index_var_desc->GetShape(); const auto index_shape = index_var_desc->GetShape();
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() <= 2) { if (x_shape.size() <= 2) {
...@@ -591,6 +593,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -591,6 +593,7 @@ struct SimpleOpTypeSetTeller : public Teller {
<< " ] not equal to x dims size [" << x_shape.size() << "]"; << " ] not equal to x dims size [" << x_shape.size() << "]";
return false; return false;
} }
#endif
} }
if (op_type == "anchor_generator") { if (op_type == "anchor_generator") {
......
...@@ -69,11 +69,11 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest): ...@@ -69,11 +69,11 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8], "input_data": [2, 32, 64, 64],
"index_data": [1], "index_data": [1],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64], "input_data": [2, 32, 64, 64],
"index_data": [1], "index_data": [1],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
...@@ -159,11 +159,11 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest): ...@@ -159,11 +159,11 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8], "input_data": [2, 32, 64, 64],
"index_data": [2], "index_data": [2],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64], "input_data": [2, 32, 64, 64],
"index_data": [2], "index_data": [2],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
...@@ -249,11 +249,11 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest): ...@@ -249,11 +249,11 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8], "input_data": [2, 32, 64, 64],
"index_data": [2, 2], "index_data": [2, 2],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64], "input_data": [2, 32, 64, 64],
"index_data": [2, 2], "index_data": [2, 2],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
...@@ -339,11 +339,11 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest): ...@@ -339,11 +339,11 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8], "input_data": [2, 32, 64, 64],
"index_data": [2, 2, 4], "index_data": [2, 2, 4],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64], "input_data": [2, 32, 64, 64],
"index_data": [2, 2, 4], "index_data": [2, 2, 4],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
...@@ -429,15 +429,15 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): ...@@ -429,15 +429,15 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 4], "input_data": [2, 32],
"index_data": [2, 2], "index_data": [2, 2],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 64], "input_data": [2, 32],
"index_data": [2, 2], "index_data": [2, 2],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data": [2, 8], "input_data": [2, 32],
"index_data": [2, 2], "index_data": [2, 2],
} }
...@@ -521,15 +521,15 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest): ...@@ -521,15 +521,15 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float): ) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = { self.dynamic_shape.min_input_shape = {
"input_data": [1, 4, 4], "input_data": [16, 32, 256],
"index_data": [1, 1, 1], "index_data": [2, 2, 2],
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [16, 64, 512], "input_data": [16, 32, 256],
"index_data": [4, 2, 4], "index_data": [2, 2, 2],
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
"input_data": [2, 8, 64], "input_data": [16, 32, 256],
"index_data": [2, 2, 2], "index_data": [2, 2, 2],
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册