diff --git a/paddle/phi/kernels/gpu/cross_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_grad_kernel.cu index ada78adb77fc97a62a179b0ebc60e0e89ce1e1b8..a6399ba39dcaec657bbfe68b67a5d1d9132a6b45 100644 --- a/paddle/phi/kernels/gpu/cross_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_grad_kernel.cu @@ -22,8 +22,6 @@ namespace phi { -using funcs::IndexCalculator; - template __global__ void CrossGrad(const T* x, const T* y, @@ -32,7 +30,7 @@ __global__ void CrossGrad(const T* x, T* out_dy, const int stride, const int N, - IndexCalculator index_calculator) { + phi::funcs::IndexCalculator index_calculator) { CUDA_KERNEL_LOOP(i, N) { int offset = index_calculator(i); @@ -107,32 +105,52 @@ void CrossGradKernel(const Context& dev_ctx, std::vector cal_dims; std::vector left_strides; std::vector full_strides; + std::vector merged_dims; + + for (int i = 0; i < dim; i++) { + if (i == 0) { + merged_dims.push_back(input_x_dims[i]); + } else { + merged_dims[0] *= input_x_dims[i]; + } + } + int merge_axis = merged_dims.size(); + merged_dims.push_back(input_x_dims[dim]); + for (int i = dim + 1; i < input_x_dims.size(); i++) { + if (i == dim + 1) { + merged_dims.push_back(input_x_dims[i]); + } else { + merged_dims[merge_axis + 1] *= input_x_dims[i]; + } + } int full_dim = 1; - int left_dim = 1; - for (auto i = 0; i < input_x_dims.size(); i++) { + for (int i = 0; i < merged_dims.size(); i++) { full_strides.insert(full_strides.begin(), full_dim); - full_dim *= input_x_dims[input_x_dims.size() - i - 1]; - if (i == dim) { + full_dim *= merged_dims[merged_dims.size() - i - 1]; + if (i == merge_axis) { continue; } cal_dims.push_back(i); + } + int left_dim = 1; + for (int i = merged_dims.size() - 1; i >= 0; i--) { + if (i == merge_axis) { + continue; + } left_strides.insert(left_strides.begin(), left_dim); - left_dim *= input_x_dims[input_x_dims.size() - i - 1]; + left_dim *= merged_dims[i]; } const auto* input_x_data = input_x.data(); const auto* input_y_data = input_y.data(); const auto* input_out_grad_data = input_out_grad.data(); - auto* output_x_grad_data = dev_ctx.template Alloc(x_grad); auto* output_y_grad_data = dev_ctx.template Alloc(y_grad); - - auto index_calculator = IndexCalculator( - input_x_dims.size() - 1, cal_dims, left_strides, full_strides); + auto index_calculator = phi::funcs::IndexCalculator( + merged_dims.size() - 1, cal_dims, left_strides, full_strides); int64_t numel = x.numel(); - backends::gpu::GpuLaunchConfig config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); @@ -144,7 +162,7 @@ void CrossGradKernel(const Context& dev_ctx, input_out_grad_data, output_x_grad_data, output_y_grad_data, - full_strides[dim], + full_strides[merge_axis], numel / 3, index_calculator); } diff --git a/paddle/phi/kernels/gpu/cross_kernel.cu b/paddle/phi/kernels/gpu/cross_kernel.cu index 44173f4fbe62d98807e5671c98d27e38dbe5a2ac..0e1e7b3a42568b2d80ccd69d0b5e11146870fa14 100644 --- a/paddle/phi/kernels/gpu/cross_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_kernel.cu @@ -22,15 +22,13 @@ namespace phi { -using funcs::IndexCalculator; - template __global__ void Cross(const T* x, const T* y, T* out, const int stride, const int N, - IndexCalculator index_calculator) { + phi::funcs::IndexCalculator index_calculator) { CUDA_KERNEL_LOOP(i, N) { int offset = index_calculator(i); @@ -96,30 +94,50 @@ void CrossKernel(const Context& dev_ctx, std::vector cal_dims; std::vector left_strides; std::vector full_strides; + std::vector merged_dims; + + for (int i = 0; i < dim; i++) { + if (i == 0) { + merged_dims.push_back(input_x_dims[i]); + } else { + merged_dims[0] *= input_x_dims[i]; + } + } + int merge_axis = merged_dims.size(); + merged_dims.push_back(input_x_dims[dim]); + for (int i = dim + 1; i < input_x_dims.size(); i++) { + if (i == dim + 1) { + merged_dims.push_back(input_x_dims[i]); + } else { + merged_dims[merge_axis + 1] *= input_x_dims[i]; + } + } - int dims0 = 1; - int dims1 = 1; - for (auto i = 0; i < input_x_dims.size(); i++) { - full_strides.insert(full_strides.begin(), dims0); - dims0 *= input_x_dims[input_x_dims.size() - i - 1]; - if (i == dim) { + int full_dim = 1; + for (int i = 0; i < merged_dims.size(); i++) { + full_strides.insert(full_strides.begin(), full_dim); + full_dim *= merged_dims[merged_dims.size() - i - 1]; + if (i == merge_axis) { continue; } cal_dims.push_back(i); - left_strides.insert(left_strides.begin(), dims1); - dims1 *= input_x_dims[input_x_dims.size() - i - 1]; + } + int left_dim = 1; + for (int i = merged_dims.size() - 1; i >= 0; i--) { + if (i == merge_axis) { + continue; + } + left_strides.insert(left_strides.begin(), left_dim); + left_dim *= merged_dims[i]; } const auto* input_x_data = input_x.data(); const auto* input_y_data = input_y.data(); - auto* out_data = dev_ctx.template Alloc(out); - - auto index_calculator = IndexCalculator( - input_x_dims.size() - 1, cal_dims, left_strides, full_strides); + auto index_calculator = phi::funcs::IndexCalculator( + merged_dims.size() - 1, cal_dims, left_strides, full_strides); int64_t numel = x.numel(); - backends::gpu::GpuLaunchConfig config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); @@ -129,7 +147,7 @@ void CrossKernel(const Context& dev_ctx, dev_ctx.stream()>>>(input_x_data, input_y_data, out_data, - full_strides[dim], + full_strides[merge_axis], numel / 3, index_calculator); }