未验证 提交 80f5e25e 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]fix gather_nd (#49266)

* fix reshape, gather_nd
上级 8386c609
......@@ -42,6 +42,7 @@ class StackOpConverter : public OpConverter {
auto input = op_desc.Input("X");
int input_num = input.size();
std::vector<nvinfer1::ITensor*> inputs;
auto output_name = op_desc.Output("Y").front();
for (int i = 0; i < input_num; ++i) {
inputs.push_back(engine_->GetITensor(input[i]));
......@@ -76,13 +77,15 @@ class StackOpConverter : public OpConverter {
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[i]);
reshape_layer->setInput(1, *after_shape_tensor);
inputs[i] = reshape_layer->getOutput(0);
reshape_layer->setName(("stack: reshape: (Output( " + std::to_string(i) +
" )" + output_name + ")")
.c_str());
}
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, inputs.data(), inputs.size());
layer->setAxis(axis);
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
}
};
......
......@@ -609,18 +609,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}
#if IS_TRT_VERSION_LT(8200)
auto index_var_name = desc.Input("Index")[0];
auto* index_var_desc = block->FindVar(index_var_name);
// The index input must be int32 datatype.
if (index_var_desc->GetDataType() !=
paddle::framework::proto::VarType_Type::VarType_Type_INT32) {
VLOG(3) << "gather_nd op Index input data type must be int32";
return false;
}
#if IS_TRT_VERSION_LT(8200)
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto index_shape = index_var_desc->GetShape();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册