未验证 提交 5508c787 编写于 作者: L LutaoChu 提交者: GitHub

Fix bug: The calculation result of Diag_v2 Op under large size input is wrong (#27447)

The calculation result of Diag_v2 Op under large size input is wrong 
上级 bc5f0246
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <tuple>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_v2_op.h"
......@@ -58,6 +59,17 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> {
auto out_dims = out->dims();
auto& dev_ctx = context.template device_context<DeviceContext>();
auto GetBlockGridSize = [&dev_ctx](int64_t size) {
const int64_t block_size =
std::min(size, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
math::SetConstant<DeviceContext, T> set_padding_value;
......@@ -67,26 +79,23 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> {
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = ComputeStride(0, x_dims);
if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
const auto& out_stride_0 = ComputeStride(0, out_dims);
const auto& out_stride_1 = ComputeStride(1, out_dims);
auto start =
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
PasteDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>(
out_data, x_data, start, x_length, out_stride_0 + out_stride_1,
x_stride);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
PasteDiagonalKernel<
T><<<std::get<1>(block_grid_size), std::get<0>(block_grid_size), 0,
dev_ctx.stream()>>>(out_data, x_data, start, x_length,
out_stride_0 + out_stride_1, x_stride);
}
} else {
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);
int size;
int64_t size;
if (offset > 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
......@@ -94,18 +103,15 @@ class DiagV2CUDAKernel : public framework::OpKernel<T> {
}
if (size > 0) {
const int block_num = std::min(static_cast<int>(size),
dev_ctx.GetMaxPhysicalThreadCount());
int size_ = static_cast<int>(size);
int block_num_ = static_cast<int>(block_num);
const int grid_num =
std::min(1024, (size_ + block_num_ - 1) / block_num_);
auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
const auto& out_stride_0 = ComputeStride(0, out_dims);
ExtractDiagonalKernel<T><<<grid_num, block_num, 0, dev_ctx.stream()>>>(
out_data, x_data, start, size, x_stride_0 + x_stride_1,
out_stride_0);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
ExtractDiagonalKernel<
T><<<std::get<1>(block_grid_size), std::get<0>(block_grid_size), 0,
dev_ctx.stream()>>>(out_data, x_data, start, size,
x_stride_0 + x_stride_1, out_stride_0);
}
}
}
......
......@@ -119,6 +119,16 @@ class TestDiagV2API(unittest.TestCase):
(n, n)) + np.diag(self.input_np3, self.offset) - np.diag(
self.padding_value * np.ones(n))
self.input_np4 = np.random.random(size=(2000, 2000)).astype(np.float32)
self.expected6 = np.diag(self.input_np4)
self.expected7 = np.diag(self.input_np4, k=1)
self.expected8 = np.diag(self.input_np4, k=-1)
self.input_np5 = np.random.random(size=(2000)).astype(np.float32)
self.expected9 = np.diag(self.input_np5)
self.expected10 = np.diag(self.input_np5, k=1)
self.expected11 = np.diag(self.input_np5, k=-1)
def run_imperative(self):
x = paddle.to_tensor(self.input_np)
y = paddle.diag(x)
......@@ -141,10 +151,32 @@ class TestDiagV2API(unittest.TestCase):
y = paddle.diag(x, padding_value=-8)
self.assertTrue(np.allclose(y.numpy(), self.expected5))
x = paddle.to_tensor(self.input_np4)
y = paddle.diag(x)
self.assertTrue(np.allclose(y.numpy(), self.expected6))
y = paddle.diag(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected7))
y = paddle.diag(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected8))
x = paddle.to_tensor(self.input_np5)
y = paddle.diag(x)
self.assertTrue(np.allclose(y.numpy(), self.expected9))
y = paddle.diag(x, offset=1)
self.assertTrue(np.allclose(y.numpy(), self.expected10))
y = paddle.diag(x, offset=-1)
self.assertTrue(np.allclose(y.numpy(), self.expected11))
def run_static(self, use_gpu=False):
x = paddle.data(name='input', shape=[10, 10], dtype='float32')
x2 = paddle.data(name='input2', shape=[100], dtype='float64')
x3 = paddle.data(name='input3', shape=[100], dtype='int64')
x4 = paddle.data(name='input4', shape=[2000, 2000], dtype='float32')
x5 = paddle.data(name='input5', shape=[2000], dtype='float32')
result0 = paddle.diag(x)
result1 = paddle.diag(x, offset=1)
result2 = paddle.diag(x, offset=-1)
......@@ -152,17 +184,28 @@ class TestDiagV2API(unittest.TestCase):
result4 = paddle.diag(x2, padding_value=8)
result5 = paddle.diag(x3, padding_value=8.0)
result6 = paddle.diag(x3, padding_value=-8)
result7 = paddle.diag(x4)
result8 = paddle.diag(x4, offset=1)
result9 = paddle.diag(x4, offset=-1)
result10 = paddle.diag(x5)
result11 = paddle.diag(x5, offset=1)
result12 = paddle.diag(x5, offset=-1)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
res0, res1, res2, res4, res5, res6 = exe.run(
res0, res1, res2, res4, res5, res6, res7, res8, res9, res10, res11, res12 = exe.run(
feed={
"input": self.input_np,
"input2": self.input_np2,
'input3': self.input_np3
'input3': self.input_np3,
'input4': self.input_np4,
'input5': self.input_np5
},
fetch_list=[result0, result1, result2, result4, result5, result6])
fetch_list=[
result0, result1, result2, result4, result5, result6, result7,
result8, result9, result10, result11, result12
])
self.assertTrue(np.allclose(res0, self.expected0))
self.assertTrue(np.allclose(res1, self.expected1))
......@@ -171,6 +214,12 @@ class TestDiagV2API(unittest.TestCase):
self.assertTrue(np.allclose(res4, self.expected3))
self.assertTrue(np.allclose(res5, self.expected4))
self.assertTrue(np.allclose(res6, self.expected5))
self.assertTrue(np.allclose(res7, self.expected6))
self.assertTrue(np.allclose(res8, self.expected7))
self.assertTrue(np.allclose(res9, self.expected8))
self.assertTrue(np.allclose(res10, self.expected9))
self.assertTrue(np.allclose(res11, self.expected10))
self.assertTrue(np.allclose(res12, self.expected11))
def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册