未验证 提交 6c650445 编写于 作者: Y yangjianfengo1 提交者: GitHub

AMP tile_op & Test (#51193)

* tile_op

* fix bfloat16 x
上级 37dbbbd1
......@@ -17,7 +17,7 @@ import unittest
import gradient_checker
import numpy as np
from decorator_helper import prog_scope
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
......@@ -196,6 +196,52 @@ class TestTileOpInteger(OpTest):
self.check_output()
class TestTileOpFloat16(OpTest):
def setUp(self):
self.op_type = "tile"
self.dtype = np.float16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile
self.inputs = {
'X': np.random.uniform(10, size=(100, 4, 5)).astype(self.dtype)
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestWhereOpBFloat16(OpTest):
def setUp(self):
self.op_type = 'tile'
self.dtype = np.uint16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile
x = np.random.uniform(10, size=(100, 4, 5)).astype(np.float32)
output = np.tile(x, (2, 1, 4))
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'repeat_times': [2, 1, 4]}
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
# Situation 5: input x is Bool
class TestTileOpBoolean(OpTest):
def setUp(self):
......
......@@ -3156,7 +3156,15 @@ def tile(x, repeat_times, name=None):
check_variable_and_dtype(
x,
'x',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'bool',
'float16',
'bfloat16',
'float32',
'float64',
'int32',
'int64',
],
'tile',
)
if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册