提交 6eba5bd2 编写于 作者: H hjchen2

Fix direct copy and refine split ut

test=develop
上级 5857fb30
......@@ -20,30 +20,59 @@ namespace paddle {
namespace inference {
namespace tensorrt {
TEST(split_op, test) {
template <int BatchSize, int Axis>
void TensorRTSplitTest(const std::vector<int> &in_shape,
const std::vector<int> &sections) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
TRTConvertValidation validator(BatchSize + 1, parameters, scope, 10000);
auto make_dim = [](const std::vector<int> &shape) {
nvinfer1::DimsCHW dim;
dim.c() = shape[0];
dim.h() = shape[1];
dim.w() = shape[2];
return dim;
};
validator.DeclInputVar("split_input", make_dim(in_shape));
std::vector<std::string> output_vars;
for (size_t i = 0; i < sections.size(); ++i) {
auto out_shape = in_shape;
out_shape[Axis - 1] = sections[i];
std::string output_name = "split_out" + std::to_string(i);
validator.DeclOutputVar(output_name, make_dim(out_shape));
output_vars.push_back(output_name);
}
// Prepare Op description
framework::OpDesc desc;
desc.SetType("split");
desc.SetInput("X", {"split_input"});
desc.SetOutput("Out", {"split_out1", "split_out2"});
desc.SetOutput("Out", output_vars);
int num = 0;
int axis = 1;
std::vector<int> output_lengths = {2, 1};
desc.SetAttr("axis", axis);
desc.SetAttr("num", num);
desc.SetAttr("sections", output_lengths);
desc.SetAttr("axis", Axis);
desc.SetAttr("num", 0);
desc.SetAttr("sections", sections);
validator.SetOp(*desc.Proto());
validator.Execute(1);
validator.Execute(BatchSize);
}
TEST(split_op, test_same_shape_batch1) {
TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2});
}
TEST(split_op, test_different_shape_batch1) {
TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1});
}
TEST(split_op, test_same_shape_batch10) {
TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2});
}
TEST(split_op, test_different_shape_batch10) {
TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1});
}
} // namespace tensorrt
......
......@@ -138,11 +138,12 @@ inline void Split(cudaStream_t stream, const bool same_shape,
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) {
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
if (axis_ == -1 && this->getNbOutputs() < 10) {
if (((batchSize == 1 && axis_ == 0) || axis_ == -1) &&
this->getNbOutputs() < 10) {
float** output_ptrs = reinterpret_cast<float**>(outputs);
int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT)
? sizeof(__half)
: sizeof(float);
? sizeof(float)
: sizeof(__half);
for (int i = 0; i < this->getNbOutputs(); ++i) {
PADDLE_ENFORCE(
cudaMemcpyAsync(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册