prelu_op_plugin.cu 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
#include <cassert>
N
nhzlx 已提交
17
#include <vector>
18 19
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
N
nhzlx 已提交
20
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
N
nhzlx 已提交
21
#include "paddle/fluid/operators/math/prelu.h"
22 23 24 25

namespace paddle {
namespace inference {
namespace tensorrt {
26
namespace plugin {
27

N
nhzlx 已提交
28 29 30 31 32 33 34 35 36
PReluPlugin *CreatePreluPluginDeserialize(const void *buffer, size_t length) {
  return new PReluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("prelu_plugin", CreatePreluPluginDeserialize);

int PReluPlugin::initialize() {
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
37
  return 0;
N
nhzlx 已提交
38 39
}

40 41 42 43 44 45 46 47 48 49
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
                                                const nvinfer1::Dims *inputDims,
                                                int nbInputs) {
  assert(nbInputs == 1);
  assert(index < this->getNbOutputs());
  nvinfer1::Dims const &input_dims = inputDims[0];
  nvinfer1::Dims output_dims = input_dims;
  return output_dims;
}

N
nhzlx 已提交
50
int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
51 52 53 54
                         void **outputs, void *workspace, cudaStream_t stream) {
  // input dims is CHW.
  const auto &input_dims = this->getInputDims(0);
  const float *input = reinterpret_cast<const float *>(inputs[0]);
N
nhzlx 已提交
55 56
  // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
  const float *alpha = p_gpu_weight_;
57
  float *output = reinterpret_cast<float **>(outputs)[0];
N
nhzlx 已提交
58 59 60 61 62 63 64

  std::vector<int> input_shape;
  input_shape.push_back(batch_size);
  for (int i = 0; i < input_dims.nbDims; i++) {
    input_shape.push_back(input_dims.d[i]);
  }

65
  if (mode_ == "channel") {
N
nhzlx 已提交
66 67 68
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
    prelu_channel_wise(stream, input, alpha, output, input_shape);
69
  } else if (mode_ == "element") {
N
nhzlx 已提交
70 71 72
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
    prelu_element_wise(stream, input, alpha, output, input_shape);
73
  } else {
N
nhzlx 已提交
74 75
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
    prelu_scalar(stream, input, alpha, output, input_shape);
76 77 78 79
  }
  return cudaGetLastError() != cudaSuccess;
}

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
#if IS_TRT_VERSION_GE(6000)

int PReluPluginDynamic::initialize() {
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
  return 0;
}
size_t PReluPluginDynamic::getSerializationSize() const { return 0; }

void PReluPluginDynamic::serialize(void *buffer) const {}

nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
    nvinfer1::IExprBuilder &expr_builder) {
  return inputs[0];
}

bool PReluPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
    int nb_outputs) {
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
          in_out[pos].format == nvinfer1::PluginFormat::kNCHW);
}

nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
    int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The PRelu Plugin only has one input, so the "
                                  "index value should be 0, but get %d.",
                                  index));
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true,
                    platform::errors::InvalidArgument(
                        "The input type should be half or float"));
  return input_types[0];
}

int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
                                const void *const *inputs, void *const *outputs,
                                void *workspace, cudaStream_t stream) {
  auto input_dims = input_desc[0].dims;
  const float *alpha = p_gpu_weight_;
  const float *input = static_cast<const float *>(inputs[0]);
  float *output = static_cast<float *>(outputs[0]);

  std::vector<int> input_shape;
  for (int i = 0; i < input_dims.nbDims; i++) {
    input_shape.push_back(input_dims.d[i]);
  }

  if (mode_ == "channel") {
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
    prelu_channel_wise(stream, input, alpha, output, input_shape);
  } else if (mode_ == "element") {
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
    prelu_element_wise(stream, input, alpha, output, input_shape);
  } else {
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
    prelu_scalar(stream, input, alpha, output, input_shape);
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

158
}  // namespace plugin
159 160 161
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle