未验证 提交 5508c787 编写于 作者: L LutaoChu 提交者: GitHub

Fix bug: The calculation result of Diag_v2 Op under large size input is wrong (#27447)

The calculation result of Diag_v2 Op under large size input is wrong 
上级 bc5f0246
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <tuple>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_v2_op.h" #include "paddle/fluid/operators/diag_v2_op.h"
...@@ -58,6 +59,17 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> { ...@@ -58,6 +59,17 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> {
auto out_dims = out->dims(); auto out_dims = out->dims();
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto GetBlockGridSize = [&dev_ctx](int64_t size) {
const int64_t block_size =
std::min(size, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value"); float padding_value = context.Attr<float>("padding_value");
math::SetConstant<DeviceContext, T> set_padding_value; math::SetConstant<DeviceContext, T> set_padding_value;
...@@ -67,26 +79,23 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> { ...@@ -67,26 +79,23 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> {
auto size = (offset > 0) ? x_length + offset : x_length - offset; auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = ComputeStride(0, x_dims); const int& x_stride = ComputeStride(0, x_dims);
if (size > 0) { if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
const auto& out_stride_0 = ComputeStride(0, out_dims); const auto& out_stride_0 = ComputeStride(0, out_dims);
const auto& out_stride_1 = ComputeStride(1, out_dims); const auto& out_stride_1 = ComputeStride(1, out_dims);
auto start = auto start =
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0); (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
PasteDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>( std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
out_data, x_data, start, x_length, out_stride_0 + out_stride_1,
x_stride); PasteDiagonalKernel<
T><<<std::get<1>(block_grid_size), std::get<0>(block_grid_size), 0,
dev_ctx.stream()>>>(out_data, x_data, start, x_length,
out_stride_0 + out_stride_1, x_stride);
} }
} else { } else {
const int& x_stride_0 = ComputeStride(0, x_dims); const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims); const int& x_stride_1 = ComputeStride(1, x_dims);
int size; int64_t size;
if (offset > 0) { if (offset > 0) {
size = std::min(x_dims[0], x_dims[1] - offset); size = std::min(x_dims[0], x_dims[1] - offset);
} else { } else {
...@@ -94,18 +103,15 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> { ...@@ -94,18 +103,15 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> {
} }
if (size > 0) { if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0); auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
const auto& out_stride_0 = ComputeStride(0, out_dims); const auto& out_stride_0 = ComputeStride(0, out_dims);
ExtractDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>( std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
out_data, x_data, start, size, x_stride_0 + x_stride_1,
out_stride_0); ExtractDiagonalKernel<
T><<<std::get<1>(block_grid_size), std::get<0>(block_grid_size), 0,
dev_ctx.stream()>>>(out_data, x_data, start, size,
x_stride_0 + x_stride_1, out_stride_0);
} }
} }
} }
......
...@@ -119,6 +119,16 @@ class TestDiagV2API(unittest.TestCase): ...@@ -119,6 +119,16 @@ class TestDiagV2API(unittest.TestCase):
(n, n)) + np.diag(self.input_np3, self.offset) - np.diag( (n, n)) + np.diag(self.input_np3, self.offset) - np.diag(
self.padding_value * np.ones(n)) self.padding_value * np.ones(n))
self.input_np4 = np.random.random(size=(2000, 2000)).astype(np.float32)
self.expected6 = np.diag(self.input_np4)
self.expected7 = np.diag(self.input_np4, k=1)
self.expected8 = np.diag(self.input_np4, k=-1)
self.input_np5 = np.random.random(size=(2000)).astype(np.float32)
self.expected9 = np.diag(self.input_np5)
self.expected10 = np.diag(self.input_np5, k=1)
self.expected11 = np.diag(self.input_np5, k=-1)
def run_imperative(self): def run_imperative(self):
x = paddle.to_tensor(self.input_np) x = paddle.to_tensor(self.input_np)
y = paddle.diag(x) y = paddle.diag(x)
...@@ -141,10 +151,32 @@ class TestDiagV2API(unittest.TestCase): ...@@ -141,10 +151,32 @@ class TestDiagV2API(unittest.TestCase):
y = paddle.diag(x, padding_value=-8) y = paddle.diag(x, padding_value=-8)
self.assertTrue(np.allclose(y.numpy(), self.expected5)) self.assertTrue(np.allclose(y.numpy(), self.expected5))
x = paddle.to_tensor(self.input_np4)
y = paddle.diag(x)
self.assertTrue(np.allclose(y.numpy(), self.expected6))
y = paddle.diag(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected7))
y = paddle.diag(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected8))
x = paddle.to_tensor(self.input_np5)
y = paddle.diag(x)
self.assertTrue(np.allclose(y.numpy(), self.expected9))
y = paddle.diag(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected10))
y = paddle.diag(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected11))
def run_static(self, use_gpu=False): def run_static(self, use_gpu=False):
x = paddle.data(name='input', shape=[10, 10], dtype='float32') x = paddle.data(name='input', shape=[10, 10], dtype='float32')
x2 = paddle.data(name='input2', shape=[100], dtype='float64') x2 = paddle.data(name='input2', shape=[100], dtype='float64')
x3 = paddle.data(name='input3', shape=[100], dtype='int64') x3 = paddle.data(name='input3', shape=[100], dtype='int64')
x4 = paddle.data(name='input4', shape=[2000, 2000], dtype='float32')
x5 = paddle.data(name='input5', shape=[2000], dtype='float32')
result0 = paddle.diag(x) result0 = paddle.diag(x)
result1 = paddle.diag(x, offset=1) result1 = paddle.diag(x, offset=1)
result2 = paddle.diag(x, offset=-1) result2 = paddle.diag(x, offset=-1)
...@@ -152,17 +184,28 @@ class TestDiagV2API(unittest.TestCase): ...@@ -152,17 +184,28 @@ class TestDiagV2API(unittest.TestCase):
result4 = paddle.diag(x2, padding_value=8) result4 = paddle.diag(x2, padding_value=8)
result5 = paddle.diag(x3, padding_value=8.0) result5 = paddle.diag(x3, padding_value=8.0)
result6 = paddle.diag(x3, padding_value=-8) result6 = paddle.diag(x3, padding_value=-8)
result7 = paddle.diag(x4)
result8 = paddle.diag(x4, offset=1)
result9 = paddle.diag(x4, offset=-1)
result10 = paddle.diag(x5)
result11 = paddle.diag(x5, offset=1)
result12 = paddle.diag(x5, offset=-1)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
res0, res1, res2, res4, res5, res6 = exe.run( res0, res1, res2, res4, res5, res6, res7, res8, res9, res10, res11, res12 = exe.run(
feed={ feed={
"input": self.input_np, "input": self.input_np,
"input2": self.input_np2, "input2": self.input_np2,
'input3': self.input_np3 'input3': self.input_np3,
'input4': self.input_np4,
'input5': self.input_np5
}, },
fetch_list=[result0, result1, result2, result4, result5, result6]) fetch_list=[
result0, result1, result2, result4, result5, result6, result7,
result8, result9, result10, result11, result12
])
self.assertTrue(np.allclose(res0, self.expected0)) self.assertTrue(np.allclose(res0, self.expected0))
self.assertTrue(np.allclose(res1, self.expected1)) self.assertTrue(np.allclose(res1, self.expected1))
...@@ -171,6 +214,12 @@ class TestDiagV2API(unittest.TestCase): ...@@ -171,6 +214,12 @@ class TestDiagV2API(unittest.TestCase):
self.assertTrue(np.allclose(res4, self.expected3)) self.assertTrue(np.allclose(res4, self.expected3))
self.assertTrue(np.allclose(res5, self.expected4)) self.assertTrue(np.allclose(res5, self.expected4))
self.assertTrue(np.allclose(res6, self.expected5)) self.assertTrue(np.allclose(res6, self.expected5))
self.assertTrue(np.allclose(res7, self.expected6))
self.assertTrue(np.allclose(res8, self.expected7))
self.assertTrue(np.allclose(res9, self.expected8))
self.assertTrue(np.allclose(res10, self.expected9))
self.assertTrue(np.allclose(res11, self.expected10))
self.assertTrue(np.allclose(res12, self.expected11))
def test_cpu(self): def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace()) paddle.disable_static(place=paddle.fluid.CPUPlace())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册