未验证 提交 20c3224d 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add gather_nd trt converter. (#47589)

* add_gather_nd_

* add_gather_nd_

* add_gather_nd_
上级 827fd5cd
......@@ -24,13 +24,24 @@ class GatherNdOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin";
framework::OpDesc op_desc(op, nullptr);
auto input = engine_->GetITensor(op_desc.Input("X")[0]);
auto index = engine_->GetITensor(op_desc.Input("Index")[0]);
auto output_name = op_desc.Output("Out")[0];
// AddGatherV2 is supported by the trt version of 8.2.
#if IS_TRT_VERSION_GE(8200)
VLOG(3) << "convert gather_nd op to tensorrt gather_nd layer";
auto layer = TRT_ENGINE_ADD_LAYER(
engine_, GatherV2, *input, *index, nvinfer1::GatherMode::kND);
layer->setNbElementWiseDims(0);
RreplenishLayerAndOutput(layer, "gather_nd", {output_name}, test_mode);
#else
VLOG(4) << "convert a paddle gather_nd op to tensorrt gather_nd plugin";
// Declare inputs
std::vector<nvinfer1::ITensor*> inputs;
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto* index = engine_->GetITensor(op_desc.Input("Index")[0]);
inputs.emplace_back(input);
inputs.emplace_back(index);
......@@ -41,7 +52,6 @@ class GatherNdOpConverter : public OpConverter {
layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
std::string layer_name = "gather_nd (Output: ";
auto output_name = op_desc.Output("Out")[0];
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer_name += output_name;
......@@ -49,6 +59,7 @@ class GatherNdOpConverter : public OpConverter {
engine_->DeclareOutput(output_name);
}
layer->setName((layer_name + ")").c_str());
#endif
}
};
......
......@@ -566,9 +566,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto index_var_name = desc.Input("Index")[0];
auto* x_var_desc = block->FindVar(x_var_name);
auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype.
......@@ -578,6 +577,9 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
#if IS_TRT_VERSION_LT(8200)
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto index_shape = index_var_desc->GetShape();
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() <= 2) {
......@@ -591,6 +593,7 @@ struct SimpleOpTypeSetTeller : public Teller {
<< " ] not equal to x dims size [" << x_shape.size() << "]";
return false;
}
#endif
}
if (op_type == "anchor_generator") {
......
......@@ -69,11 +69,11 @@ class TrtConvertGatherNdTest_dim_4_1(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"input_data": [2, 32, 64, 64],
"index_data": [1],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [1],
}
self.dynamic_shape.opt_input_shape = {
......@@ -159,11 +159,11 @@ class TrtConvertGatherNdTest_dim_4_1_2(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"input_data": [2, 32, 64, 64],
"index_data": [2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [2],
}
self.dynamic_shape.opt_input_shape = {
......@@ -249,11 +249,11 @@ class TrtConvertGatherNdTest_dim_4_2(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"input_data": [2, 32, 64, 64],
"index_data": [2, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [2, 2],
}
self.dynamic_shape.opt_input_shape = {
......@@ -339,11 +339,11 @@ class TrtConvertGatherNdTest_dim_4_3(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 8, 8, 8],
"input_data": [2, 32, 64, 64],
"index_data": [2, 2, 4],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 64, 64],
"input_data": [2, 32, 64, 64],
"index_data": [2, 2, 4],
}
self.dynamic_shape.opt_input_shape = {
......@@ -429,15 +429,15 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 4],
"input_data": [2, 32],
"index_data": [2, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 64],
"input_data": [2, 32],
"index_data": [2, 2],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 8],
"input_data": [2, 32],
"index_data": [2, 2],
}
......@@ -521,15 +521,15 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 4, 4],
"index_data": [1, 1, 1],
"input_data": [16, 32, 256],
"index_data": [2, 2, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [16, 64, 512],
"index_data": [4, 2, 4],
"input_data": [16, 32, 256],
"index_data": [2, 2, 2],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [2, 8, 64],
"input_data": [16, 32, 256],
"index_data": [2, 2, 2],
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册