recover_padding_plugin.cu 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

22 23
__global__ void RecoverPaddingKernel(const float* input0,
                                     const int32_t* input1,
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
                                     float* output) {
  int word_id = blockIdx.x * gridDim.y + blockIdx.y;
  int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
  if (blockIdx.y < seqence_length) {
    output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
           threadIdx.x] =
        input0[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
               blockIdx.z * blockDim.x + threadIdx.x];
  } else {
    output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
           threadIdx.x] = 0;
  }
}

nvinfer1::DataType RecoverPaddingPlugin::getOutputDataType(
39 40
    int index,
    const nvinfer1::DataType* input_types,
41 42 43 44 45
    int nb_inputs) const TRT_NOEXCEPT {
  return input_types[0];
}

nvinfer1::DimsExprs RecoverPaddingPlugin::getOutputDimensions(
46 47 48
    int outputIndex,
    const nvinfer1::DimsExprs* inputs,
    int nbInputs,
49 50 51 52
    nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
  nvinfer1::DimsExprs output_dims{};
  output_dims.nbDims = 3;
  const auto* one = exprBuilder.constant(1);
53 54
  output_dims.d[0] = exprBuilder.operation(
      nvinfer1::DimensionOperation::kSUB, *inputs[1].d[0], *one);
55 56 57 58 59 60
  output_dims.d[1] = inputs[2].d[1];
  output_dims.d[2] = inputs[0].d[1];
  return output_dims;
}

bool RecoverPaddingPlugin::supportsFormatCombination(
61 62 63
    int pos,
    const nvinfer1::PluginTensorDesc* inOut,
    int nbInputs,
64
    int nbOutputs) TRT_NOEXCEPT {
65 66
  PADDLE_ENFORCE_EQ(nbInputs,
                    3,
67 68 69
                    platform::errors::InvalidArgument("Must have 3 inputs, "
                                                      "but got %d input(s). ",
                                                      nbInputs));
70 71
  PADDLE_ENFORCE_EQ(nbOutputs,
                    getNbOutputs(),
72 73 74
                    platform::errors::InvalidArgument("Must have 1 output, "
                                                      "but got %d output(s). ",
                                                      nbOutputs));
75
  if (pos == 1) {  // PosId
76 77
    return inOut[pos].type == nvinfer1::DataType::kINT32 &&
           inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
78 79 80 81 82 83
  } else if (pos == 2) {  // mask_id
    return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
           inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
  } else {
    return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
           inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
84 85 86 87 88 89 90 91 92 93
  }
  // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
  // == nvinfer1::TensorFormat::kLINEAR)||
  // (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
  // nvinfer1::TensorFormat::kLINEAR)||
  // (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
  // nvinfer1::TensorFormat::kCHW32);
}

void RecoverPaddingPlugin::configurePlugin(
94 95
    const nvinfer1::DynamicPluginTensorDesc* inputs,
    int nbInputs,
96 97 98 99
    const nvinfer1::DynamicPluginTensorDesc* outputs,
    int nbOutputs) TRT_NOEXCEPT {}

void RecoverPaddingPlugin::attachToContext(
100 101
    cudnnContext* cudnnContext,
    cublasContext* cublasContext,
102 103 104 105 106 107 108 109 110
    nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}

void RecoverPaddingPlugin::detachFromContext() TRT_NOEXCEPT {}

void RecoverPaddingPlugin::terminate() TRT_NOEXCEPT {}

int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
                                  const nvinfer1::PluginTensorDesc* outputDesc,
                                  const void* const* inputs,
111 112
                                  void* const* outputs,
                                  void* workspace,
113 114 115 116 117 118 119 120 121 122
                                  cudaStream_t stream) TRT_NOEXCEPT {
  const auto input0_desc = inputDesc[0];
  const auto input1_desc = inputDesc[1];
  const auto input2_desc = inputDesc[2];
  const float* input0 = static_cast<const float*>(inputs[0]);
  const int32_t* input1 =
      static_cast<const int32_t*>(inputs[1]);  // pos_id_tensor
  float* output = static_cast<float*>(outputs[0]);
  const int32_t num_threads = 256;
  const dim3 num_blocks(
123 124
      input1_desc.dims.d[0] - 1,
      input2_desc.dims.d[1],
125 126 127
      input0_desc.dims.d[1] / num_threads);  //  batchs, max sequnce length
                                             //  (mask_id.dims.d[1]),
                                             //  input.dims.d[1]/256
128 129
  RecoverPaddingKernel<<<num_blocks, num_threads, 0, stream>>>(
      input0, input1, output);
130 131 132 133 134 135 136
  return cudaGetLastError() != cudaSuccess;
}

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle