elementwise_op_plugin.cu 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

namespace details {
template <typename T>
struct Add {
26
  __device__ T operator()(const T &a, const T &b) const { return a + b; }
27 28 29 30
};

template <typename T>
struct Mul {
31
  __device__ T operator()(const T &a, const T &b) const { return a * b; }
32
};
33
}  // namespace details
34 35

template <typename T, typename Operator>
36 37 38 39 40 41 42 43 44 45 46
__global__ void elementwise_kernel(const size_t total, const T *x_data,
                                   const T *y_data, T *out_data, int pre, int n,
                                   int post, Operator op) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < total) {
    int idx = tid / post % n;
#if __CUDA_ARCH__ >= 350
    out_data[tid] = op(__ldg(x_data + tid), __ldg(y_data + idx));
#else
    out_data[tid] = op(x_data[tid], y_data[idx]);
#endif
47 48 49 50
  }
}

nvinfer1::Dims ElementWisePlugin::getOutputDimensions(
51
    int index, const nvinfer1::Dims *input_dims, int num_inputs) TRT_NOEXCEPT {
52 53 54 55 56 57 58 59 60 61 62 63
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "There is only one output in TRT elementwise "
                                  "op plugin, but got output index: %d.",
                                  index));
  PADDLE_ENFORCE_EQ(num_inputs, 2, platform::errors::InvalidArgument(
                                       "There are 2 inputs in TRT elementwise "
                                       "op plugin, but got input number: %d.",
                                       num_inputs));
  PADDLE_ENFORCE_NOT_NULL(
      input_dims,
      platform::errors::InvalidArgument(
          "The input dims of TRT elementwise op plugin should not be null."));
64 65 66
  return input_dims[0];
}

67
int ElementWisePlugin::initialize() TRT_NOEXCEPT {
68 69 70 71 72
  PADDLE_ENFORCE_GT(dims_y_.nbDims, 0,
                    platform::errors::InvalidArgument(
                        "The dimension of input Y of TRT elementwise op plugin "
                        "should be greater than 0, but got %d.",
                        dims_y_.nbDims));
73 74 75 76 77 78 79 80 81 82

  axis_ = (axis_ == -1) ? dims_x_.nbDims - dims_y_.nbDims : axis_;
  int trimed_nb_dims = dims_y_.nbDims;
  for (; trimed_nb_dims > 0; --trimed_nb_dims) {
    if (dims_y_.d[trimed_nb_dims - 1] != 1) {
      break;
    }
  }
  dims_y_.nbDims = trimed_nb_dims;

83 84 85 86 87 88 89 90 91 92 93 94
  PADDLE_ENFORCE_GE(dims_x_.nbDims, dims_y_.nbDims + axis_,
                    platform::errors::InvalidArgument(
                        "We expect [number of x dims] >= [number of y dims + "
                        "axis] in TRT elementwise op plugin, but got [number "
                        "of x dims] = %d, [number of y dims + axis] = %d.",
                        dims_x_.nbDims, dims_y_.nbDims + axis_));
  PADDLE_ENFORCE_LT(
      axis_, dims_x_.nbDims,
      platform::errors::InvalidArgument("We expect [axis] < [number of x dims] "
                                        "in TRT elementwise op plugin, but got "
                                        "[axis] = %d, [number of x dims] = %d.",
                                        axis_, dims_x_.nbDims));
95 96 97 98 99 100 101 102 103 104

  prev_size_ = 1;
  midd_size_ = 1;
  post_size_ = 1;
  for (int i = 0; i < axis_; ++i) {
    prev_size_ *= dims_x_.d[i];
  }

  for (int i = 0; i < dims_y_.nbDims; ++i) {
    PADDLE_ENFORCE_EQ(dims_x_.d[i + axis_], dims_y_.d[i],
105 106 107
                      platform::errors::InvalidArgument(
                          "Broadcast dimension mismatch. The dims of input Y "
                          "should be a subsequence of X."));
108 109 110 111 112 113 114 115 116
    midd_size_ *= dims_y_.d[i];
  }

  for (int i = axis_ + dims_y_.nbDims; i < dims_x_.nbDims; ++i) {
    post_size_ *= dims_x_.d[i];
  }
  return 0;
}

117
int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,
118
#if IS_TRT_VERSION_LT(8000)
119
                               void **outputs, void *workspace,
120 121 122
#else
                               void *const *outputs, void *workspace,
#endif
123
                               cudaStream_t stream) TRT_NOEXCEPT {
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
  const float *x = reinterpret_cast<const float *>(inputs[0]);
  const float *y = reinterpret_cast<const float *>(inputs[1]);
  float *out = reinterpret_cast<float *>(outputs[0]);

  int num = batch_size * prev_size_ * midd_size_ * post_size_;
  int thread = 256;
  int block = (num + thread - 1) / thread;
  if (type_ == "add") {
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
        details::Add<float>());
  } else if (type_ == "mul") {
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
        details::Mul<float>());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "The %s type elementwise is not implemented in trt plugin.", type_));
  }

  return cudaGetLastError() != cudaSuccess;
}

// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)

150
int ElementwisePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
151

152
size_t ElementwisePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
153 154
  return SerializedSize(type_.c_str()) + SerializedSize(axis_);
}
155

156
void ElementwisePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
157 158 159
  SerializeValue(&buffer, type_.c_str());
  SerializeValue(&buffer, axis_);
}
160 161 162

nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
163
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
164 165 166 167 168
  return inputs[0];
}

bool ElementwisePluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
169
    int nb_outputs) TRT_NOEXCEPT {
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc &in = in_out[pos];
  if (pos == 0) {
    return (in.type == nvinfer1::DataType::kFLOAT) &&
           (in.format == nvinfer1::TensorFormat::kLINEAR);
  }
  const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType ElementwisePluginDynamic::getOutputDataType(
192 193
    int index, const nvinfer1::DataType *input_types,
    int nb_inputs) const TRT_NOEXCEPT {
194 195 196 197 198 199 200 201 202 203 204
  PADDLE_ENFORCE_EQ(index, 0,
                    platform::errors::InvalidArgument(
                        "The Elementwise Plugin only has one input, so the "
                        "index value should be 0, but get %d.",
                        index));
  return input_types[0];
}

int ElementwisePluginDynamic::enqueue(
    const nvinfer1::PluginTensorDesc *input_desc,
    const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
205
    void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT {
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
  auto x_dims = input_desc[0].dims;
  auto y_dims = input_desc[1].dims;
  int axis = (axis_ == -1) ? x_dims.nbDims - y_dims.nbDims : axis_;
  int batch_size = x_dims.d[0];

  int prev_size = 1;
  int midd_size = 1;
  int post_size = 1;
  for (int i = 0; i < axis; ++i) {
    prev_size *= x_dims.d[i];
  }

  int trimed_nb_dims = y_dims.nbDims;
  for (; trimed_nb_dims > 0; --trimed_nb_dims) {
    if (y_dims.d[trimed_nb_dims - 1] != 1) {
      break;
    }
  }

  for (int i = 0; i < trimed_nb_dims; ++i) {
    PADDLE_ENFORCE_EQ(x_dims.d[i + axis], y_dims.d[i],
                      platform::errors::InvalidArgument(
                          "Broadcast dimension mismatch found in trt "
                          "elementwise plugin's x and y input."));
    midd_size *= y_dims.d[i];
  }

  for (int i = axis + trimed_nb_dims; i < x_dims.nbDims; ++i) {
    post_size *= x_dims.d[i];
  }

  const float *x = static_cast<const float *>(inputs[0]);
  const float *y = static_cast<const float *>(inputs[1]);

  float *out = static_cast<float *>(outputs[0]);
241

242 243 244
  int num = prev_size * midd_size * post_size;
  int thread = 256;
  int block = (num + thread - 1) / thread;
N
nhzlx 已提交
245
  if (type_ == "add") {
246 247
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size, midd_size, post_size, details::Add<float>());
N
nhzlx 已提交
248
  } else if (type_ == "mul") {
249 250
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size, midd_size, post_size, details::Mul<float>());
251
  } else {
252 253 254 255
    PADDLE_THROW(platform::errors::Unimplemented(
        "Paddle-TRT only support elementwise operation: {add, mul} currently, "
        "but got %s.",
        type_));
256 257 258 259
  }

  return cudaGetLastError() != cudaSuccess;
}
260
#endif
261 262 263 264 265

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle