未验证 提交 b91bbd32 编写于 作者: 2 201716010711 提交者: GitHub

Optimize Paddle diagonal (#47904)

上级 de2c5fd6
......@@ -35,6 +35,7 @@ void DiagonalKernel(const Context& dev_ctx,
auto* output = out;
T* output_data = dev_ctx.template Alloc<T>(output);
auto output_dim = vectorize(output->dims());
auto output_dim_size = output_dim.size();
const int64_t offset_ = offset;
int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
......@@ -43,40 +44,48 @@ void DiagonalKernel(const Context& dev_ctx,
std::vector<int64_t> input_stride = funcs::ComputeDimStride(input_dim);
std::vector<int64_t> output_stride = funcs::ComputeDimStride(output_dim);
int64_t numel = input->numel();
for (int64_t idx = 0; idx < numel; idx++) {
std::vector<int64_t> idx_dim(input_dim_size);
int64_t out_numel = out->numel();
for (int64_t idx = 0; idx < out_numel; idx++) {
std::vector<int64_t> idx_dim(output_dim_size);
int64_t temp = 0;
for (size_t i = 0; i < input_dim_size; i++) {
idx_dim[i] = (idx - temp) / input_stride[i];
temp = temp + idx_dim[i] * input_stride[i];
for (size_t i = 0; i < output_dim_size; i++) {
idx_dim[i] = (idx - temp) / output_stride[i];
temp = temp + idx_dim[i] * output_stride[i];
}
int64_t axis1_dim = idx_dim[axis1_];
int64_t axis2_dim = idx_dim[axis2_];
idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_));
idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_));
bool flag = false;
if (offset_ == 0 && axis1_dim == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis1_dim);
flag = true;
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
idx_dim.push_back(axis2_dim);
flag = true;
int64_t tmp = idx_dim[output_dim_size - 1];
std::vector<int64_t> list;
list.clear();
int64_t l = std::min(axis1_, axis2_);
int64_t r = std::max(axis1_, axis2_);
for (size_t j = 0; j < output_dim_size - 1; j++) {
list.push_back(idx_dim[j]);
}
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < idx_dim.size(); i++) {
idx_output = idx_output + idx_dim[i] * output_stride[i];
if (offset_ == 0) {
list.insert(list.begin() + l, tmp);
list.insert(list.begin() + r, tmp);
} else if (offset_ > 0) {
if (axis1_ < axis2_) {
list.insert(list.begin() + l, tmp);
list.insert(list.begin() + r, tmp + offset_);
} else {
list.insert(list.begin() + l, tmp + offset_);
list.insert(list.begin() + r, tmp);
}
output_data[idx_output] = input_data[idx];
} else if (offset_ < 0) {
if (axis1_ < axis2_) {
list.insert(list.begin() + l, tmp - offset_);
list.insert(list.begin() + r, tmp);
} else {
list.insert(list.begin() + l, tmp);
list.insert(list.begin() + r, tmp - offset_);
}
}
int64_t input_offset = 0;
for (size_t i = 0; i < input_dim_size; i++) {
input_offset = input_offset + list[i] * input_stride[i];
}
output_data[idx] = input_data[input_offset];
}
}
} // namespace phi
......
......@@ -156,59 +156,59 @@ __global__ void DiagonalCuda(const T* data1,
int64_t* x_stride,
int64_t* out_stride,
int64_t numel,
int64_t out_numel,
bool is_grad) {
CUDA_KERNEL_LOOP(idx, numel) {
int64_t idx_dim[X_DIM_SIZE] = {0};
CUDA_KERNEL_LOOP(idx, out_numel) {
int64_t idx_dim[OUT_DIM_SIZE] = {0};
int64_t temp = 0;
for (size_t i = 0; i < X_DIM_SIZE - 1; i++) {
idx_dim[i] = (idx - temp) / x_stride[i];
temp = temp + idx_dim[i] * x_stride[i];
}
idx_dim[X_DIM_SIZE - 1] = idx - temp;
int64_t axis1_dim = idx_dim[axis1_];
int64_t axis2_dim = idx_dim[axis2_];
int64_t out_dim[OUT_DIM_SIZE] = {0};
int temp_pos = 0;
for (int i = 0; i < X_DIM_SIZE; i++) {
if (i != axis1_ && i != axis2_) {
out_dim[temp_pos] = idx_dim[i];
temp_pos++;
}
}
bool flag = false;
if (offset_ == 0 && axis1_dim == axis2_dim) {
out_dim[temp_pos] = axis1_dim;
flag = true;
} else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
out_dim[temp_pos] = axis1_dim;
flag = true;
} else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
out_dim[temp_pos] = axis2_dim;
flag = true;
}
if (!is_grad) {
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
idx_output = idx_output + out_dim[i] * out_stride[i];
}
idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
data2[idx_output] = data1[idx];
}
idx_dim[i] = (idx - temp) / out_stride[i];
temp = temp + idx_dim[i] * out_stride[i];
}
idx_dim[OUT_DIM_SIZE - 1] = idx - temp;
int64_t tmp = idx - temp;
int64_t list[9];
int64_t p = 0;
for (size_t j = 0; j < X_DIM_SIZE; j++) {
if (j == axis1_ || j == axis2_) {
list[j] = 0;
} else {
if (flag) {
int64_t idx_output = 0;
for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
idx_output = idx_output + out_dim[i] * out_stride[i];
list[j] = idx_dim[p];
p += 1;
}
}
int64_t l = min(axis1_, axis2_);
int64_t r = max(axis1_, axis2_);
if (offset_ == 0) {
list[l] = tmp;
list[r] = tmp;
} else if (offset_ > 0) {
if (axis1_ < axis2_) {
list[l] = tmp;
list[r] = tmp + offset_;
} else {
list[l] = tmp + offset_;
list[r] = tmp;
}
idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
data2[idx] = data1[idx_output];
} else if (offset_ < 0) {
if (axis1_ < axis2_) {
list[l] = tmp - offset_;
list[r] = tmp;
} else {
data2[idx] = static_cast<T>(0);
list[l] = tmp;
list[r] = tmp - offset_;
}
}
int64_t input_offset = 0;
for (size_t i = 0; i < X_DIM_SIZE; i++) {
input_offset = input_offset + list[i] * x_stride[i];
}
if (!is_grad) {
data2[idx] = data1[input_offset];
} else {
data2[input_offset] = data1[idx];
}
}
}
#endif
......
......@@ -62,6 +62,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
int threads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + threads - 1) / threads;
int64_t dout_numel = out_grad.numel();
phi::backends::gpu::GpuMemsetAsync(
dx_data, 0, numel * sizeof(T), dev_ctx.stream());
switch (dx_dim_size) {
case 2:
funcs::DiagonalCuda<T, 2, 1><<<blocks, threads>>>(dout_data,
......@@ -72,6 +76,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 3:
......@@ -83,6 +88,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 4:
......@@ -94,6 +100,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 5:
......@@ -105,6 +112,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 6:
......@@ -116,6 +124,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 7:
......@@ -127,6 +136,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 8:
......@@ -138,6 +148,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
case 9:
......@@ -149,6 +160,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
dx_stride,
dout_stride,
numel,
dout_numel,
true);
break;
default:
......
......@@ -54,9 +54,10 @@ void DiagonalKernel(const Context& dev_ctx,
int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2;
int64_t numel = input->numel();
int64_t out_numel = out->numel();
int threads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + threads - 1) / threads;
int blocks = (out_numel + threads - 1) / threads;
switch (input_dim_size) {
case 2:
......@@ -68,6 +69,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 3:
......@@ -79,6 +81,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 4:
......@@ -90,6 +93,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 5:
......@@ -101,6 +105,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 6:
......@@ -112,6 +117,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 7:
......@@ -123,6 +129,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 8:
......@@ -134,6 +141,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
case 9:
......@@ -145,6 +153,7 @@ void DiagonalKernel(const Context& dev_ctx,
input_stride,
output_stride,
numel,
out_numel,
false);
break;
default:
......
......@@ -101,6 +101,35 @@ class TestDiagonalOpCase3(TestDiagonalOp):
pass
class TestDiagonalOpCase4(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(100, 100).astype('int64')
self.inputs = {'Input': self.case}
self.attrs = {'offset': 1, 'axis1': 1, 'axis2': 0}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
def test_check_grad(self):
pass
class TestDiagonalOpCase5(TestDiagonalOp):
def init_config(self):
self.case = np.random.randn(4, 2, 4, 4).astype('float32')
self.inputs = {'Input': self.case}
self.attrs = {'offset': -2, 'axis1': 0, 'axis2': 3}
self.target = np.diagonal(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestDiagonalAPI(unittest.TestCase):
def setUp(self):
self.shape = [10, 3, 4]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册