未验证 提交 d91352c0 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT]Fix flatten converter when batch_size > 1 (#33768)

* fix trt flatten converter when batch_size > 1

* change ut to same dynamic shape
上级 0f59d4e6
......@@ -53,10 +53,19 @@ class FlattenOpConverter : public OpConverter {
layer->setReshapeDimensions(flatten_dim);
} else {
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
nvinfer1::Dims start_dim, size_dim, stride_dim;
start_dim.nbDims = 1;
size_dim.nbDims = 1;
stride_dim.nbDims = 1;
start_dim.d[0] = 1;
size_dim.d[0] = dims - 1;
stride_dim.d[0] = 1;
auto* slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *(shape_layer->getOutput(0)),
start_dim, size_dim, stride_dim);
uint32_t reduce_dim = 1;
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *(shape_layer->getOutput(0)),
engine_, Reduce, *(slice_layer->getOutput(0)),
nvinfer1::ReduceOperation::kPROD, reduce_dim, true);
int32_t* constant_weight_data = new int32_t[1];
constant_weight_data[0] = -1;
......
......@@ -134,6 +134,19 @@ inline size_t ProductDim(const nvinfer1::Dims& dims) {
return v;
}
inline void PrintITensorShape(nvinfer1::ITensor* X) {
auto dims = X->getDimensions();
auto name = X->getName();
std::cout << "ITensor " << name << " shape: [";
for (int i = 0; i < dims.nbDims; i++) {
if (i == dims.nbDims - 1)
std::cout << dims.d[i];
else
std::cout << dims.d[i] << ", ";
}
std::cout << "]\n";
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -63,10 +63,10 @@ class TRTFlattenDynamicTest(InferencePassTest):
self.trt_parameters = TRTFlattenDynamicTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.dynamic_shape_params = TRTFlattenDynamicTest.DynamicShapeParam({
'data': [1, 6, 8, 8],
'flatten_0.tmp_0': [1, 6 * 8 * 8]
}, {'data': [3, 6, 128, 128],
'flatten_0.tmp_0': [3, 6 * 128 * 128]}, {
'data': [2, 6, 64, 64],
'flatten_0.tmp_0': [2, 6 * 64 * 64]
}, {'data': [2, 6, 64, 64],
'flatten_0.tmp_0': [2, 6 * 64 * 64]}, {
'data': [2, 6, 64, 64],
'flatten_0.tmp_0': [2, 6 * 64 * 64]
}, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册