未验证 提交 18650db3 编写于 作者: 5 5u13 提交者: GitHub

optimization of depthwise_conv2d grad (#46332)

上级 4839aca2
...@@ -176,7 +176,8 @@ __device__ __inline__ void KernelDepthwiseConvNCHW( ...@@ -176,7 +176,8 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset]; T in_data = input_data[offset];
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[weight_offset] * T(max(0.0f, double(in_data))); value += weight[weight_offset] *
T(max(0.0f, static_cast<double>(in_data)));
} else { } else {
value += weight[weight_offset] * in_data; value += weight[weight_offset] * in_data;
} }
...@@ -228,7 +229,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( ...@@ -228,7 +229,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T in_data = input_data[offset]; T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out; const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[0] * T(max(0.0f, double(in_data))); value += weight[0] * T(max(0.0f, static_cast<double>(in_data)));
} else { } else {
value += weight[0] * in_data; value += weight[0] * in_data;
} }
...@@ -281,7 +282,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW( ...@@ -281,7 +282,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
T(max(0.0f, double(input_data[offset]))); T(max(0.0f, static_cast<double>(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -337,7 +338,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( ...@@ -337,7 +338,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset + (h_in * input_width + w_in) * input_channels + c_in; in_offset + (h_in * input_width + w_in) * input_channels + c_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
T(max(0.0, double(input_data[offset]))); T(max(0.0, static_cast<double>(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -880,7 +881,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -880,7 +881,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk; image_wk;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
T(max(0.0f, double(input_data[input_id]))); T(max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id]; input_data[input_id];
...@@ -891,7 +892,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -891,7 +892,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
} }
T val = BlockReduceSum(s); T val = BlockReduceSum(s);
platform::CudaAtomicAdd(&filter_grad_data[gbid], val); if (threadIdx.y == 0 && threadIdx.x == 0) filter_grad_data[gbid] = val;
} }
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
...@@ -941,7 +942,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( ...@@ -941,7 +942,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id / filter_multiplier; kernel_id / filter_multiplier;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
T(max(0.0f, double(input_data[input_id]))); T(max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
input_data[input_id]; input_data[input_id];
...@@ -1013,7 +1014,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( ...@@ -1013,7 +1014,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T s(0); T s(0);
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s = output_grad_data[output_id] * s = output_grad_data[output_id] *
T(max(0.0f, double(input_data[input_id]))); T(max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s = output_grad_data[output_id] * input_data[input_id]; s = output_grad_data[output_id] * input_data[input_id];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册