未验证 提交 a2b39320 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

shufflechannelfix (#44516)

上级 4f86092b
...@@ -46,26 +46,23 @@ class ShuffleChannelOpConverter : public OpConverter { ...@@ -46,26 +46,23 @@ class ShuffleChannelOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(8000) #if IS_TRT_VERSION_GE(8000)
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
auto* input_shape_tensor = Shape(input); auto* input_shape_tensor = Shape(input);
auto* batch_shape_tensor = GetEleTensorOfShape(input_shape_tensor, 0);
auto* channel_shape_tensor = GetEleTensorOfShape(input_shape_tensor, 1); auto* channel_shape_tensor = GetEleTensorOfShape(input_shape_tensor, 1);
auto* group_tensor = auto* group_tensor =
Add1DConstantLayer(group, output_name + "_group_tensor_"); Add1DConstantLayer(group, output_name + "_group_tensor_");
auto* new_channel_shape_tensor = Div(channel_shape_tensor, group_tensor); auto* new_channel_shape_tensor = Div(channel_shape_tensor, group_tensor);
std::vector<int32_t> shape_dim3{0, 2, 3}; std::vector<int32_t> shape_dim2{2, 3};
auto* shape_dim3_tensor = Gather(input_shape_tensor, shape_dim3); auto* shape_dim2_tensor = Gather(input_shape_tensor, shape_dim2);
std::vector<nvinfer1::ITensor*> itensors; std::vector<nvinfer1::ITensor*> itensors;
itensors.push_back(shape_dim3_tensor); itensors.push_back(batch_shape_tensor);
itensors.push_back(group_tensor); itensors.push_back(group_tensor);
itensors.push_back(new_channel_shape_tensor); itensors.push_back(new_channel_shape_tensor);
itensors.push_back(shape_dim2_tensor);
auto* reshape_tensor = Concat(itensors); auto* reshape_tensor = Concat(itensors);
auto* reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *reshape_tensor);
nvinfer1::Permutation transpose_new_input{0, 3, 4, 1, 2};
reshape_layer->setSecondTranspose(transpose_new_input);
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setInput(1, *(reshape_layer->getOutput(0))); layer->setInput(1, *(reshape_tensor));
nvinfer1::Permutation transpose_embed{0, 2, 1, 3, 4}; nvinfer1::Permutation transpose_embed{0, 2, 1, 3, 4};
layer->setSecondTranspose(transpose_embed); layer->setSecondTranspose(transpose_embed);
auto* output = layer->getOutput(0); auto* output = layer->getOutput(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册