未验证 提交 eb322853 编写于 作者: W wangxiaoning 提交者: GitHub

support fp16 index sample (#47897)

* add index sample fp16 support

* remove fluid APIs in distributed_strategy.py and role_maker.py

* Revert "remove fluid APIs in distributed_strategy.py and role_maker.py"

This reverts commit 223bbee990d3bf69e252fc3c0f19e3873550a264.

* fix instantiated more than once

* clean codes
上级 e61df289
......@@ -99,6 +99,28 @@ class TestCase4(TestIndexSampleOp):
self.index_type = "int64"
class TestCase5(TestIndexSampleOp):
def config(self):
"""
For float16 x type
"""
self.x_shape = (10, 128)
self.x_type = "float16"
self.index_shape = (10, 64)
self.index_type = "int32"
class TestCase6(TestIndexSampleOp):
def config(self):
"""
For float16 x type
"""
self.x_shape = (10, 128)
self.x_type = "float16"
self.index_shape = (10, 64)
self.index_type = "int64"
class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
......
......@@ -792,7 +792,7 @@ def index_sample(x, index):
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.index_sample',
)
check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册