未验证 提交 04eb211a 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of depthwise_conv_bwd of filter (#46490)

* Optimize performance of depthwise_conv_bwd of filter

* op-benchmark

* fix

* op benchmark

* merge bwd
上级 f17a73e9
......@@ -87,43 +87,36 @@ class DepthwiseConvFilterGradFunctor {
const DataLayout data_layout = DataLayout::kNCHW);
};
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
template <typename T>
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
val = WarpReduce(temp_storage).Sum(val, warp_size);
__forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val += platform::CudaShuffleDownSync(lane_mask, val, mask);
return val;
}
template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) {
static __shared__ T shared[32];
int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y;
int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
int lane = thread_id % warp_size;
int wid = thread_id / warp_size;
val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed
int block_size = blockDim.x * blockDim.y * blockDim.z;
if (thread_id < (block_size - 1) / warp_size + 1) {
val = shared[lane];
} else {
val = static_cast<T>(0);
}
__forceinline__ __device__ T BlockReduceSum(T val, unsigned mask = FINAL_MASK) {
static __shared__ T shared[WARP_SIZE];
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int lane = tid & 0x1f;
int wid = tid >> 5;
val = WarpReduceSum<T>(val, mask);
if (wid == 0) {
val = WarpReduceSum(val, warp_size); // Final reduce within first warp
}
__syncthreads();
if (thread_id != 0) {
val = static_cast<T>(0);
}
if (lane == 0) shared[wid] = val;
__syncthreads();
// align block_span to WARP_SIZE
int block_span = (blockDim.x * blockDim.y + WARP_SIZE - 1) >> 5;
val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = WarpReduceSum<T>(val, mask);
return val;
}
......@@ -858,45 +851,81 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
const int dilate_height,
const int dilate_width,
T* filter_grad_data) {
T s(0);
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
for (int image_w = threadIdx.x; image_w < output_width;
image_w += blockDim.x) {
for (int bid = 0; bid < num; bid++) {
for (int image_h = threadIdx.y; image_h < output_height;
image_h += blockDim.y) {
int kernel_id = blockIdx.z;
int kernel_h = blockIdx.y * dilate_height - padding_height;
int kernel_w = blockIdx.x * dilate_width - padding_width;
int image_hk = image_h * stride_height + kernel_h;
int image_wk = image_w * stride_width + kernel_w;
if (image_hk < 0 || image_hk >= input_height) continue;
if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \
((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
int input_id = ((bid * (gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier) *
input_height +
image_hk) *
input_width +
image_wk;
T f_grad(0);
const bool loop_batch = output_height * output_width >= WARP_SIZE;
int kw_id = blockIdx.x;
int kh_id = blockIdx.y;
int oc_id = blockIdx.z;
int ic_id = oc_id / filter_multiplier;
int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
const int ohw = output_height * output_width;
const int onhw = num * ohw;
const int h_offset = kh_id * dilate_height - padding_height;
const int w_offset = kw_id * dilate_width - padding_width;
if (loop_batch) {
for (int og_w = threadIdx.x; og_w < output_width; og_w += blockDim.x) {
for (int bid = 0; bid < num; ++bid) {
for (int og_h = threadIdx.y; og_h < output_height; og_h += blockDim.y) {
int i_h = og_h * stride_height + h_offset;
int i_w = og_w * stride_width + w_offset;
if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
int input_offset =
((bid * input_channels + ic_id) * input_height + i_h) *
input_width +
i_w;
int output_grad_offset =
((bid * output_channels + oc_id) * output_height + og_h) *
output_width +
og_w;
if (fuse_relu_before_conv) {
f_grad +=
output_grad_data[output_grad_offset] *
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_offset])));
} else {
f_grad += output_grad_data[output_grad_offset] *
input_data[input_offset];
}
}
}
}
}
} else {
for (int id = threadIdx.x; id < onhw; id += blockDim.x) {
int bid = id / ohw;
int og_hw = id - bid * ohw;
int og_h = og_hw / output_width;
int og_w = og_hw - og_h * output_width;
int i_h = og_h * stride_height + h_offset;
int i_w = og_w * stride_width + w_offset;
if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
int input_offset =
((bid * input_channels + ic_id) * input_height + i_h) *
input_width +
i_w;
int output_grad_offset = (bid * output_channels + oc_id) * ohw + og_hw;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
f_grad += output_grad_data[output_grad_offset] *
static_cast<T>(max(
0.0f, static_cast<double>(input_data[input_offset])));
} else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id];
f_grad +=
output_grad_data[output_grad_offset] * input_data[input_offset];
}
#undef gaid
}
}
}
T val = BlockReduceSum(s);
if (threadIdx.y == 0 && threadIdx.x == 0) filter_grad_data[gbid] = val;
T val = BlockReduceSum<T>(f_grad);
if (threadIdx.x == 0 && threadIdx.y == 0) {
filter_grad_data[idx] = val;
}
}
template <typename T, bool fuse_relu_before_conv>
......@@ -1572,6 +1601,10 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext,
blocks = std::min(std::max(block_size / output_width, 1), output_height);
grid = dim3(ksize_width, ksize_height, output_channels);
threads = dim3(std::min(output_width, block_size), blocks, 1);
if (output_height * output_width < WARP_SIZE) {
threads = dim3(
std::min(block_size, batch_size * output_height * output_width));
}
} else {
blocks = std::min(
std::max(block_size / output_channels, 1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册