未验证 提交 e5616448 编写于 作者: W wangxiaoning 提交者: GitHub

[AMP OP&Test]fix index_select bf16 test (#51652)

上级 64076727
......@@ -31,7 +31,6 @@ template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index,
int64_t nums,
int64_t N,
int64_t stride,
int64_t size,
......@@ -104,7 +103,6 @@ void IndexSelectGradKernel(const Context& ctx,
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
......@@ -115,7 +113,6 @@ void IndexSelectGradKernel(const Context& ctx,
<<<grid_dim, block_dim, 0, stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
......
......@@ -21,6 +21,8 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
np.random.seed(1024)
class TestIndexSelectOp(OpTest):
def setUp(self):
......@@ -119,7 +121,7 @@ class TestIndexSelectBF16Op(OpTest):
self.dim = 1
self.x_type = np.uint16
self.index_type = np.int64
self.x_shape = (100, 4, 5)
self.x_shape = (20, 4, 5)
self.index_size = 100
def test_check_output(self):
......@@ -137,10 +139,11 @@ class TestIndexSelectAPI(unittest.TestCase):
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
]
)
).astype("float32")
self.data_index = np.array([0, 1, 1]).astype('int32')
def test_index_select_api(self):
paddle.enable_static()
self.input_data()
# case 1:
......@@ -176,6 +179,7 @@ class TestIndexSelectAPI(unittest.TestCase):
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
def test_dygraph_api(self):
paddle.disable_static()
self.input_data()
# case 1:
with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册