未验证 提交 c6b39a00 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge pull request #14714 from NHZlX/add_prelu_gpu

add prelu cuda kernel for inference.
......@@ -90,5 +90,4 @@ TEST(prelu_op, test_scalar) {
} // namespace inference
} // namespace paddle
// USE_OP(prelu);
USE_CPU_ONLY_OP(prelu);
USE_OP(prelu);
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine)
DEPS enforce tensorrt_engine prelu)
......@@ -14,92 +14,16 @@
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/operators/math/prelu.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
float *output, int channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float *out = output + offset;
float scale = alpha[blockIdx.x % channel];
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
const float *scale = alpha + offset;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale[i] * x;
}
}
__global__ void PReluScalarKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float scale = *alpha;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
static inline void PReluChannelWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, dims.d[0], spatial_size);
}
static inline void PReluElementWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
static inline void PReluScalar(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size, const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
......@@ -110,19 +34,31 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
return output_dims;
}
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
float *output = reinterpret_cast<float **>(outputs)[0];
std::vector<int> input_shape;
input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
if (mode_ == "channel") {
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_shape);
} else if (mode_ == "element") {
PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
prelu_element_wise(stream, input, alpha, output, input_shape);
} else {
PReluScalar(stream, input, alpha, output, batchSize, input_dims);
operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
prelu_scalar(stream, input, alpha, output, input_shape);
}
return cudaGetLastError() != cudaSuccess;
}
......
......@@ -70,7 +70,7 @@ endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
endif()
# FIXME(typhoonzero): operator deps may not needed.
......
......@@ -59,6 +59,7 @@ math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
math_library(prelu)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/prelu.h"
namespace paddle {
namespace operators {
namespace math {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
T *output, int channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const T *in = input + offset;
T *out = output + offset;
T scale = alpha[blockIdx.x % channel];
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const T *in = input + offset;
const T *scale = alpha + offset;
T *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T x = in[i];
out[i] = (x > 0) ? x : scale[i] * x;
}
}
template <typename T>
__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const T *in = input + offset;
T scale = *alpha;
T *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
template <typename T>
static inline void PReluChannelWise(cudaStream_t stream, const T *input,
const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, input_shape[1], spatial_size);
}
template <typename T>
static inline void PReluElementWise(cudaStream_t stream, const T *input,
const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
template <typename T>
static inline void PReluScalar(cudaStream_t stream, const T *input,
const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
cudaStream_t stream, const T *input, const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, input_shape[1], spatial_size);
}
template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(
cudaStream_t stream, const T *input, const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
const T *input, const T *alpha,
T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<double>;
template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<double>;
template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
namespace math {
#ifdef PADDLE_WITH_CUDA
template <typename T>
class PreluChannelWiseDirectCUDAFunctor {
public:
void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape);
};
template <typename T>
class PreluElementWiseDirectCUDAFunctor {
public:
void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape);
};
template <typename T>
class PreluScalarDirectCUDAFunctor {
public:
void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape);
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -58,7 +58,7 @@ class PReluOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
ctx.device_context());
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class CUDAPReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* alpha = context.Input<Tensor>("Alpha");
auto* out = context.Output<Tensor>("Out");
const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace());
const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode");
int numel = x->numel();
auto dim = x->dims();
std::vector<int> input_shape = framework::vectorize2int(dim);
if (mode == "channel") {
math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, input_shape);
} else if (mode == "element") {
math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, input_shape);
} else {
math::PreluScalarDirectCUDAFunctor<T> prelu_scalar;
prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr,
o_ptr, input_shape);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
prelu, ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册