diff --git a/paddle/phi/kernels/cpu/diagonal_kernel.cc b/paddle/phi/kernels/cpu/diagonal_kernel.cc index f125802c19e242c813ee8baf78bc32f234bc1f5a..d2361bee30a5feb29da9937d24b76b3589b9f268 100644 --- a/paddle/phi/kernels/cpu/diagonal_kernel.cc +++ b/paddle/phi/kernels/cpu/diagonal_kernel.cc @@ -35,6 +35,7 @@ void DiagonalKernel(const Context& dev_ctx, auto* output = out; T* output_data = dev_ctx.template Alloc(output); auto output_dim = vectorize(output->dims()); + auto output_dim_size = output_dim.size(); const int64_t offset_ = offset; int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; @@ -43,40 +44,48 @@ void DiagonalKernel(const Context& dev_ctx, std::vector input_stride = funcs::ComputeDimStride(input_dim); std::vector output_stride = funcs::ComputeDimStride(output_dim); - int64_t numel = input->numel(); - - for (int64_t idx = 0; idx < numel; idx++) { - std::vector idx_dim(input_dim_size); + int64_t out_numel = out->numel(); + for (int64_t idx = 0; idx < out_numel; idx++) { + std::vector idx_dim(output_dim_size); int64_t temp = 0; - for (size_t i = 0; i < input_dim_size; i++) { - idx_dim[i] = (idx - temp) / input_stride[i]; - temp = temp + idx_dim[i] * input_stride[i]; + for (size_t i = 0; i < output_dim_size; i++) { + idx_dim[i] = (idx - temp) / output_stride[i]; + temp = temp + idx_dim[i] * output_stride[i]; } - - int64_t axis1_dim = idx_dim[axis1_]; - int64_t axis2_dim = idx_dim[axis2_]; - - idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_)); - idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_)); - - bool flag = false; - if (offset_ == 0 && axis1_dim == axis2_dim) { - idx_dim.push_back(axis1_dim); - flag = true; - } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { - idx_dim.push_back(axis1_dim); - flag = true; - } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { - idx_dim.push_back(axis2_dim); - flag = true; + int64_t tmp = idx_dim[output_dim_size - 1]; + std::vector list; + list.clear(); + int64_t l = std::min(axis1_, axis2_); + int64_t r = std::max(axis1_, axis2_); + for (size_t j = 0; j < output_dim_size - 1; j++) { + list.push_back(idx_dim[j]); } - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < idx_dim.size(); i++) { - idx_output = idx_output + idx_dim[i] * output_stride[i]; + if (offset_ == 0) { + list.insert(list.begin() + l, tmp); + list.insert(list.begin() + r, tmp); + } else if (offset_ > 0) { + if (axis1_ < axis2_) { + list.insert(list.begin() + l, tmp); + list.insert(list.begin() + r, tmp + offset_); + } else { + list.insert(list.begin() + l, tmp + offset_); + list.insert(list.begin() + r, tmp); } - output_data[idx_output] = input_data[idx]; + } else if (offset_ < 0) { + if (axis1_ < axis2_) { + list.insert(list.begin() + l, tmp - offset_); + list.insert(list.begin() + r, tmp); + } else { + list.insert(list.begin() + l, tmp); + list.insert(list.begin() + r, tmp - offset_); + } + } + + int64_t input_offset = 0; + for (size_t i = 0; i < input_dim_size; i++) { + input_offset = input_offset + list[i] * input_stride[i]; } + output_data[idx] = input_data[input_offset]; } } } // namespace phi diff --git a/paddle/phi/kernels/funcs/diagonal.h b/paddle/phi/kernels/funcs/diagonal.h index 92f970aed327951e0d1833726bb1f32e8332e4a4..a30fb79f8c8b0466ed3426b5717566c93b8dde9d 100644 --- a/paddle/phi/kernels/funcs/diagonal.h +++ b/paddle/phi/kernels/funcs/diagonal.h @@ -156,59 +156,59 @@ __global__ void DiagonalCuda(const T* data1, int64_t* x_stride, int64_t* out_stride, int64_t numel, + int64_t out_numel, bool is_grad) { - CUDA_KERNEL_LOOP(idx, numel) { - int64_t idx_dim[X_DIM_SIZE] = {0}; + CUDA_KERNEL_LOOP(idx, out_numel) { + int64_t idx_dim[OUT_DIM_SIZE] = {0}; int64_t temp = 0; - for (size_t i = 0; i < X_DIM_SIZE - 1; i++) { - idx_dim[i] = (idx - temp) / x_stride[i]; - temp = temp + idx_dim[i] * x_stride[i]; + for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { + idx_dim[i] = (idx - temp) / out_stride[i]; + temp = temp + idx_dim[i] * out_stride[i]; } - idx_dim[X_DIM_SIZE - 1] = idx - temp; - - int64_t axis1_dim = idx_dim[axis1_]; - int64_t axis2_dim = idx_dim[axis2_]; - - int64_t out_dim[OUT_DIM_SIZE] = {0}; - int temp_pos = 0; - for (int i = 0; i < X_DIM_SIZE; i++) { - if (i != axis1_ && i != axis2_) { - out_dim[temp_pos] = idx_dim[i]; - temp_pos++; + idx_dim[OUT_DIM_SIZE - 1] = idx - temp; + int64_t tmp = idx - temp; + int64_t list[9]; + int64_t p = 0; + for (size_t j = 0; j < X_DIM_SIZE; j++) { + if (j == axis1_ || j == axis2_) { + list[j] = 0; + } else { + list[j] = idx_dim[p]; + p += 1; } } - bool flag = false; - if (offset_ == 0 && axis1_dim == axis2_dim) { - out_dim[temp_pos] = axis1_dim; - flag = true; - } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) { - out_dim[temp_pos] = axis1_dim; - flag = true; - } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) { - out_dim[temp_pos] = axis2_dim; - flag = true; - } - if (!is_grad) { - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { - idx_output = idx_output + out_dim[i] * out_stride[i]; - } - idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1]; - data2[idx_output] = data1[idx]; + int64_t l = min(axis1_, axis2_); + int64_t r = max(axis1_, axis2_); + if (offset_ == 0) { + list[l] = tmp; + list[r] = tmp; + } else if (offset_ > 0) { + if (axis1_ < axis2_) { + list[l] = tmp; + list[r] = tmp + offset_; + } else { + list[l] = tmp + offset_; + list[r] = tmp; } - } else { - if (flag) { - int64_t idx_output = 0; - for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) { - idx_output = idx_output + out_dim[i] * out_stride[i]; - } - idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1]; - data2[idx] = data1[idx_output]; + } else if (offset_ < 0) { + if (axis1_ < axis2_) { + list[l] = tmp - offset_; + list[r] = tmp; } else { - data2[idx] = static_cast(0); + list[l] = tmp; + list[r] = tmp - offset_; } } + int64_t input_offset = 0; + + for (size_t i = 0; i < X_DIM_SIZE; i++) { + input_offset = input_offset + list[i] * x_stride[i]; + } + if (!is_grad) { + data2[idx] = data1[input_offset]; + } else { + data2[input_offset] = data1[idx]; + } } } #endif diff --git a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu index 05a57426fcb213ef6fee47df38363f261301831c..a65d9af75f6a33c4669bf6a6059c1ed81d7b7d44 100644 --- a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu @@ -62,6 +62,10 @@ void DiagonalGradKernel(const Context& dev_ctx, int threads = PADDLE_CUDA_NUM_THREADS; int blocks = (numel + threads - 1) / threads; + int64_t dout_numel = out_grad.numel(); + phi::backends::gpu::GpuMemsetAsync( + dx_data, 0, numel * sizeof(T), dev_ctx.stream()); + switch (dx_dim_size) { case 2: funcs::DiagonalCuda<<>>(dout_data, @@ -72,6 +76,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 3: @@ -83,6 +88,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 4: @@ -94,6 +100,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 5: @@ -105,6 +112,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 6: @@ -116,6 +124,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 7: @@ -127,6 +136,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 8: @@ -138,6 +148,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; case 9: @@ -149,6 +160,7 @@ void DiagonalGradKernel(const Context& dev_ctx, dx_stride, dout_stride, numel, + dout_numel, true); break; default: diff --git a/paddle/phi/kernels/gpu/diagonal_kernel.cu b/paddle/phi/kernels/gpu/diagonal_kernel.cu index 74bad0ecd9a3509a3cd92d3a50465c0ea7fdbf7d..74e7db258c7d1ac320599ad0cacc7c3e67833a83 100644 --- a/paddle/phi/kernels/gpu/diagonal_kernel.cu +++ b/paddle/phi/kernels/gpu/diagonal_kernel.cu @@ -54,9 +54,10 @@ void DiagonalKernel(const Context& dev_ctx, int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1; int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2; int64_t numel = input->numel(); + int64_t out_numel = out->numel(); int threads = PADDLE_CUDA_NUM_THREADS; - int blocks = (numel + threads - 1) / threads; + int blocks = (out_numel + threads - 1) / threads; switch (input_dim_size) { case 2: @@ -68,6 +69,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 3: @@ -79,6 +81,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 4: @@ -90,6 +93,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 5: @@ -101,6 +105,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 6: @@ -112,6 +117,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 7: @@ -123,6 +129,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 8: @@ -134,6 +141,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; case 9: @@ -145,6 +153,7 @@ void DiagonalKernel(const Context& dev_ctx, input_stride, output_stride, numel, + out_numel, false); break; default: diff --git a/python/paddle/fluid/tests/unittests/test_diagonal_op.py b/python/paddle/fluid/tests/unittests/test_diagonal_op.py index 5b3c3830c57ca07194c627512cd060c18cce4c31..cb35a3fce5d0307a0e12634e16c117591c633be9 100644 --- a/python/paddle/fluid/tests/unittests/test_diagonal_op.py +++ b/python/paddle/fluid/tests/unittests/test_diagonal_op.py @@ -101,6 +101,35 @@ class TestDiagonalOpCase3(TestDiagonalOp): pass +class TestDiagonalOpCase4(TestDiagonalOp): + def init_config(self): + self.case = np.random.randn(100, 100).astype('int64') + self.inputs = {'Input': self.case} + self.attrs = {'offset': 1, 'axis1': 1, 'axis2': 0} + self.target = np.diagonal( + self.inputs['Input'], + offset=self.attrs['offset'], + axis1=self.attrs['axis1'], + axis2=self.attrs['axis2'], + ) + + def test_check_grad(self): + pass + + +class TestDiagonalOpCase5(TestDiagonalOp): + def init_config(self): + self.case = np.random.randn(4, 2, 4, 4).astype('float32') + self.inputs = {'Input': self.case} + self.attrs = {'offset': -2, 'axis1': 0, 'axis2': 3} + self.target = np.diagonal( + self.inputs['Input'], + offset=self.attrs['offset'], + axis1=self.attrs['axis1'], + axis2=self.attrs['axis2'], + ) + + class TestDiagonalAPI(unittest.TestCase): def setUp(self): self.shape = [10, 3, 4]