未验证 提交 e397a3ff 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP_OP&Test] Add float16 OpTest for full_op (#50723)

上级 aab713ea
...@@ -30,60 +30,71 @@ def fill_wrapper(shape, value=0.0): ...@@ -30,60 +30,71 @@ def fill_wrapper(shape, value=0.0):
# Situation 1: Attr(shape) is a list(without tensor) # Situation 1: Attr(shape) is a list(without tensor)
class TestFillConstantOp1(OpTest): # Base case
class TestFillConstantOp(OpTest):
def setUp(self): def setUp(self):
'''Test fill_constant op with specified value''' '''Test fill_constant op with default value'''
self.op_type = "fill_constant" self.op_type = "fill_constant"
self.python_api = fill_wrapper self.python_api = fill_wrapper
self.init_dtype()
self.init_shape()
self.init_value()
self.inputs = {} self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3.8} self.attrs = {'shape': self.shape, 'value': self.value}
self.outputs = {'Out': np.full((123, 92), 3.8)} self.outputs = {'Out': np.full(self.shape, self.value)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def init_dtype(self):
self.dtype = np.float64
class TestFillConstantOp2(OpTest): def init_shape(self):
def setUp(self): self.shape = [123, 92]
'''Test fill_constant op with default value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
self.inputs = {} def init_value(self):
self.attrs = {'shape': [123, 92]} self.value = 0.0
self.outputs = {'Out': np.full((123, 92), 0.0)}
def test_check_output(self):
self.check_output()
class TestFillConstantFP32Op(TestFillConstantOp):
'''Test fill_constant op with specified value'''
class TestFillConstantOp3(OpTest): def init_dtype(self):
def setUp(self): self.dtype = np.float32
'''Test fill_constant op with specified int64 value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
self.inputs = {} def init_value(self):
self.attrs = {'shape': [123, 92], 'value': 10000000000} self.value = 3.8
self.outputs = {'Out': np.full((123, 92), 10000000000)}
def test_check_output(self):
self.check_output()
class TestFillConstantFP16Op(TestFillConstantOp):
'''Test fill_constant op with specified value'''
class TestFillConstantOp4(OpTest): def init_dtype(self):
def setUp(self): self.dtype = np.float16
'''Test fill_constant op with specified int value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
self.inputs = {} def init_value(self):
self.attrs = {'shape': [123, 92], 'value': 3} self.value = 3.8
self.outputs = {'Out': np.full((123, 92), 3)}
def test_check_output(self):
self.check_output() class TestFillConstantINT64Op(TestFillConstantOp):
'''Test fill_constant op with specified int64 value'''
def init_dtype(self):
self.dtype = np.int64
def init_value(self):
self.value = 10000000000
class TestFillConstantINT32Op(TestFillConstantOp):
'''Test fill_constant op with specified int value'''
def init_dtype(self):
self.dtype = np.int32
def init_value(self):
self.value = 3
@unittest.skipIf( @unittest.skipIf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册