未验证 提交 e249d9a3 编写于 作者: L lilong12 提交者: GitHub

fix the computation for dx (grad for x) for prelu operation. (#20949)

* set the default value of alpha for prelu to 0.25, test=develop

* add the call to __syncthreads(), test=develop

* fix the implementation of cpu prelu, test=develop

* repair the implementation of element mode prelu, test=develop

* modify test_prelu_op.py, test=develop
上级 edd6680a
......@@ -18,108 +18,76 @@ namespace paddle {
namespace operators {
namespace math {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
#define CUDA_NUM_THREADS 1024
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline static int PADDLE_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
T *output, int channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const T *in = input + offset;
T *out = output + offset;
T scale = alpha[blockIdx.x % channel];
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T x = in[i];
out[i] = (x > 0) ? x : scale * x;
T *output, size_t channel_num,
size_t plane_size, size_t numel) {
size_t index;
CUDA_KERNEL_LOOP(index, numel) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
T scale = alpha[channel_index];
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const T *in = input + offset;
const T *scale = alpha + offset;
T *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T x = in[i];
out[i] = (x > 0) ? x : scale[i] * x;
T *output, size_t spatial_size,
size_t numel) {
size_t index;
CUDA_KERNEL_LOOP(index, numel) {
size_t element_index = index % spatial_size;
T scale = alpha[element_index];
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const T *in = input + offset;
T scale = *alpha;
T *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T x = in[i];
out[i] = (x > 0) ? x : scale * x;
size_t numel) {
T scale = alpha[0];
size_t index;
CUDA_KERNEL_LOOP(index, numel) {
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
}
}
template <typename T>
static inline void PReluChannelWise(cudaStream_t stream, const T *input,
const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, input_shape[1], spatial_size);
}
template <typename T>
static inline void PReluElementWise(cudaStream_t stream, const T *input,
const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
template <typename T>
static inline void PReluScalar(cudaStream_t stream, const T *input,
const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
cudaStream_t stream, const T *input, const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, input_shape[1], spatial_size);
size_t plane_size = input_shape[2] * input_shape[3];
size_t spatial_size = input_shape[1] * plane_size;
size_t numel = input_shape[0] * spatial_size;
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, input_shape[1],
plane_size, numel);
}
template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(
cudaStream_t stream, const T *input, const T *alpha, T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
size_t plane_size = input_shape[2] * input_shape[3];
size_t spatial_size = input_shape[1] * plane_size;
size_t numel = input_shape[0] * spatial_size;
PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, spatial_size, numel);
}
template <typename T>
......@@ -127,11 +95,11 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
const T *input, const T *alpha,
T *output,
std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
size_t plane_size = input_shape[2] * input_shape[3];
size_t spatial_size = input_shape[1] * plane_size;
size_t numel = input_shape[0] * spatial_size;
PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, numel);
}
template class PreluChannelWiseDirectCUDAFunctor<float>;
......
......@@ -42,6 +42,7 @@ class PreluScalarDirectCUDAFunctor {
void operator()(cudaStream_t stream, const T *input, const T *alpha,
T *output, std::vector<int> input_shape);
};
#endif
} // namespace math
......
......@@ -42,10 +42,21 @@ class PReluOp : public framework::OperatorWithKernel {
"equal to the number of channels, should be %d",
x_dim[1]);
} else if (mode == "element") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim),
"For element-wise mode, size of weight Alpha must be "
"equal to the number of input, should be %d",
product(x_dim));
auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size();
auto x_rank = x_dim.size();
size_t x_product = 1;
size_t alpha_product = 1;
PADDLE_ENFORCE_EQ(alpha_rank, x_rank,
"For element-wise mode, rank of weight Alpha must be ",
"equal to the rank of input.");
for (int64_t i = x_rank - 1; i > 0; i--) {
x_product *= x_dim[i];
alpha_product *= alpha_dim[i];
}
PADDLE_ENFORCE_EQ(x_product, alpha_product,
"For element-wise mode, size of weight Alpha must be "
"equal to the number of input.");
} else {
PADDLE_THROW("Unkown mode %s", mode);
}
......
......@@ -20,11 +20,19 @@ limitations under the License. */
namespace paddle {
namespace operators {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
using Tensor = framework::Tensor;
#define CUDA_NUM_THREADS 1024
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline static int PADDLE_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename DeviceContext, typename T>
class CUDAPReluKernel : public framework::OpKernel<T> {
public:
......@@ -59,71 +67,47 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
}
};
namespace prelu {
struct ElementWiseMode {};
struct ChannelMode {};
struct ScalarMode {};
} /* namespace prelu */
template <typename T, typename M>
struct AlphaFunctor {
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
size_t spatial_size, size_t idx) const {}
};
template <typename T>
struct AlphaFunctor<T, prelu::ElementWiseMode> {
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
size_t spatial_size, size_t idx) const {
return alpha[blockIdx.x * spatial_size + idx];
}
};
enum PRELU_MODE { Element, Channel, Scalar };
template <typename T>
struct AlphaFunctor<T, prelu::ChannelMode> {
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
size_t spatial_size, size_t idx) const {
return alpha[blockIdx.x % channel];
}
};
template <typename T>
struct AlphaFunctor<T, prelu::ScalarMode> {
HOSTDEVICE inline T operator()(const T* alpha, size_t channel,
size_t spatial_size, size_t idx) const {
return alpha[0];
}
};
template <typename T, typename M>
__global__ void PReluGradElementWiseKernel(const T* x_ptr, const T* y_ptr,
const T* alpha_ptr, const T* dy_ptr,
T* dx_ptr, T* dalpha_ptr,
size_t channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
AlphaFunctor<T, M> alpha_func;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
T y = y_ptr[offset + i];
T x = x_ptr[offset + i];
T dy = dy_ptr[offset + i];
T alpha = alpha_func(alpha_ptr, channel, spatial_size, i);
if (dx_ptr != nullptr) dx_ptr[offset + i] = (y > 0) ? dy : alpha * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[offset + i] = (x > 0) ? 0 : x * dy;
__global__ void PReluOpGradKernel(const T* x_ptr, const T* y_ptr,
const T* alpha_ptr, const T* dy_ptr,
T* dx_ptr, T* dalpha_ptr, size_t channel_num,
size_t plane_size, size_t spatial_size,
size_t numel, PRELU_MODE mode) {
size_t index;
CUDA_KERNEL_LOOP(index, numel) {
T scale;
if (mode == Element) {
size_t element_index = index % spatial_size;
scale = alpha_ptr[element_index];
} else if (mode == Channel) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
scale = alpha_ptr[channel_index];
} else {
scale = alpha_ptr[0];
}
T x = x_ptr[index];
T dy = dy_ptr[index];
if (dx_ptr != nullptr) dx_ptr[index] = (x > 0) ? dy : scale * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > 0) ? 0 : x * dy;
}
}
template <typename T, typename M>
class PreluGradElementwiseFunctor {
template <typename T>
class PreluOpGradFunctor {
public:
void operator()(cudaStream_t stream, const T* x, const T* y, const T* alpha,
const T* dy, T* dx, T* dalpha, std::vector<int> input_shape) {
size_t unroll = input_shape[0] * input_shape[1];
size_t spatial_size = input_shape[2] * input_shape[3];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluGradElementWiseKernel<T, M><<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
x, y, alpha, dy, dx, dalpha, input_shape[1], spatial_size);
const T* dy, T* dx, T* dalpha, std::vector<int> input_shape,
PRELU_MODE mode) {
size_t plane_size = input_shape[2] * input_shape[3];
size_t spatial_size = plane_size * input_shape[1];
size_t numel = spatial_size * input_shape[0];
PReluOpGradKernel<
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
x, y, alpha, dy, dx, dalpha, input_shape[1], plane_size, spatial_size,
numel, mode);
}
};
......@@ -162,7 +146,7 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
T* dalpha_tmp_ptr;
Tensor dalpha_tmp;
if (mode == "element" || dalpha_ptr == nullptr) {
if (dalpha_ptr == nullptr) {
dalpha_tmp_ptr = dalpha_ptr;
} else {
auto& dev_ctx = context.template device_context<DeviceContext>();
......@@ -170,25 +154,24 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
dalpha_tmp_ptr = dalpha_tmp.mutable_data<T>(context.GetPlace());
}
PRELU_MODE m;
if (mode == "element") {
PreluGradElementwiseFunctor<T, prelu::ElementWiseMode> prelu_grad;
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr,
dalpha_tmp_ptr, input_shape);
m = Element;
} else if (mode == "channel") {
PreluGradElementwiseFunctor<T, prelu::ChannelMode> prelu_grad;
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr,
dalpha_tmp_ptr, input_shape);
m = Channel;
} else {
PreluGradElementwiseFunctor<T, prelu::ScalarMode> prelu_grad;
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr,
dalpha_tmp_ptr, input_shape);
m = Scalar;
}
PreluOpGradFunctor<T> prelu_grad;
prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr,
input_shape, m);
if (mode == "element" || dalpha_tmp_ptr == nullptr) return;
if (dalpha_tmp_ptr == nullptr) return;
std::vector<int> reduce_dims;
for (size_t i = 0; i < input_shape.size(); i++) {
if (mode == "channel" && i == 1) continue;
if (mode == "element" && i != 0) continue;
reduce_dims.push_back(i);
}
......
......@@ -45,8 +45,10 @@ class PReluKernel : public framework::OpKernel<T> {
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else if (mode == "element") {
int temp = numel / dim[0];
for (i = 0; i < numel; i++) {
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i];
index = i % temp;
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
......@@ -64,12 +66,10 @@ class PReluGradKernel : public framework::OpKernel<T> {
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dalpha = context.Output<Tensor>(framework::GradVarName("Alpha"));
auto* out = context.Input<Tensor>("Out");
auto* alpha = context.Input<Tensor>("Alpha");
const T* alpha_ptr = alpha->data<T>();
const T* x_ptr = x->data<T>();
const T* dout_ptr = dout->data<T>();
const T* out_ptr = out->data<T>();
std::string mode = context.Attr<std::string>("mode");
int numel = x->numel();
auto dim = x->dims();
......@@ -83,15 +83,18 @@ class PReluGradKernel : public framework::OpKernel<T> {
temp = numel / (dim[0] * dim[1]);
index = (i / temp) % dim[1];
dx_ptr[i] =
out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
} else if (mode == "element") {
temp = numel / dim[0];
for (i = 0; i < numel; i++) {
dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i];
index = i % temp;
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i];
dx_ptr[i] = x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i];
}
}
}
......@@ -105,15 +108,17 @@ class PReluGradKernel : public framework::OpKernel<T> {
for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
index = (i / temp) % dim[1];
dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else if (mode == "element") {
temp = numel / dim[0];
for (i = 0; i < numel; i++) {
dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
index = i % temp;
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
dalpha_ptr[0] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
}
}
......
......@@ -66,6 +66,7 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
Ty reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim)
reduce_var = reducer(reduce_var, transformer(x[idx_x + idx_y]));
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
......@@ -113,6 +114,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
reduce_var = static_cast<Ty>(reducer(reduce_var, transformer(x[idx_x])));
}
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
......
......@@ -10034,14 +10034,14 @@ def prelu(x, mode, param_attr=None, name=None):
if mode == 'channel':
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
alpha_shape = x.shape
alpha_shape = x.shape[1:]
dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter(
attr=helper.param_attr,
shape=alpha_shape,
dtype='float32',
is_bias=False,
default_initializer=Constant(1.0))
default_initializer=Constant(0.25))
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prelu",
......
......@@ -37,7 +37,8 @@ class PReluTest(OpTest):
alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np}
else:
alpha_np = np.random.rand(*x_np.shape).astype("float32")
alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], \
x_np.shape[3]).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册