未验证 提交 9f2d88e9 编写于 作者: L LoneRanger 提交者: GitHub

【PaddlePaddle Hackathon 4】No.63 : add embedding fp16 test (#51321)

上级 dd0681e3
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, skip_check_grad_ci
from eager_op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
from op import Operator
import paddle
......@@ -47,11 +47,17 @@ class TestLookupTableOp(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
table = np.random.random((17, 31)).astype("float64")
self.init_dtype()
table = np.random.random((17, 31)).astype(self.dtype)
ids = np.random.randint(0, 17, 4).astype(self.id_dtype())
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]}
def init_dtype(self):
self.dtype = "float64"
def id_dtype(self):
return "int64"
......@@ -297,6 +303,53 @@ class TestEmbedOpError(unittest.TestCase):
)
class TestEmbeddingFP16OP(TestLookupTableOp):
def setUp(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
self.init_dtype()
table = np.random.random((18, 32)).astype(self.dtype)
ids = np.random.randint(0, 18, 4).astype(self.id_dtype())
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]}
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestEmbeddingBF16OP(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
self.dtype = np.uint16
table = np.random.random((18, 32)).astype("float32")
ids = np.random.randint(0, 18, 4).astype(self.id_dtype())
self.inputs = {'W': convert_float_to_uint16(table), 'Ids': ids}
self.outputs = {'Out': convert_float_to_uint16(table[ids])}
def id_dtype(self):
return "int64"
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_cinn=True)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['W'], 'Out', no_grad_set=set('Ids'), check_cinn=True
)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册