diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc index 6620c76318f99092236d4009037f2ce01b295164..ae5b1b98060a4e73b2d1761d4edafb152f364070 100644 --- a/paddle/fluid/inference/tensorrt/convert/split_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -19,9 +19,6 @@ namespace paddle { namespace inference { namespace tensorrt { -/* - * SplitOp. - */ class SplitOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, @@ -40,16 +37,11 @@ class SplitOpConverter : public OpConverter { int axis = boost::get(op_desc.GetAttr("axis")); std::vector output_lengths = boost::get>(op_desc.GetAttr("sections")); + // split on batch is not supported in TensorRT PADDLE_ENFORCE(axis != 0); - if (axis < 0) { - axis += input_dims.nbDims; - } else { - axis -= 1; - } + axis += (axis < 0) ? input_dims.nbDims : -1; PADDLE_ENFORCE(output_lengths.size() == output_num); - - // plugin::SplitPlugin* plugin = new plugin::SplitPlugin(axis, output_lengths); nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, input_num, plugin); diff --git a/paddle/fluid/inference/tensorrt/convert/test_split_op.cc b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc index f81d011552c152c2df79e1a272f34b954ae2a3a1..5aacc5c600dd1371e3865adc888bb8e24640e7d9 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_split_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc @@ -20,30 +20,92 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(split_op, test) { +template +void TensorRTSplitTest(const std::vector &in_shape, + const std::vector §ions) { std::unordered_set parameters({""}); framework::Scope scope; - TRTConvertValidation validator(10, parameters, scope, 1000); - validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2)); - validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2)); - validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2)); + TRTConvertValidation validator(BatchSize + 1, parameters, scope, 10000); + + auto make_dim = [](const std::vector &shape) { + nvinfer1::DimsCHW dim; + dim.c() = shape[0]; + dim.h() = shape[1]; + dim.w() = shape[2]; + return dim; + }; + validator.DeclInputVar("split_input", make_dim(in_shape)); + std::vector output_vars; + for (size_t i = 0; i < sections.size(); ++i) { + auto out_shape = in_shape; + out_shape[Axis - 1] = sections[i]; + std::string output_name = "split_out" + std::to_string(i); + validator.DeclOutputVar(output_name, make_dim(out_shape)); + output_vars.push_back(output_name); + } // Prepare Op description framework::OpDesc desc; desc.SetType("split"); desc.SetInput("X", {"split_input"}); - desc.SetOutput("Out", {"split_out1", "split_out2"}); + desc.SetOutput("Out", output_vars); - int num = 0; - int axis = 1; - std::vector output_lengths = {2, 1}; - desc.SetAttr("axis", axis); - desc.SetAttr("num", num); - desc.SetAttr("sections", output_lengths); + desc.SetAttr("axis", Axis); + desc.SetAttr("num", 0); + desc.SetAttr("sections", sections); validator.SetOp(*desc.Proto()); - validator.Execute(1); + validator.Execute(BatchSize); +} + +// batch = 0, axis = 1, same shape +TEST(split_op, test_same_shape_axis1_batch1) { + TensorRTSplitTest<1, 1>({4, 2, 2}, {2, 2}); +} +// batch = 0, axis = 1, different shape +TEST(split_op, test_different_shape_axis1_batch1) { + TensorRTSplitTest<1, 1>({3, 2, 2}, {2, 1}); +} +// batch = 10, axis = 1, same shape +TEST(split_op, test_same_shape_axis1_batch10) { + TensorRTSplitTest<10, 1>({4, 2, 2}, {2, 2}); +} +// batch = 10, axis = 1, different shape +TEST(split_op, test_different_shape_axis1_batch10) { + TensorRTSplitTest<10, 1>({3, 2, 2}, {2, 1}); +} +// batch = 0, axis = 2, same shape +TEST(split_op, test_same_shape_axis2_batch1) { + TensorRTSplitTest<1, 2>({3, 4, 2}, {2, 2}); +} +// batch = 0, axis = 2, different shape +TEST(split_op, test_different_shape_axis2_batch1) { + TensorRTSplitTest<1, 2>({3, 3, 2}, {2, 1}); +} +// batch = 10, axis = 2, same shape +TEST(split_op, test_same_shape_axis2_batch10) { + TensorRTSplitTest<10, 2>({3, 4, 2}, {2, 2}); +} +// batch = 10, axis = 2, different shape +TEST(split_op, test_different_shape_axis2_batch10) { + TensorRTSplitTest<10, 2>({3, 3, 2}, {2, 1}); +} +// batch = 0, axis = 3, same shape +TEST(split_op, test_same_shape_axis3_batch1) { + TensorRTSplitTest<1, 3>({3, 2, 4}, {2, 2}); +} +// batch = 0, axis = 3, different shape +TEST(split_op, test_different_shape_axis3_batch1) { + TensorRTSplitTest<1, 3>({3, 2, 3}, {2, 1}); +} +// batch = 10, axis = 3, same shape +TEST(split_op, test_same_shape_axis3_batch10) { + TensorRTSplitTest<10, 3>({3, 2, 4}, {2, 2}); +} +// batch = 10, axis = 3, different shape +TEST(split_op, test_different_shape_axis3_batch10) { + TensorRTSplitTest<10, 3>({3, 2, 3}, {2, 1}); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index 4adea2db1ee80fb20adba3cf4141a6485a1065a0..de61ace59e299a1f51940e4b433a0133d4fbe7ff 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" namespace paddle { @@ -19,6 +21,52 @@ namespace inference { namespace tensorrt { namespace plugin { +// copied from operators::math::SplitFunctor +template +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int* out_cols, + int out_cols_size, T** outputs_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int curr_segment = 0; + int curr_offset = out_cols[0]; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int curr_col_offset = out_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = out_cols[curr_segment + 1]; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs_data[curr_segment]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input_data[tid_y * in_col + tid_x]; + } + } +} + +template +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T** outputs_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x / fixed_out_col; + int in_offset = tid_x - split * fixed_out_col; + T* output_ptr = outputs_data[split]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * fixed_out_col + in_offset] = + input_data[tid_y * in_col + tid_x]; + } + } +} + nvinfer1::Dims SplitPlugin::getOutputDimensions( int index, const nvinfer1::Dims* input_dims, int num_inputs) { PADDLE_ENFORCE_EQ(num_inputs, 1); @@ -31,48 +79,96 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions( int SplitPlugin::initialize() { PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS); - + // notice input dims is [C, H, W] + nvinfer1::Dims dims = this->getInputDims(0); + outer_rows_ = 1; + inner_cols_ = 1; + for (int i = 0; i < axis_; ++i) { + outer_rows_ *= dims.d[i]; + } + for (int i = axis_ + 1; i < dims.nbDims; ++i) { + inner_cols_ *= dims.d[i]; + } + same_shape_ = true; std::vector segment_offsets(1, 0); for (int i = 0; i < this->getNbOutputs(); ++i) { - segment_offsets.push_back(segment_offsets.back() + output_length_[i]); + if (output_length_[i] != output_length_[0]) { + same_shape_ = false; + } + segment_offsets.push_back(segment_offsets.back() + + output_length_[i] * inner_cols_); } - segment_offsets_ = segment_offsets; - nvinfer1::Dims dims = this->getInputDims(0); - nx_ = 1; - for (int i = dims.nbDims - 1; i > axis_; --i) { - nx_ *= dims.d[i]; + inner_cols_ *= dims.d[axis_]; + d_segment_offsets_ = segment_offsets; + segment_offsets_ = std::move(segment_offsets); + d_output_ptrs_.resize(this->getNbOutputs(), nullptr); + return 0; +} + +template +inline void Split(cudaStream_t stream, const bool same_shape, + const int outer_rows, const int inner_cols, + const std::vector& segment_offsets, + const int* d_segment_offsets, const T* input, T** outputs) { + const int kThreadsPerBlock = 1024; + const int kMaxBlocks = 65535; + int block_cols = kThreadsPerBlock; + if (inner_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((inner_cols + 31) >> 5) << 5; } - ny_ = dims.d[axis_]; - nz_ = 1; - for (int i = axis_ - 1; i >= 0; --i) { - nz_ *= dims.d[i]; + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int grid_cols = + std::min((inner_cols + block_cols - 1) / block_cols, kMaxBlocks); + int grid_rows = + std::min(kMaxBlocks / grid_cols, std::max(outer_rows / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (same_shape) { + SplitKernel<<>>( + input, outer_rows, inner_cols, segment_offsets[1], outputs); + } else { + SplitKernel<<>>( + input, outer_rows, inner_cols, d_segment_offsets, + static_cast(segment_offsets.size()), outputs); } - return 0; } int SplitPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) { - auto const& input_dims = this->getInputDims(0); - int input_size = 0; - float const* idata = reinterpret_cast(inputs[0]); - float** odatas = reinterpret_cast(outputs); - - // kernel impl here. - int inputBatchOffset = nx_ * ny_ * nz_; - for (size_t i = 0; i < this->getNbOutputs(); i++) { - for (size_t j = 0; j < batchSize; j++) { - cudaMemcpyAsync( - odatas[i] + - j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * - sizeof(float), - inputs[0] + - (inputBatchOffset * j + segment_offsets_[i] * nx_) * - sizeof(float), - (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float), - cudaMemcpyDeviceToDevice, stream); + float const* input_ptr = reinterpret_cast(inputs[0]); + if (((batchSize == 1 && axis_ == 0) || axis_ == -1) && + this->getNbOutputs() < 10) { + float** output_ptrs = reinterpret_cast(outputs); + int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT) + ? sizeof(float) + : sizeof(__half); + for (int i = 0; i < this->getNbOutputs(); ++i) { + PADDLE_ENFORCE( + cudaMemcpyAsync( + output_ptrs[i], input_ptr + segment_offsets_[i], + (segment_offsets_[i + 1] - segment_offsets_[i]) * data_type_size, + cudaMemcpyDeviceToDevice, stream) == cudaSuccess); + } + } else { + outer_rows_ *= batchSize; + const int* d_segment_offsets_ptr = + thrust::raw_pointer_cast(&d_segment_offsets_[0]); + float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); + PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, outputs, + this->getNbOutputs() * sizeof(float*), + cudaMemcpyHostToDevice, + stream) == cudaSuccess); + if (this->getDataType() == nvinfer1::DataType::kFLOAT) { + Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_, + d_segment_offsets_ptr, input_ptr, output_ptrs); + } else { + Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_, + d_segment_offsets_ptr, (__half*)input_ptr, // NOLINT + (__half**)output_ptrs); // NOLINT } } - return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index b5b6e69992b057a1478f61457b4ae6f5f1619b4d..6f028d3d72ae3cc7d96c6782b734cdbf1243c06c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" @@ -25,7 +26,7 @@ namespace plugin { class SplitPlugin : public PluginTensorRT { public: SplitPlugin(int axis, std::vector const &output_lengths) - : axis_(axis), output_length_(output_lengths) {} + : axis_(axis), same_shape_(true), output_length_(output_lengths) {} SplitPlugin(void const *serial_data, size_t serial_length) { deserializeBase(serial_data, serial_length); @@ -60,9 +61,13 @@ class SplitPlugin : public PluginTensorRT { } int axis_; + int outer_rows_; + int inner_cols_; + bool same_shape_; std::vector output_length_; - int nx_, ny_, nz_; std::vector segment_offsets_; + thrust::device_vector d_segment_offsets_; + thrust::device_vector d_output_ptrs_; }; } // namespace plugin