未验证 提交 d7660a7c 编写于 作者: Y yangjianfengo1 提交者: GitHub

[AMP OP&Test] Tile OP (#51380)

* tile_op

* fix bfloat16 x

* update review

* del out
上级 d0d739ca
......@@ -196,19 +196,23 @@ class TestTileOpInteger(OpTest):
self.check_output()
class TestTileOpFloat16(OpTest):
class TestTileFP16OP(OpTest):
def setUp(self):
self.op_type = "tile"
self.dtype = np.float16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile
self.inputs = {
'X': np.random.uniform(10, size=(100, 4, 5)).astype(self.dtype)
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype)
output = np.tile(x, self.repeat_times)
self.inputs = {'X': x}
self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': output}
def init_data(self):
self.dtype = np.float16
self.ori_shape = [100, 4, 5]
self.repeat_times = [2, 1, 4]
def test_check_output(self):
self.check_output()
......@@ -221,22 +225,27 @@ class TestTileOpFloat16(OpTest):
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestWhereOpBFloat16(OpTest):
class TestTileBF16OP(OpTest):
def setUp(self):
self.op_type = 'tile'
self.dtype = np.uint16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile
x = np.random.uniform(10, size=(100, 4, 5)).astype(np.float32)
output = np.tile(x, (2, 1, 4))
self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(np.float32)
output = np.tile(x, self.repeat_times)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'repeat_times': [2, 1, 4]}
self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def init_data(self):
self.dtype = np.uint16
self.ori_shape = [100, 4, 5]
self.repeat_times = [2, 1, 4]
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
......
......@@ -3159,7 +3159,7 @@ def tile(x, repeat_times, name=None):
[
'bool',
'float16',
'bfloat16',
'uint16',
'float32',
'float64',
'int32',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册