未验证 提交 e5616448 编写于 作者: W wangxiaoning 提交者: GitHub

[AMP OP&Test]fix index_select bf16 test (#51652)

上级 64076727
...@@ -31,7 +31,6 @@ template <typename T, typename IndexT> ...@@ -31,7 +31,6 @@ template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad, __global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad, T* input_grad,
const IndexT* index, const IndexT* index,
int64_t nums,
int64_t N, int64_t N,
int64_t stride, int64_t stride,
int64_t size, int64_t size,
...@@ -104,7 +103,6 @@ void IndexSelectGradKernel(const Context& ctx, ...@@ -104,7 +103,6 @@ void IndexSelectGradKernel(const Context& ctx,
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data, <<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
in_grad_data, in_grad_data,
index_data, index_data,
index_nums,
out_nums, out_nums,
stride, stride,
size, size,
...@@ -115,7 +113,6 @@ void IndexSelectGradKernel(const Context& ctx, ...@@ -115,7 +113,6 @@ void IndexSelectGradKernel(const Context& ctx,
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data, <<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
in_grad_data, in_grad_data,
index_data, index_data,
index_nums,
out_nums, out_nums,
stride, stride,
size, size,
......
...@@ -21,6 +21,8 @@ import paddle ...@@ -21,6 +21,8 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
np.random.seed(1024)
class TestIndexSelectOp(OpTest): class TestIndexSelectOp(OpTest):
def setUp(self): def setUp(self):
...@@ -119,7 +121,7 @@ class TestIndexSelectBF16Op(OpTest): ...@@ -119,7 +121,7 @@ class TestIndexSelectBF16Op(OpTest):
self.dim = 1 self.dim = 1
self.x_type = np.uint16 self.x_type = np.uint16
self.index_type = np.int64 self.index_type = np.int64
self.x_shape = (100, 4, 5) self.x_shape = (20, 4, 5)
self.index_size = 100 self.index_size = 100
def test_check_output(self): def test_check_output(self):
...@@ -137,10 +139,11 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -137,10 +139,11 @@ class TestIndexSelectAPI(unittest.TestCase):
[5.0, 6.0, 7.0, 8.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0], [9.0, 10.0, 11.0, 12.0],
] ]
) ).astype("float32")
self.data_index = np.array([0, 1, 1]).astype('int32') self.data_index = np.array([0, 1, 1]).astype('int32')
def test_index_select_api(self): def test_index_select_api(self):
paddle.enable_static()
self.input_data() self.input_data()
# case 1: # case 1:
...@@ -176,6 +179,7 @@ class TestIndexSelectAPI(unittest.TestCase): ...@@ -176,6 +179,7 @@ class TestIndexSelectAPI(unittest.TestCase):
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
def test_dygraph_api(self): def test_dygraph_api(self):
paddle.disable_static()
self.input_data() self.input_data()
# case 1: # case 1:
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册