elementwise_op_plugin.cu 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

namespace details {
template <typename T>
struct Add {
26
  __device__ T operator()(const T &a, const T &b) const { return a + b; }
27 28 29 30
};

template <typename T>
struct Mul {
31
  __device__ T operator()(const T &a, const T &b) const { return a * b; }
32
};
33 34 35 36 37

template <typename T>
struct Div {
  __device__ T operator()(const T &a, const T &b) const { return a / b; }
};
38
}  // namespace details
39 40

template <typename T, typename Operator>
41 42 43 44 45 46 47 48 49 50 51
__global__ void elementwise_kernel(const size_t total, const T *x_data,
                                   const T *y_data, T *out_data, int pre, int n,
                                   int post, Operator op) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < total) {
    int idx = tid / post % n;
#if __CUDA_ARCH__ >= 350
    out_data[tid] = op(__ldg(x_data + tid), __ldg(y_data + idx));
#else
    out_data[tid] = op(x_data[tid], y_data[idx]);
#endif
52 53 54 55
  }
}

nvinfer1::Dims ElementWisePlugin::getOutputDimensions(
56
    int index, const nvinfer1::Dims *input_dims, int num_inputs) TRT_NOEXCEPT {
57 58 59 60 61 62 63 64 65 66 67 68
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "There is only one output in TRT elementwise "
                                  "op plugin, but got output index: %d.",
                                  index));
  PADDLE_ENFORCE_EQ(num_inputs, 2, platform::errors::InvalidArgument(
                                       "There are 2 inputs in TRT elementwise "
                                       "op plugin, but got input number: %d.",
                                       num_inputs));
  PADDLE_ENFORCE_NOT_NULL(
      input_dims,
      platform::errors::InvalidArgument(
          "The input dims of TRT elementwise op plugin should not be null."));
69 70 71
  return input_dims[0];
}

72
int ElementWisePlugin::initialize() TRT_NOEXCEPT {
73 74 75 76 77 78 79 80 81
  axis_ = (axis_ == -1) ? dims_x_.nbDims - dims_y_.nbDims : axis_;
  int trimed_nb_dims = dims_y_.nbDims;
  for (; trimed_nb_dims > 0; --trimed_nb_dims) {
    if (dims_y_.d[trimed_nb_dims - 1] != 1) {
      break;
    }
  }
  dims_y_.nbDims = trimed_nb_dims;

82 83 84 85 86 87 88 89 90 91 92 93
  PADDLE_ENFORCE_GE(dims_x_.nbDims, dims_y_.nbDims + axis_,
                    platform::errors::InvalidArgument(
                        "We expect [number of x dims] >= [number of y dims + "
                        "axis] in TRT elementwise op plugin, but got [number "
                        "of x dims] = %d, [number of y dims + axis] = %d.",
                        dims_x_.nbDims, dims_y_.nbDims + axis_));
  PADDLE_ENFORCE_LT(
      axis_, dims_x_.nbDims,
      platform::errors::InvalidArgument("We expect [axis] < [number of x dims] "
                                        "in TRT elementwise op plugin, but got "
                                        "[axis] = %d, [number of x dims] = %d.",
                                        axis_, dims_x_.nbDims));
94 95 96 97 98 99 100 101 102 103

  prev_size_ = 1;
  midd_size_ = 1;
  post_size_ = 1;
  for (int i = 0; i < axis_; ++i) {
    prev_size_ *= dims_x_.d[i];
  }

  for (int i = 0; i < dims_y_.nbDims; ++i) {
    PADDLE_ENFORCE_EQ(dims_x_.d[i + axis_], dims_y_.d[i],
104 105 106
                      platform::errors::InvalidArgument(
                          "Broadcast dimension mismatch. The dims of input Y "
                          "should be a subsequence of X."));
107 108 109 110 111 112 113 114 115
    midd_size_ *= dims_y_.d[i];
  }

  for (int i = axis_ + dims_y_.nbDims; i < dims_x_.nbDims; ++i) {
    post_size_ *= dims_x_.d[i];
  }
  return 0;
}

116
int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,
117
#if IS_TRT_VERSION_LT(8000)
118
                               void **outputs, void *workspace,
119 120 121
#else
                               void *const *outputs, void *workspace,
#endif
122
                               cudaStream_t stream) TRT_NOEXCEPT {
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  const float *x = reinterpret_cast<const float *>(inputs[0]);
  const float *y = reinterpret_cast<const float *>(inputs[1]);
  float *out = reinterpret_cast<float *>(outputs[0]);

  int num = batch_size * prev_size_ * midd_size_ * post_size_;
  int thread = 256;
  int block = (num + thread - 1) / thread;
  if (type_ == "add") {
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
        details::Add<float>());
  } else if (type_ == "mul") {
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
        details::Mul<float>());
138 139 140 141
  } else if (type_ == "div") {
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
        details::Div<float>());
142 143 144 145 146 147 148 149 150 151 152
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "The %s type elementwise is not implemented in trt plugin.", type_));
  }

  return cudaGetLastError() != cudaSuccess;
}

// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)

153
int ElementwisePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
154

155
size_t ElementwisePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
156 157
  return SerializedSize(type_.c_str()) + SerializedSize(axis_);
}
158

159
void ElementwisePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
160 161 162
  SerializeValue(&buffer, type_.c_str());
  SerializeValue(&buffer, axis_);
}
163 164 165

nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
166
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
167 168 169 170 171
  return inputs[0];
}

bool ElementwisePluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
172
    int nb_outputs) TRT_NOEXCEPT {
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc &in = in_out[pos];
  if (pos == 0) {
    return (in.type == nvinfer1::DataType::kFLOAT) &&
           (in.format == nvinfer1::TensorFormat::kLINEAR);
  }
  const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType ElementwisePluginDynamic::getOutputDataType(
195 196
    int index, const nvinfer1::DataType *input_types,
    int nb_inputs) const TRT_NOEXCEPT {
197 198 199 200 201 202 203 204 205 206 207
  PADDLE_ENFORCE_EQ(index, 0,
                    platform::errors::InvalidArgument(
                        "The Elementwise Plugin only has one input, so the "
                        "index value should be 0, but get %d.",
                        index));
  return input_types[0];
}

int ElementwisePluginDynamic::enqueue(
    const nvinfer1::PluginTensorDesc *input_desc,
    const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
208
    void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT {
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
  auto x_dims = input_desc[0].dims;
  auto y_dims = input_desc[1].dims;
  int axis = (axis_ == -1) ? x_dims.nbDims - y_dims.nbDims : axis_;
  int batch_size = x_dims.d[0];

  int prev_size = 1;
  int midd_size = 1;
  int post_size = 1;
  for (int i = 0; i < axis; ++i) {
    prev_size *= x_dims.d[i];
  }

  int trimed_nb_dims = y_dims.nbDims;
  for (; trimed_nb_dims > 0; --trimed_nb_dims) {
    if (y_dims.d[trimed_nb_dims - 1] != 1) {
      break;
    }
  }

  for (int i = 0; i < trimed_nb_dims; ++i) {
    PADDLE_ENFORCE_EQ(x_dims.d[i + axis], y_dims.d[i],
                      platform::errors::InvalidArgument(
                          "Broadcast dimension mismatch found in trt "
                          "elementwise plugin's x and y input."));
    midd_size *= y_dims.d[i];
  }

  for (int i = axis + trimed_nb_dims; i < x_dims.nbDims; ++i) {
    post_size *= x_dims.d[i];
  }

  const float *x = static_cast<const float *>(inputs[0]);
  const float *y = static_cast<const float *>(inputs[1]);

  float *out = static_cast<float *>(outputs[0]);
244

245 246 247
  int num = prev_size * midd_size * post_size;
  int thread = 256;
  int block = (num + thread - 1) / thread;
N
nhzlx 已提交
248
  if (type_ == "add") {
249 250
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size, midd_size, post_size, details::Add<float>());
N
nhzlx 已提交
251
  } else if (type_ == "mul") {
252 253
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size, midd_size, post_size, details::Mul<float>());
254 255 256
  } else if (type_ == "div") {
    elementwise_kernel<<<block, thread, 0, stream>>>(
        num, x, y, out, prev_size, midd_size, post_size, details::Div<float>());
257
  } else {
258 259 260 261 262
    PADDLE_THROW(
        platform::errors::Unimplemented("Paddle-TRT only support elementwise "
                                        "operation: {add, mul, div} currently, "
                                        "but got %s.",
                                        type_));
263 264 265 266
  }

  return cudaGetLastError() != cudaSuccess;
}
267
#endif
268 269 270 271 272

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle