未验证 提交 9e3433bd 编写于 作者: B Bo Zhang 提交者: GitHub

Merge dimensions && OP performance optimization (#43931)

上级 cf8e86df
...@@ -22,8 +22,6 @@ ...@@ -22,8 +22,6 @@
namespace phi { namespace phi {
using funcs::IndexCalculator;
template <typename T> template <typename T>
__global__ void CrossGrad(const T* x, __global__ void CrossGrad(const T* x,
const T* y, const T* y,
...@@ -32,7 +30,7 @@ __global__ void CrossGrad(const T* x, ...@@ -32,7 +30,7 @@ __global__ void CrossGrad(const T* x,
T* out_dy, T* out_dy,
const int stride, const int stride,
const int N, const int N,
IndexCalculator index_calculator) { phi::funcs::IndexCalculator index_calculator) {
CUDA_KERNEL_LOOP(i, N) { CUDA_KERNEL_LOOP(i, N) {
int offset = index_calculator(i); int offset = index_calculator(i);
...@@ -107,32 +105,52 @@ void CrossGradKernel(const Context& dev_ctx, ...@@ -107,32 +105,52 @@ void CrossGradKernel(const Context& dev_ctx,
std::vector<int> cal_dims; std::vector<int> cal_dims;
std::vector<int> left_strides; std::vector<int> left_strides;
std::vector<int> full_strides; std::vector<int> full_strides;
std::vector<int> merged_dims;
for (int i = 0; i < dim; i++) {
if (i == 0) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[0] *= input_x_dims[i];
}
}
int merge_axis = merged_dims.size();
merged_dims.push_back(input_x_dims[dim]);
for (int i = dim + 1; i < input_x_dims.size(); i++) {
if (i == dim + 1) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[merge_axis + 1] *= input_x_dims[i];
}
}
int full_dim = 1; int full_dim = 1;
int left_dim = 1; for (int i = 0; i < merged_dims.size(); i++) {
for (auto i = 0; i < input_x_dims.size(); i++) {
full_strides.insert(full_strides.begin(), full_dim); full_strides.insert(full_strides.begin(), full_dim);
full_dim *= input_x_dims[input_x_dims.size() - i - 1]; full_dim *= merged_dims[merged_dims.size() - i - 1];
if (i == dim) { if (i == merge_axis) {
continue; continue;
} }
cal_dims.push_back(i); cal_dims.push_back(i);
}
int left_dim = 1;
for (int i = merged_dims.size() - 1; i >= 0; i--) {
if (i == merge_axis) {
continue;
}
left_strides.insert(left_strides.begin(), left_dim); left_strides.insert(left_strides.begin(), left_dim);
left_dim *= input_x_dims[input_x_dims.size() - i - 1]; left_dim *= merged_dims[i];
} }
const auto* input_x_data = input_x.data<T>(); const auto* input_x_data = input_x.data<T>();
const auto* input_y_data = input_y.data<T>(); const auto* input_y_data = input_y.data<T>();
const auto* input_out_grad_data = input_out_grad.data<T>(); const auto* input_out_grad_data = input_out_grad.data<T>();
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad); auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad); auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
auto index_calculator = phi::funcs::IndexCalculator(
auto index_calculator = IndexCalculator( merged_dims.size() - 1, cal_dims, left_strides, full_strides);
input_x_dims.size() - 1, cal_dims, left_strides, full_strides);
int64_t numel = x.numel(); int64_t numel = x.numel();
backends::gpu::GpuLaunchConfig config = backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
...@@ -144,7 +162,7 @@ void CrossGradKernel(const Context& dev_ctx, ...@@ -144,7 +162,7 @@ void CrossGradKernel(const Context& dev_ctx,
input_out_grad_data, input_out_grad_data,
output_x_grad_data, output_x_grad_data,
output_y_grad_data, output_y_grad_data,
full_strides[dim], full_strides[merge_axis],
numel / 3, numel / 3,
index_calculator); index_calculator);
} }
......
...@@ -22,15 +22,13 @@ ...@@ -22,15 +22,13 @@
namespace phi { namespace phi {
using funcs::IndexCalculator;
template <typename T> template <typename T>
__global__ void Cross(const T* x, __global__ void Cross(const T* x,
const T* y, const T* y,
T* out, T* out,
const int stride, const int stride,
const int N, const int N,
IndexCalculator index_calculator) { phi::funcs::IndexCalculator index_calculator) {
CUDA_KERNEL_LOOP(i, N) { CUDA_KERNEL_LOOP(i, N) {
int offset = index_calculator(i); int offset = index_calculator(i);
...@@ -96,30 +94,50 @@ void CrossKernel(const Context& dev_ctx, ...@@ -96,30 +94,50 @@ void CrossKernel(const Context& dev_ctx,
std::vector<int> cal_dims; std::vector<int> cal_dims;
std::vector<int> left_strides; std::vector<int> left_strides;
std::vector<int> full_strides; std::vector<int> full_strides;
std::vector<int> merged_dims;
for (int i = 0; i < dim; i++) {
if (i == 0) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[0] *= input_x_dims[i];
}
}
int merge_axis = merged_dims.size();
merged_dims.push_back(input_x_dims[dim]);
for (int i = dim + 1; i < input_x_dims.size(); i++) {
if (i == dim + 1) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[merge_axis + 1] *= input_x_dims[i];
}
}
int dims0 = 1; int full_dim = 1;
int dims1 = 1; for (int i = 0; i < merged_dims.size(); i++) {
for (auto i = 0; i < input_x_dims.size(); i++) { full_strides.insert(full_strides.begin(), full_dim);
full_strides.insert(full_strides.begin(), dims0); full_dim *= merged_dims[merged_dims.size() - i - 1];
dims0 *= input_x_dims[input_x_dims.size() - i - 1]; if (i == merge_axis) {
if (i == dim) {
continue; continue;
} }
cal_dims.push_back(i); cal_dims.push_back(i);
left_strides.insert(left_strides.begin(), dims1); }
dims1 *= input_x_dims[input_x_dims.size() - i - 1]; int left_dim = 1;
for (int i = merged_dims.size() - 1; i >= 0; i--) {
if (i == merge_axis) {
continue;
}
left_strides.insert(left_strides.begin(), left_dim);
left_dim *= merged_dims[i];
} }
const auto* input_x_data = input_x.data<T>(); const auto* input_x_data = input_x.data<T>();
const auto* input_y_data = input_y.data<T>(); const auto* input_y_data = input_y.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out); auto* out_data = dev_ctx.template Alloc<T>(out);
auto index_calculator = phi::funcs::IndexCalculator(
auto index_calculator = IndexCalculator( merged_dims.size() - 1, cal_dims, left_strides, full_strides);
input_x_dims.size() - 1, cal_dims, left_strides, full_strides);
int64_t numel = x.numel(); int64_t numel = x.numel();
backends::gpu::GpuLaunchConfig config = backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
...@@ -129,7 +147,7 @@ void CrossKernel(const Context& dev_ctx, ...@@ -129,7 +147,7 @@ void CrossKernel(const Context& dev_ctx,
dev_ctx.stream()>>>(input_x_data, dev_ctx.stream()>>>(input_x_data,
input_y_data, input_y_data,
out_data, out_data,
full_strides[dim], full_strides[merge_axis],
numel / 3, numel / 3,
index_calculator); index_calculator);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册