未验证 提交 4a52c0cc 编写于 作者: W Wangzheee 提交者: GitHub

[paddle-trt] fix_teller_reshape (#34583)

上级 6c8a10a2
...@@ -703,8 +703,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -703,8 +703,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
// Paddle-TRT does not support the input tensors: Shape and ShapeTensor // Paddle-TRT does not support the input tensors: Shape and ShapeTensor
if (desc.Input("Shape").size() >= 1 || auto reshape_inputs = desc.Inputs();
desc.Input("ShapeTensor").size() >= 1) { if (reshape_inputs.find("Shape") != reshape_inputs.end() ||
reshape_inputs.find("ShapeTensor") != reshape_inputs.end()) {
return false; return false;
} }
std::vector<int> shape = std::vector<int> shape =
......
...@@ -80,6 +80,33 @@ class TRTReshapeTest1(TRTReshapeTest): ...@@ -80,6 +80,33 @@ class TRTReshapeTest1(TRTReshapeTest):
class TRTReshapeTest2(TRTReshapeTest): class TRTReshapeTest2(TRTReshapeTest):
def setUp(self):
self.bs = 2
self.input_shape = [23, 13, 24]
self.reshape = [2, 0, -1, 12]
self.data_shape = [
self.bs, self.input_shape[0], self.input_shape[1],
self.input_shape[2]
]
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name='data', shape=self.data_shape, dtype='float32')
actual_reshape = fluid.data(
name='actual_reshape', shape=[4], dtype='int32')
reshape_out = fluid.layers.reshape(
x=data, shape=self.reshape, actual_shape=actual_reshape)
out = fluid.layers.batch_norm(reshape_out, is_test=True)
self.feeds = {
'data': np.random.random(self.data_shape).astype('float32'),
'actual_reshape': np.array([2, 0, -1, 12]).astype('int32')
}
self.enable_trt = True
self.trt_parameters = TRTReshapeTest.TensorRTParam(
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
class TRTReshapeTest3(TRTReshapeTest):
def setUp(self): def setUp(self):
self.bs = 1 self.bs = 1
self.input_shape = [14, 48, 27] self.input_shape = [14, 48, 27]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册