prelu_op_plugin.cu 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
16

17
#include <cassert>
N
nhzlx 已提交
18
#include <vector>
19

20 21
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
N
nhzlx 已提交
22
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
N
nhzlx 已提交
23
#include "paddle/fluid/operators/math/prelu.h"
24 25 26 27

namespace paddle {
namespace inference {
namespace tensorrt {
28
namespace plugin {
29

N
nhzlx 已提交
30 31 32 33 34 35 36 37 38
PReluPlugin *CreatePreluPluginDeserialize(const void *buffer, size_t length) {
  return new PReluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("prelu_plugin", CreatePreluPluginDeserialize);

int PReluPlugin::initialize() {
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
39
  return 0;
N
nhzlx 已提交
40 41
}

42 43 44 45 46 47 48 49 50 51
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
                                                const nvinfer1::Dims *inputDims,
                                                int nbInputs) {
  assert(nbInputs == 1);
  assert(index < this->getNbOutputs());
  nvinfer1::Dims const &input_dims = inputDims[0];
  nvinfer1::Dims output_dims = input_dims;
  return output_dims;
}

N
nhzlx 已提交
52
int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
53 54 55 56
                         void **outputs, void *workspace, cudaStream_t stream) {
  // input dims is CHW.
  const auto &input_dims = this->getInputDims(0);
  const float *input = reinterpret_cast<const float *>(inputs[0]);
N
nhzlx 已提交
57 58
  // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
  const float *alpha = p_gpu_weight_;
59
  float *output = reinterpret_cast<float **>(outputs)[0];
60
  int numel = 1;
N
nhzlx 已提交
61
  for (int i = 0; i < input_dims.nbDims; i++) {
62
    numel *= input_dims.d[i];
N
nhzlx 已提交
63 64
  }

65
  if (mode_ == "channel") {
N
nhzlx 已提交
66 67
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
68 69
    prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
                       input_dims.d[1], numel);
70
  } else if (mode_ == "element") {
N
nhzlx 已提交
71 72
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
73
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
74
  } else {
N
nhzlx 已提交
75
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
76
    prelu_scalar(stream, input, alpha, output, numel);
77 78 79 80
  }
  return cudaGetLastError() != cudaSuccess;
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
#if IS_TRT_VERSION_GE(6000)

int PReluPluginDynamic::initialize() {
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
  return 0;
}
size_t PReluPluginDynamic::getSerializationSize() const { return 0; }

void PReluPluginDynamic::serialize(void *buffer) const {}

nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
    nvinfer1::IExprBuilder &expr_builder) {
  return inputs[0];
}

bool PReluPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
    int nb_outputs) {
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
          in_out[pos].format == nvinfer1::PluginFormat::kNCHW);
}

nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
    int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The PRelu Plugin only has one input, so the "
                                  "index value should be 0, but get %d.",
                                  index));
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true,
                    platform::errors::InvalidArgument(
                        "The input type should be half or float"));
  return input_types[0];
}

int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
                                const void *const *inputs, void *const *outputs,
                                void *workspace, cudaStream_t stream) {
  auto input_dims = input_desc[0].dims;
  const float *alpha = p_gpu_weight_;
  const float *input = static_cast<const float *>(inputs[0]);
  float *output = static_cast<float *>(outputs[0]);
137
  int numel = 1;
138
  for (int i = 0; i < input_dims.nbDims; i++) {
139
    numel *= input_dims.d[i];
140 141 142 143 144
  }

  if (mode_ == "channel") {
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
145 146
    prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
                       input_dims.d[1], numel);
147 148 149
  } else if (mode_ == "element") {
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
150
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
151 152
  } else {
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
153
    prelu_scalar(stream, input, alpha, output, numel);
154 155 156 157 158
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

159
}  // namespace plugin
160 161 162
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle