未验证 提交 d7660a7c 编写于 作者: Y yangjianfengo1 提交者: GitHub

[AMP OP&Test] Tile OP (#51380)

* tile_op

* fix bfloat16 x

* update review

* del out
上级 d0d739ca
...@@ -196,19 +196,23 @@ class TestTileOpInteger(OpTest): ...@@ -196,19 +196,23 @@ class TestTileOpInteger(OpTest):
self.check_output() self.check_output()
class TestTileOpFloat16(OpTest): class TestTileFP16OP(OpTest):
def setUp(self): def setUp(self):
self.op_type = "tile" self.op_type = "tile"
self.dtype = np.float16 self.dtype = np.float16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile self.python_api = paddle.tile
self.inputs = { self.init_data()
'X': np.random.uniform(10, size=(100, 4, 5)).astype(self.dtype) x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype)
} output = np.tile(x, self.repeat_times)
self.attrs = {'repeat_times': [2, 1, 4]} self.inputs = {'X': x}
output = np.tile(self.inputs['X'], (2, 1, 4)) self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': output} self.outputs = {'Out': output}
def init_data(self):
self.dtype = np.float16
self.ori_shape = [100, 4, 5]
self.repeat_times = [2, 1, 4]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -221,22 +225,27 @@ class TestTileOpFloat16(OpTest): ...@@ -221,22 +225,27 @@ class TestTileOpFloat16(OpTest):
or not core.is_bfloat16_supported(core.CUDAPlace(0)), or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16", "core is not complied with CUDA and not support the bfloat16",
) )
class TestWhereOpBFloat16(OpTest): class TestTileBF16OP(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'tile' self.op_type = 'tile'
self.dtype = np.uint16
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
self.python_api = paddle.tile self.python_api = paddle.tile
x = np.random.uniform(10, size=(100, 4, 5)).astype(np.float32) self.init_data()
output = np.tile(x, (2, 1, 4)) x = np.random.uniform(10, size=self.ori_shape).astype(np.float32)
output = np.tile(x, self.repeat_times)
self.inputs = {'X': convert_float_to_uint16(x)} self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'repeat_times': [2, 1, 4]} self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': convert_float_to_uint16(output)} self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place)
def init_data(self):
self.dtype = np.uint16
self.ori_shape = [100, 4, 5]
self.repeat_times = [2, 1, 4]
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out')
......
...@@ -3159,7 +3159,7 @@ def tile(x, repeat_times, name=None): ...@@ -3159,7 +3159,7 @@ def tile(x, repeat_times, name=None):
[ [
'bool', 'bool',
'float16', 'float16',
'bfloat16', 'uint16',
'float32', 'float32',
'float64', 'float64',
'int32', 'int32',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册