未验证 提交 b91bbd32 编写于 作者: 2 201716010711 提交者: GitHub

Optimize Paddle diagonal (#47904)

上级 de2c5fd6
...@@ -35,6 +35,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -35,6 +35,7 @@ void DiagonalKernel(const Context& dev_ctx,
auto* output = out; auto* output = out;
T* output_data = dev_ctx.template Alloc<T>(output); T* output_data = dev_ctx.template Alloc<T>(output);
auto output_dim = vectorize(output->dims()); auto output_dim = vectorize(output->dims());
auto output_dim_size = output_dim.size();
const int64_t offset_ = offset; const int64_t offset_ = offset;
int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
...@@ -43,40 +44,48 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -43,40 +44,48 @@ void DiagonalKernel(const Context& dev_ctx,
std::vector<int64_t> input_stride = funcs::ComputeDimStride(input_dim); std::vector<int64_t> input_stride = funcs::ComputeDimStride(input_dim);
std::vector<int64_t> output_stride = funcs::ComputeDimStride(output_dim); std::vector<int64_t> output_stride = funcs::ComputeDimStride(output_dim);
int64_t numel = input->numel(); int64_t out_numel = out->numel();
for (int64_t idx = 0; idx < out_numel; idx++) {
for (int64_t idx = 0; idx < numel; idx++) { std::vector<int64_t> idx_dim(output_dim_size);
std::vector<int64_t> idx_dim(input_dim_size);
int64_t temp = 0; int64_t temp = 0;
for (size_t i = 0; i < input_dim_size; i++) { for (size_t i = 0; i < output_dim_size; i++) {
idx_dim[i] = (idx - temp) / input_stride[i]; idx_dim[i] = (idx - temp) / output_stride[i];
temp = temp + idx_dim[i] * input_stride[i]; temp = temp + idx_dim[i] * output_stride[i];
} }
int64_t tmp = idx_dim[output_dim_size - 1];
int64_t axis1_dim = idx_dim[axis1_]; std::vector<int64_t> list;
int64_t axis2_dim = idx_dim[axis2_]; list.clear();
int64_t l = std::min(axis1_, axis2_);
idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_)); int64_t r = std::max(axis1_, axis2_);
idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_)); for (size_t j = 0; j < output_dim_size - 1; j++) {
list.push_back(idx_dim[j]);
bool flag = false;
if (offset_ == 0 && axis1_dim == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis2_dim);
flag = true;
} }
if (flag) { if (offset_ == 0) {
int64_t idx_output = 0; list.insert(list.begin() + l, tmp);
for (size_t i = 0; i < idx_dim.size(); i++) { list.insert(list.begin() + r, tmp);
idx_output = idx_output + idx_dim[i] * output_stride[i]; } else if (offset_ > 0) {
if (axis1_ < axis2_) {
list.insert(list.begin() + l, tmp);
list.insert(list.begin() + r, tmp + offset_);
} else {
list.insert(list.begin() + l, tmp + offset_);
list.insert(list.begin() + r, tmp);
} }
output_data[idx_output] = input_data[idx]; } else if (offset_ < 0) {
if (axis1_ < axis2_) {
list.insert(list.begin() + l, tmp - offset_);
list.insert(list.begin() + r, tmp);
} else {
list.insert(list.begin() + l, tmp);
list.insert(list.begin() + r, tmp - offset_);
}
}
int64_t input_offset = 0;
for (size_t i = 0; i < input_dim_size; i++) {
input_offset = input_offset + list[i] * input_stride[i];
} }
output_data[idx] = input_data[input_offset];
} }
} }
} // namespace phi } // namespace phi
......
...@@ -156,59 +156,59 @@ __global__ void DiagonalCuda(const T* data1, ...@@ -156,59 +156,59 @@ __global__ void DiagonalCuda(const T* data1,
int64_t* x_stride, int64_t* x_stride,
int64_t* out_stride, int64_t* out_stride,
int64_t numel, int64_t numel,
int64_t out_numel,
bool is_grad) { bool is_grad) {
CUDA_KERNEL_LOOP(idx, numel) { CUDA_KERNEL_LOOP(idx, out_numel) {
int64_t idx_dim[X_DIM_SIZE] = {0}; int64_t idx_dim[OUT_DIM_SIZE] = {0};
int64_t temp = 0; int64_t temp = 0;
for (size_t i = 0; i < X_DIM_SIZE - 1; i++) { for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
idx_dim[i] = (idx - temp) / x_stride[i]; idx_dim[i] = (idx - temp) / out_stride[i];
temp = temp + idx_dim[i] * x_stride[i]; temp = temp + idx_dim[i] * out_stride[i];
} }
idx_dim[X_DIM_SIZE - 1] = idx - temp; idx_dim[OUT_DIM_SIZE - 1] = idx - temp;
int64_t tmp = idx - temp;
int64_t axis1_dim = idx_dim[axis1_]; int64_t list[9];
int64_t axis2_dim = idx_dim[axis2_]; int64_t p = 0;
for (size_t j = 0; j < X_DIM_SIZE; j++) {
int64_t out_dim[OUT_DIM_SIZE] = {0}; if (j == axis1_ || j == axis2_) {
int temp_pos = 0; list[j] = 0;
for (int i = 0; i < X_DIM_SIZE; i++) { } else {
if (i != axis1_ && i != axis2_) { list[j] = idx_dim[p];
out_dim[temp_pos] = idx_dim[i]; p += 1;
temp_pos++;
} }
} }
bool flag = false; int64_t l = min(axis1_, axis2_);
if (offset_ == 0 && axis1_dim == axis2_dim) { int64_t r = max(axis1_, axis2_);
out_dim[temp_pos] = axis1_dim; if (offset_ == 0) {
flag = true; list[l] = tmp;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { list[r] = tmp;
out_dim[temp_pos] = axis1_dim; } else if (offset_ > 0) {
flag = true; if (axis1_ < axis2_) {
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { list[l] = tmp;
out_dim[temp_pos] = axis2_dim; list[r] = tmp + offset_;
flag = true; } else {
} list[l] = tmp + offset_;
if (!is_grad) { list[r] = tmp;
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
idx_output = idx_output + out_dim[i] * out_stride[i];
}
idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
data2[idx_output] = data1[idx];
} }
} else { } else if (offset_ < 0) {
if (flag) { if (axis1_ < axis2_) {
int64_t idx_output = 0; list[l] = tmp - offset_;
for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { list[r] = tmp;
idx_output = idx_output + out_dim[i] * out_stride[i];
}
idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
data2[idx] = data1[idx_output];
} else { } else {
data2[idx] = static_cast<T>(0); list[l] = tmp;
list[r] = tmp - offset_;
} }
} }
int64_t input_offset = 0;
for (size_t i = 0; i < X_DIM_SIZE; i++) {
input_offset = input_offset + list[i] * x_stride[i];
}
if (!is_grad) {
data2[idx] = data1[input_offset];
} else {
data2[input_offset] = data1[idx];
}
} }
} }
#endif #endif
......
...@@ -62,6 +62,10 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -62,6 +62,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
int threads = PADDLE_CUDA_NUM_THREADS; int threads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + threads - 1) / threads; int blocks = (numel + threads - 1) / threads;
int64_t dout_numel = out_grad.numel();
phi::backends::gpu::GpuMemsetAsync(
dx_data, 0, numel * sizeof(T), dev_ctx.stream());
switch (dx_dim_size) { switch (dx_dim_size) {
case 2: case 2:
funcs::DiagonalCuda<T, 2, 1><<<blocks, threads>>>(dout_data, funcs::DiagonalCuda<T, 2, 1><<<blocks, threads>>>(dout_data,
...@@ -72,6 +76,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -72,6 +76,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 3: case 3:
...@@ -83,6 +88,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -83,6 +88,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 4: case 4:
...@@ -94,6 +100,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -94,6 +100,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 5: case 5:
...@@ -105,6 +112,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -105,6 +112,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 6: case 6:
...@@ -116,6 +124,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -116,6 +124,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 7: case 7:
...@@ -127,6 +136,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -127,6 +136,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 8: case 8:
...@@ -138,6 +148,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -138,6 +148,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
case 9: case 9:
...@@ -149,6 +160,7 @@ void DiagonalGradKernel(const Context& dev_ctx, ...@@ -149,6 +160,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride, dx_stride,
dout_stride, dout_stride,
numel, numel,
dout_numel,
true); true);
break; break;
default: default:
......
...@@ -54,9 +54,10 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -54,9 +54,10 @@ void DiagonalKernel(const Context& dev_ctx,
int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2; int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2;
int64_t numel = input->numel(); int64_t numel = input->numel();
int64_t out_numel = out->numel();
int threads = PADDLE_CUDA_NUM_THREADS; int threads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + threads - 1) / threads; int blocks = (out_numel + threads - 1) / threads;
switch (input_dim_size) { switch (input_dim_size) {
case 2: case 2:
...@@ -68,6 +69,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -68,6 +69,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 3: case 3:
...@@ -79,6 +81,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -79,6 +81,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 4: case 4:
...@@ -90,6 +93,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -90,6 +93,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 5: case 5:
...@@ -101,6 +105,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -101,6 +105,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 6: case 6:
...@@ -112,6 +117,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -112,6 +117,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 7: case 7:
...@@ -123,6 +129,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -123,6 +129,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 8: case 8:
...@@ -134,6 +141,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -134,6 +141,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
case 9: case 9:
...@@ -145,6 +153,7 @@ void DiagonalKernel(const Context& dev_ctx, ...@@ -145,6 +153,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride, input_stride,
output_stride, output_stride,
numel, numel,
out_numel,
false); false);
break; break;
default: default:
......
...@@ -101,6 +101,35 @@ class TestDiagonalOpCase3(TestDiagonalOp): ...@@ -101,6 +101,35 @@ class TestDiagonalOpCase3(TestDiagonalOp):
pass pass
class TestDiagonalOpCase4(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(100, 100).astype('int64')
self.inputs = {'Input': self.case}
self.attrs = {'offset': 1, 'axis1': 1, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
def test_check_grad(self):
pass
class TestDiagonalOpCase5(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(4, 2, 4, 4).astype('float32')
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 0, 'axis2': 3}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestDiagonalAPI(unittest.TestCase): class TestDiagonalAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.shape = [10, 3, 4] self.shape = [10, 3, 4]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册