diff --git a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc index 453f222f1f1e3f3b9ee8fa7bd49f4cab2286e7ea..b086c910d38a243d98315f2d6eb82ecc0ec5c06d 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc @@ -90,5 +90,4 @@ TEST(prelu_op, test_scalar) { } // namespace inference } // namespace paddle -// USE_OP(prelu); -USE_CPU_ONLY_OP(prelu); +USE_OP(prelu); diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index e822785ad6f4f6f67b72141f3e7b04aefa72e58b..95443e813327c1247ac530c4d2e68b3607ff0e73 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,4 +1,4 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu avg_pool_op_plugin.cu - DEPS enforce tensorrt_engine) + DEPS enforce tensorrt_engine prelu) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index e8f4254402a5d8a5e6c5a2384bf9fbe48341956e..3075e87ea6d719a3f49d14c8c4b8015f7d688a50 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -14,92 +14,16 @@ #include #include +#include #include "glog/logging.h" #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" +#include "paddle/fluid/operators/math/prelu.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { -static const int CUDA_NUM_THREADS = 1024; -static const int CUDA_MAX_NUM_BLOCKS = 65535; -inline static int GET_NUM_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -__global__ void PReluChannelWiseKernel(const float *input, const float *alpha, - float *output, int channel, - size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const float *in = input + offset; - float *out = output + offset; - float scale = alpha[blockIdx.x % channel]; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - float x = in[i]; - out[i] = (x > 0) ? x : scale * x; - } -} - -__global__ void PReluElementWiseKernel(const float *input, const float *alpha, - float *output, size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const float *in = input + offset; - const float *scale = alpha + offset; - float *out = output + offset; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - float x = in[i]; - out[i] = (x > 0) ? x : scale[i] * x; - } -} - -__global__ void PReluScalarKernel(const float *input, const float *alpha, - float *output, size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const float *in = input + offset; - float scale = *alpha; - float *out = output + offset; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - float x = in[i]; - out[i] = (x > 0) ? x : scale * x; - } -} - -static inline void PReluChannelWise(cudaStream_t stream, const float *input, - const float *alpha, float *output, - int batch_size, - const nvinfer1::Dims &dims) { - size_t unroll = batch_size * dims.d[0]; - size_t spatial_size = dims.d[1] * dims.d[2]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluChannelWiseKernel<<>>( - input, alpha, output, dims.d[0], spatial_size); -} - -static inline void PReluElementWise(cudaStream_t stream, const float *input, - const float *alpha, float *output, - int batch_size, - const nvinfer1::Dims &dims) { - size_t unroll = batch_size * dims.d[0]; - size_t spatial_size = dims.d[1] * dims.d[2]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluElementWiseKernel<<>>( - input, alpha, output, spatial_size); -} - -static inline void PReluScalar(cudaStream_t stream, const float *input, - const float *alpha, float *output, - int batch_size, const nvinfer1::Dims &dims) { - size_t unroll = batch_size * dims.d[0]; - size_t spatial_size = dims.d[1] * dims.d[2]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluScalarKernel<<>>( - input, alpha, output, spatial_size); -} - nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, const nvinfer1::Dims *inputDims, int nbInputs) { @@ -110,19 +34,31 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, return output_dims; } -int PReluPlugin::enqueue(int batchSize, const void *const *inputs, +int PReluPlugin::enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) { // input dims is CHW. const auto &input_dims = this->getInputDims(0); const float *input = reinterpret_cast(inputs[0]); const float *alpha = reinterpret_cast(alpha_.get().values); float *output = reinterpret_cast(outputs)[0]; + + std::vector input_shape; + input_shape.push_back(batch_size); + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + if (mode_ == "channel") { - PReluChannelWise(stream, input, alpha, output, batchSize, input_dims); + operators::math::PreluChannelWiseDirectCUDAFunctor + prelu_channel_wise; + prelu_channel_wise(stream, input, alpha, output, input_shape); } else if (mode_ == "element") { - PReluElementWise(stream, input, alpha, output, batchSize, input_dims); + operators::math::PreluElementWiseDirectCUDAFunctor + prelu_element_wise; + prelu_element_wise(stream, input, alpha, output, input_shape); } else { - PReluScalar(stream, input, alpha, output, batchSize, input_dims); + operators::math::PreluScalarDirectCUDAFunctor prelu_scalar; + prelu_scalar(stream, input, alpha, output, input_shape); } return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8c8dc7026e1b3c1bb1899ebcf151f52711ea5bc1..257bfc0a3f926d20abc4647b27e8e9cc2c49e014 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -70,7 +70,7 @@ endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) if (WITH_GPU) - set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) endif() # FIXME(typhoonzero): operator deps may not needed. diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 63363086adbf12c38ac09949ac20483116ccf4ee..b3d2ea38eb1bfffadc1f68c5a34bc4d557bdea3b 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -59,6 +59,7 @@ math_library(matrix_bit_code) math_library(unpooling) math_library(vol2col) +math_library(prelu) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu new file mode 100644 index 0000000000000000000000000000000000000000..701a802080f65ea32b95402682dc46362ccf0966 --- /dev/null +++ b/paddle/fluid/operators/math/prelu.cu @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/prelu.h" + +namespace paddle { +namespace operators { +namespace math { + +static const int CUDA_NUM_THREADS = 1024; +static const int CUDA_MAX_NUM_BLOCKS = 65535; +inline static int GET_NUM_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void PReluChannelWiseKernel(const T *input, const T *alpha, + T *output, int channel, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const T *in = input + offset; + T *out = output + offset; + T scale = alpha[blockIdx.x % channel]; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +template +__global__ void PReluElementWiseKernel(const T *input, const T *alpha, + T *output, size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const T *in = input + offset; + const T *scale = alpha + offset; + T *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T x = in[i]; + out[i] = (x > 0) ? x : scale[i] * x; + } +} + +template +__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const T *in = input + offset; + T scale = *alpha; + T *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +template +static inline void PReluChannelWise(cudaStream_t stream, const T *input, + const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluChannelWiseKernel<<>>( + input, alpha, output, input_shape[1], spatial_size); +} + +template +static inline void PReluElementWise(cudaStream_t stream, const T *input, + const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluElementWiseKernel<<>>( + input, alpha, output, spatial_size); +} + +template +static inline void PReluScalar(cudaStream_t stream, const T *input, + const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluScalarKernel<<>>( + input, alpha, output, spatial_size); +} + +template +void PreluChannelWiseDirectCUDAFunctor::operator()( + cudaStream_t stream, const T *input, const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluChannelWiseKernel<<>>( + input, alpha, output, input_shape[1], spatial_size); +} + +template +void PreluElementWiseDirectCUDAFunctor::operator()( + cudaStream_t stream, const T *input, const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluElementWiseKernel<<>>( + input, alpha, output, spatial_size); +} + +template +void PreluScalarDirectCUDAFunctor::operator()(cudaStream_t stream, + const T *input, const T *alpha, + T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluScalarKernel<<>>( + input, alpha, output, spatial_size); +} + +template class PreluChannelWiseDirectCUDAFunctor; +template class PreluChannelWiseDirectCUDAFunctor; + +template class PreluElementWiseDirectCUDAFunctor; +template class PreluElementWiseDirectCUDAFunctor; + +template class PreluScalarDirectCUDAFunctor; +template class PreluScalarDirectCUDAFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h new file mode 100644 index 0000000000000000000000000000000000000000..3237c6d4cbf956aafb4046ea2ffa42efe62e7b28 --- /dev/null +++ b/paddle/fluid/operators/math/prelu.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +#ifdef PADDLE_WITH_CUDA +template +class PreluChannelWiseDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T *input, const T *alpha, + T *output, std::vector input_shape); +}; + +template +class PreluElementWiseDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T *input, const T *alpha, + T *output, std::vector input_shape); +}; + +template +class PreluScalarDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T *input, const T *alpha, + T *output, std::vector input_shape); +}; +#endif + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 58cfbb76e93a1c15c9b7cf9f9e596066c29b7ebb..64d94ab6044c1992145062319120b0372f5061c0 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -58,7 +58,7 @@ class PReluOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..36b5259ae5106914f5668625cad535ebc8aa72ec --- /dev/null +++ b/paddle/fluid/operators/prelu_op.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/prelu.h" +#include "paddle/fluid/operators/prelu_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class CUDAPReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* alpha = context.Input("Alpha"); + auto* out = context.Output("Out"); + + const T* x_ptr = x->data(); + T* o_ptr = out->mutable_data(context.GetPlace()); + + const T* alpha_ptr = alpha->data(); + auto& mode = context.Attr("mode"); + + int numel = x->numel(); + auto dim = x->dims(); + std::vector input_shape = framework::vectorize2int(dim); + + if (mode == "channel") { + math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; + prelu_channel_wise(context.cuda_device_context().stream(), x_ptr, + alpha_ptr, o_ptr, input_shape); + } else if (mode == "element") { + math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; + prelu_element_wise(context.cuda_device_context().stream(), x_ptr, + alpha_ptr, o_ptr, input_shape); + } else { + math::PreluScalarDirectCUDAFunctor prelu_scalar; + prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr, + o_ptr, input_shape); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + prelu, ops::CUDAPReluKernel, + ops::CUDAPReluKernel);