diff --git a/paddle/fluid/operators/diag_v2_op.cu b/paddle/fluid/operators/diag_v2_op.cu index 4386cc6b8183c03b4d4a19aba7d1126eac2ab495..12ea31945f8d032e1f395c2fb92d9ef31d10c7e8 100644 --- a/paddle/fluid/operators/diag_v2_op.cu +++ b/paddle/fluid/operators/diag_v2_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/diag_v2_op.h" @@ -58,6 +59,17 @@ class DiagV2CUDAKernel : public framework::OpKernel { auto out_dims = out->dims(); auto& dev_ctx = context.template device_context(); + auto GetBlockGridSize = [&dev_ctx](int64_t size) { + const int64_t block_size = + std::min(size, static_cast(dev_ctx.GetMaxThreadsPerBlock())); + int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (size + block_size - 1) / block_size); + return std::tuple{block_size, grid_size}; + }; + if (x_dims.size() == 1) { float padding_value = context.Attr("padding_value"); math::SetConstant set_padding_value; @@ -67,26 +79,23 @@ class DiagV2CUDAKernel : public framework::OpKernel { auto size = (offset > 0) ? x_length + offset : x_length - offset; const int& x_stride = ComputeStride(0, x_dims); if (size > 0) { - const int block_num = std::min(static_cast(size), - dev_ctx.GetMaxPhysicalThreadCount()); - int size_ = static_cast(size); - int block_num_ = static_cast(block_num); - const int grid_num = - std::min(1024, (size_ + block_num_ - 1) / block_num_); const auto& out_stride_0 = ComputeStride(0, out_dims); const auto& out_stride_1 = ComputeStride(1, out_dims); auto start = (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); - PasteDiagonalKernel<<>>( - out_data, x_data, start, x_length, out_stride_0 + out_stride_1, - x_stride); + std::tuple block_grid_size = GetBlockGridSize(size); + + PasteDiagonalKernel< + T><<(block_grid_size), std::get<0>(block_grid_size), 0, + dev_ctx.stream()>>>(out_data, x_data, start, x_length, + out_stride_0 + out_stride_1, x_stride); } } else { const int& x_stride_0 = ComputeStride(0, x_dims); const int& x_stride_1 = ComputeStride(1, x_dims); - int size; + int64_t size; if (offset > 0) { size = std::min(x_dims[0], x_dims[1] - offset); } else { @@ -94,18 +103,15 @@ class DiagV2CUDAKernel : public framework::OpKernel { } if (size > 0) { - const int block_num = std::min(static_cast(size), - dev_ctx.GetMaxPhysicalThreadCount()); - int size_ = static_cast(size); - int block_num_ = static_cast(block_num); - const int grid_num = - std::min(1024, (size_ + block_num_ - 1) / block_num_); auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); const auto& out_stride_0 = ComputeStride(0, out_dims); - ExtractDiagonalKernel<<>>( - out_data, x_data, start, size, x_stride_0 + x_stride_1, - out_stride_0); + std::tuple block_grid_size = GetBlockGridSize(size); + + ExtractDiagonalKernel< + T><<(block_grid_size), std::get<0>(block_grid_size), 0, + dev_ctx.stream()>>>(out_data, x_data, start, size, + x_stride_0 + x_stride_1, out_stride_0); } } } diff --git a/python/paddle/fluid/tests/unittests/test_diag.py b/python/paddle/fluid/tests/unittests/test_diag.py index 780d57b53310bb5f385a131d4ad52dd6f5e695f0..ddf1240e4ef27775a24cee540c5f193399112270 100644 --- a/python/paddle/fluid/tests/unittests/test_diag.py +++ b/python/paddle/fluid/tests/unittests/test_diag.py @@ -119,6 +119,16 @@ class TestDiagV2API(unittest.TestCase): (n, n)) + np.diag(self.input_np3, self.offset) - np.diag( self.padding_value * np.ones(n)) + self.input_np4 = np.random.random(size=(2000, 2000)).astype(np.float32) + self.expected6 = np.diag(self.input_np4) + self.expected7 = np.diag(self.input_np4, k=1) + self.expected8 = np.diag(self.input_np4, k=-1) + + self.input_np5 = np.random.random(size=(2000)).astype(np.float32) + self.expected9 = np.diag(self.input_np5) + self.expected10 = np.diag(self.input_np5, k=1) + self.expected11 = np.diag(self.input_np5, k=-1) + def run_imperative(self): x = paddle.to_tensor(self.input_np) y = paddle.diag(x) @@ -141,10 +151,32 @@ class TestDiagV2API(unittest.TestCase): y = paddle.diag(x, padding_value=-8) self.assertTrue(np.allclose(y.numpy(), self.expected5)) + x = paddle.to_tensor(self.input_np4) + y = paddle.diag(x) + self.assertTrue(np.allclose(y.numpy(), self.expected6)) + + y = paddle.diag(x, offset=1) + self.assertTrue(np.allclose(y.numpy(), self.expected7)) + + y = paddle.diag(x, offset=-1) + self.assertTrue(np.allclose(y.numpy(), self.expected8)) + + x = paddle.to_tensor(self.input_np5) + y = paddle.diag(x) + self.assertTrue(np.allclose(y.numpy(), self.expected9)) + + y = paddle.diag(x, offset=1) + self.assertTrue(np.allclose(y.numpy(), self.expected10)) + + y = paddle.diag(x, offset=-1) + self.assertTrue(np.allclose(y.numpy(), self.expected11)) + def run_static(self, use_gpu=False): x = paddle.data(name='input', shape=[10, 10], dtype='float32') x2 = paddle.data(name='input2', shape=[100], dtype='float64') x3 = paddle.data(name='input3', shape=[100], dtype='int64') + x4 = paddle.data(name='input4', shape=[2000, 2000], dtype='float32') + x5 = paddle.data(name='input5', shape=[2000], dtype='float32') result0 = paddle.diag(x) result1 = paddle.diag(x, offset=1) result2 = paddle.diag(x, offset=-1) @@ -152,17 +184,28 @@ class TestDiagV2API(unittest.TestCase): result4 = paddle.diag(x2, padding_value=8) result5 = paddle.diag(x3, padding_value=8.0) result6 = paddle.diag(x3, padding_value=-8) + result7 = paddle.diag(x4) + result8 = paddle.diag(x4, offset=1) + result9 = paddle.diag(x4, offset=-1) + result10 = paddle.diag(x5) + result11 = paddle.diag(x5, offset=1) + result12 = paddle.diag(x5, offset=-1) place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - res0, res1, res2, res4, res5, res6 = exe.run( + res0, res1, res2, res4, res5, res6, res7, res8, res9, res10, res11, res12 = exe.run( feed={ "input": self.input_np, "input2": self.input_np2, - 'input3': self.input_np3 + 'input3': self.input_np3, + 'input4': self.input_np4, + 'input5': self.input_np5 }, - fetch_list=[result0, result1, result2, result4, result5, result6]) + fetch_list=[ + result0, result1, result2, result4, result5, result6, result7, + result8, result9, result10, result11, result12 + ]) self.assertTrue(np.allclose(res0, self.expected0)) self.assertTrue(np.allclose(res1, self.expected1)) @@ -171,6 +214,12 @@ class TestDiagV2API(unittest.TestCase): self.assertTrue(np.allclose(res4, self.expected3)) self.assertTrue(np.allclose(res5, self.expected4)) self.assertTrue(np.allclose(res6, self.expected5)) + self.assertTrue(np.allclose(res7, self.expected6)) + self.assertTrue(np.allclose(res8, self.expected7)) + self.assertTrue(np.allclose(res9, self.expected8)) + self.assertTrue(np.allclose(res10, self.expected9)) + self.assertTrue(np.allclose(res11, self.expected10)) + self.assertTrue(np.allclose(res12, self.expected11)) def test_cpu(self): paddle.disable_static(place=paddle.fluid.CPUPlace())