未验证 提交 ed857585 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Support depthwise_conv2d fp16. (#44642)

* depthwise_fp16

* depthwise_fp16

* depthwise_fp16

* depthwise_fp16
上级 20759c30
...@@ -153,7 +153,7 @@ __device__ __inline__ void KernelDepthwiseConvNCHW( ...@@ -153,7 +153,7 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
const int c_in = c_out / filter_multiplier; const int c_in = c_out / filter_multiplier;
const T* weight = filter_data + c_out * filter_height * filter_width; const T* weight = filter_data + c_out * filter_height * filter_width;
T value = 0; T value(0);
const int h_in_start = -padding_height + h_out * stride_height; const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width; const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height; const int h_in_end = h_in_start + filter_height * dilate_height;
...@@ -176,7 +176,7 @@ __device__ __inline__ void KernelDepthwiseConvNCHW( ...@@ -176,7 +176,7 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset]; T in_data = input_data[offset];
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, in_data); value += weight[weight_offset] * T(max(0.0f, double(in_data)));
} else { } else {
value += weight[weight_offset] * in_data; value += weight[weight_offset] * in_data;
} }
...@@ -205,7 +205,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( ...@@ -205,7 +205,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
const int batch = idx / output_width / output_height / output_channels; const int batch = idx / output_width / output_height / output_channels;
const int c_in = c_out / filter_multiplier; const int c_in = c_out / filter_multiplier;
T value = 0; T value(0);
const int h_in_start = -padding_height + h_out * stride_height; const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width; const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height; const int h_in_end = h_in_start + filter_height * dilate_height;
...@@ -228,7 +228,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( ...@@ -228,7 +228,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T in_data = input_data[offset]; T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out; const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[0] * max(0.0f, in_data); value += weight[0] * T(max(0.0f, double(in_data)));
} else { } else {
value += weight[0] * in_data; value += weight[0] * in_data;
} }
...@@ -258,7 +258,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW( ...@@ -258,7 +258,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
const int c_out = blockIdx.x; const int c_out = blockIdx.x;
const int c_in = c_out / filter_multiplier; const int c_in = c_out / filter_multiplier;
T value = 0; T value(0);
const int h_in_start = -padding_height + h_out * stride_height; const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width; const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + c_filter * dilate_height; const int h_in_end = h_in_start + c_filter * dilate_height;
...@@ -281,7 +281,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW( ...@@ -281,7 +281,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
max(0.0f, input_data[offset]); T(max(0.0f, double(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -325,7 +325,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( ...@@ -325,7 +325,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
if (w_out >= output_width) { if (w_out >= output_width) {
continue; continue;
} }
T value = 0; T value(0);
const int w_in_start = -padding_width + w_out * stride_width; const int w_in_start = -padding_width + w_out * stride_width;
for (int h_in = h_in_start, h_f = 0; h_f < c_filter; for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
h_in += dilate_height, h_f++) { h_in += dilate_height, h_f++) {
...@@ -337,7 +337,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( ...@@ -337,7 +337,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset + (h_in * input_width + w_in) * input_channels + c_in; in_offset + (h_in * input_width + w_in) * input_channels + c_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
max(0.0f, input_data[offset]); T(max(0.0, double(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -482,13 +482,13 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNCHW( ...@@ -482,13 +482,13 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
w_in - (filter_width - 1) * dilate_width + padding_width; w_in - (filter_width - 1) * dilate_width + padding_width;
int w_out_end = w_in + padding_width; int w_out_end = w_in + padding_width;
T value = 0; T value(0);
int index = int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in; w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
if (input_data[index] <= 0) { if (input_data[index] <= T(0)) {
input_grad_data[index] = 0; input_grad_data[index] = 0;
continue; continue;
} }
...@@ -539,12 +539,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNHWC( ...@@ -539,12 +539,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
int w_out_start = int w_out_start =
w_in - (filter_width - 1) * dilate_width + padding_width; w_in - (filter_width - 1) * dilate_width + padding_width;
T value = 0; T value(0);
int index = ((batch * input_height + h_in) * input_width + w_in) * int index = ((batch * input_height + h_in) * input_width + w_in) *
input_channels + input_channels +
c_in; c_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
if (input_data[index] <= 0) { if (input_data[index] <= T(0)) {
input_grad_data[index] = 0; input_grad_data[index] = 0;
continue; continue;
} }
...@@ -603,12 +603,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW( ...@@ -603,12 +603,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height; int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;
T value = 0; T value(0);
int index = int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in; w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
if (input_data[index] <= 0) { if (input_data[index] <= T(0)) {
input_grad_data[index] = 0; input_grad_data[index] = 0;
continue; continue;
} }
...@@ -676,12 +676,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC( ...@@ -676,12 +676,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
} }
int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;
T value = 0; T value(0);
int index = ((batch * input_height + h_in) * input_width + w_in) * int index = ((batch * input_height + h_in) * input_width + w_in) *
input_channels + input_channels +
c_in; c_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
if (input_data[index] <= 0) { if (input_data[index] <= T(0)) {
input_grad_data[index] = 0; input_grad_data[index] = 0;
continue; continue;
} }
...@@ -854,7 +854,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -854,7 +854,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
const int dilate_height, const int dilate_height,
const int dilate_width, const int dilate_width,
T* filter_grad_data) { T* filter_grad_data) {
T s = 0; T s(0);
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
for (int image_w = threadIdx.x; image_w < output_width; for (int image_w = threadIdx.x; image_w < output_width;
...@@ -880,7 +880,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -880,7 +880,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk; image_wk;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
max(0.0f, input_data[input_id]); T(max(0.0f, double(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id]; input_data[input_id];
...@@ -921,7 +921,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( ...@@ -921,7 +921,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
int kernel_ih = blockIdx.x / filter_width; int kernel_ih = blockIdx.x / filter_width;
for (int kernel_id = threadIdx.x; kernel_id < output_channels; for (int kernel_id = threadIdx.x; kernel_id < output_channels;
kernel_id += blockDim.x) { kernel_id += blockDim.x) {
T s = 0; T s(0);
int gbid = int gbid =
((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw; ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw;
for (int image_w = threadIdx.y; image_w < output_width; for (int image_w = threadIdx.y; image_w < output_width;
...@@ -941,7 +941,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( ...@@ -941,7 +941,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id / filter_multiplier; kernel_id / filter_multiplier;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
max(0.0f, input_data[input_id]); T(max(0.0f, double(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
input_data[input_id]; input_data[input_id];
...@@ -1010,9 +1010,10 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( ...@@ -1010,9 +1010,10 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
((bid * output_height + image_h) * output_width + image_w) * ((bid * output_height + image_h) * output_width + image_w) *
output_channels + output_channels +
kernel_id; kernel_id;
T s = 0; T s(0);
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s = output_grad_data[output_id] * max(0.0f, input_data[input_id]); s = output_grad_data[output_id] *
T(max(0.0f, double(input_data[input_id])));
} else { } else {
s = output_grad_data[output_id] * input_data[input_id]; s = output_grad_data[output_id] * input_data[input_id];
} }
...@@ -1672,21 +1673,35 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext, ...@@ -1672,21 +1673,35 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext,
template class DepthwiseConvFunctor<phi::GPUContext, float, false>; template class DepthwiseConvFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFunctor<phi::GPUContext, double, false>; template class DepthwiseConvFunctor<phi::GPUContext, double, false>;
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, false>; template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, false>; template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, false>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
platform::float16,
false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, false>; template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, false>; template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, false>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
platform::float16,
false>;
template class DepthwiseConvFunctor<phi::GPUContext, float, true>; template class DepthwiseConvFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFunctor<phi::GPUContext, double, true>; template class DepthwiseConvFunctor<phi::GPUContext, double, true>;
template class DepthwiseConvFunctor<phi::GPUContext, platform::float16, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, true>; template class DepthwiseConvInputGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, true>; template class DepthwiseConvInputGradFunctor<phi::GPUContext, double, true>;
template class DepthwiseConvInputGradFunctor<phi::GPUContext,
platform::float16,
true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, true>; template class DepthwiseConvFilterGradFunctor<phi::GPUContext, float, true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, true>; template class DepthwiseConvFilterGradFunctor<phi::GPUContext, double, true>;
template class DepthwiseConvFilterGradFunctor<phi::GPUContext,
platform::float16,
true>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -139,4 +139,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad, ...@@ -139,4 +139,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::DepthwiseConvGradKernel, phi::DepthwiseConvGradKernel,
float, float,
double) {} double,
phi::dtype::float16) {}
...@@ -124,4 +124,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d, ...@@ -124,4 +124,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d,
ALL_LAYOUT, ALL_LAYOUT,
phi::DepthwiseConvKernel, phi::DepthwiseConvKernel,
float, float,
double) {} double,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册