提交 1adda8e0 编写于 作者: H hjchen2

Add more unit tests for split plugin

test=develop
上级 6eba5bd2
...@@ -19,9 +19,6 @@ namespace paddle { ...@@ -19,9 +19,6 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
/*
* SplitOp.
*/
class SplitOpConverter : public OpConverter { class SplitOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
...@@ -40,15 +37,11 @@ class SplitOpConverter : public OpConverter { ...@@ -40,15 +37,11 @@ class SplitOpConverter : public OpConverter {
int axis = boost::get<int>(op_desc.GetAttr("axis")); int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths = std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections")); boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
// PADDLE_ENFORCE(axis != 0); // split on batch is not supported in TensorRT
if (axis < 0) { PADDLE_ENFORCE(axis != 0);
axis += input_dims.nbDims; axis += (axis < 0) ? input_dims.nbDims : -1;
} else {
axis -= 1;
}
PADDLE_ENFORCE(output_lengths.size() == output_num); PADDLE_ENFORCE(output_lengths.size() == output_num);
//
plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths); plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer = nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin); engine_->AddPlugin(&input, input_num, plugin);
......
...@@ -59,21 +59,54 @@ void TensorRTSplitTest(const std::vector<int> &in_shape, ...@@ -59,21 +59,54 @@ void TensorRTSplitTest(const std::vector<int> &in_shape,
validator.Execute(BatchSize); validator.Execute(BatchSize);
} }
TEST(split_op, test_same_shape_batch1) { // batch = 0, axis = 1, same shape
TEST(split_op, test_same_shape_axis1_batch1) {
TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2}); TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2});
} }
// batch = 0, axis = 1, different shape
TEST(split_op, test_different_shape_batch1) { TEST(split_op, test_different_shape_axis1_batch1) {
TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1}); TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1});
} }
// batch = 10, axis = 1, same shape
TEST(split_op, test_same_shape_batch10) { TEST(split_op, test_same_shape_axis1_batch10) {
TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2}); TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2});
} }
// batch = 10, axis = 1, different shape
TEST(split_op, test_different_shape_batch10) { TEST(split_op, test_different_shape_axis1_batch10) {
TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1}); TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1});
} }
// batch = 0, axis = 2, same shape
TEST(split_op, test_same_shape_axis2_batch1) {
TensorRTSplitTest<1, 2>({3, 4, 2}, {2, 2});
}
// batch = 0, axis = 2, different shape
TEST(split_op, test_different_shape_axis2_batch1) {
TensorRTSplitTest<1, 2>({3, 3, 2}, {2, 1});
}
// batch = 10, axis = 2, same shape
TEST(split_op, test_same_shape_axis2_batch10) {
TensorRTSplitTest<10, 2>({3, 4, 2}, {2, 2});
}
// batch = 10, axis = 2, different shape
TEST(split_op, test_different_shape_axis2_batch10) {
TensorRTSplitTest<10, 2>({3, 3, 2}, {2, 1});
}
// batch = 0, axis = 3, same shape
TEST(split_op, test_same_shape_axis3_batch1) {
TensorRTSplitTest<1, 3>({3, 2, 4}, {2, 2});
}
// batch = 0, axis = 3, different shape
TEST(split_op, test_different_shape_axis3_batch1) {
TensorRTSplitTest<1, 3>({3, 2, 3}, {2, 1});
}
// batch = 10, axis = 3, same shape
TEST(split_op, test_same_shape_axis3_batch10) {
TensorRTSplitTest<10, 3>({3, 2, 4}, {2, 2});
}
// batch = 10, axis = 3, different shape
TEST(split_op, test_different_shape_axis3_batch10) {
TensorRTSplitTest<10, 3>({3, 2, 3}, {2, 1});
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册