From e5616448deda49ee9617473647f6b2e7316bb3b5 Mon Sep 17 00:00:00 2001 From: wangxiaoning <71813629+wangxn12138@users.noreply.github.com> Date: Wed, 15 Mar 2023 13:40:34 +0800 Subject: [PATCH] [AMP OP&Test]fix index_select bf16 test (#51652) --- paddle/phi/kernels/gpu/index_select_grad_kernel.cu | 3 --- .../paddle/fluid/tests/unittests/test_index_select_op.py | 8 ++++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index be00fe0ddc8..9578241829f 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -31,7 +31,6 @@ template __global__ void index_select_grad_cuda_kernel(const T* output_grad, T* input_grad, const IndexT* index, - int64_t nums, int64_t N, int64_t stride, int64_t size, @@ -104,7 +103,6 @@ void IndexSelectGradKernel(const Context& ctx, <<>>(output_grad_data, in_grad_data, index_data, - index_nums, out_nums, stride, size, @@ -115,7 +113,6 @@ void IndexSelectGradKernel(const Context& ctx, <<>>(output_grad_data, in_grad_data, index_data, - index_nums, out_nums, stride, size, diff --git a/python/paddle/fluid/tests/unittests/test_index_select_op.py b/python/paddle/fluid/tests/unittests/test_index_select_op.py index dd3bdd80591..a62f7028a9e 100644 --- a/python/paddle/fluid/tests/unittests/test_index_select_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_select_op.py @@ -21,6 +21,8 @@ import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard +np.random.seed(1024) + class TestIndexSelectOp(OpTest): def setUp(self): @@ -119,7 +121,7 @@ class TestIndexSelectBF16Op(OpTest): self.dim = 1 self.x_type = np.uint16 self.index_type = np.int64 - self.x_shape = (100, 4, 5) + self.x_shape = (20, 4, 5) self.index_size = 100 def test_check_output(self): @@ -137,10 +139,11 @@ class TestIndexSelectAPI(unittest.TestCase): [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ] - ) + ).astype("float32") self.data_index = np.array([0, 1, 1]).astype('int32') def test_index_select_api(self): + paddle.enable_static() self.input_data() # case 1: @@ -176,6 +179,7 @@ class TestIndexSelectAPI(unittest.TestCase): np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) def test_dygraph_api(self): + paddle.disable_static() self.input_data() # case 1: with fluid.dygraph.guard(): -- GitLab