未验证 提交 e397a3ff 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP_OP&Test] Add float16 OpTest for full_op (#50723)

上级 aab713ea
......@@ -30,60 +30,71 @@ def fill_wrapper(shape, value=0.0):
# Situation 1: Attr(shape) is a list(without tensor)
class TestFillConstantOp1(OpTest):
# Base case
class TestFillConstantOp(OpTest):
def setUp(self):
'''Test fill_constant op with specified value'''
'''Test fill_constant op with default value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
self.init_dtype()
self.init_shape()
self.init_value()
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3.8}
self.outputs = {'Out': np.full((123, 92), 3.8)}
self.attrs = {'shape': self.shape, 'value': self.value}
self.outputs = {'Out': np.full(self.shape, self.value)}
def test_check_output(self):
self.check_output()
def init_dtype(self):
self.dtype = np.float64
class TestFillConstantOp2(OpTest):
def setUp(self):
'''Test fill_constant op with default value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
def init_shape(self):
self.shape = [123, 92]
self.inputs = {}
self.attrs = {'shape': [123, 92]}
self.outputs = {'Out': np.full((123, 92), 0.0)}
def init_value(self):
self.value = 0.0
def test_check_output(self):
self.check_output()
class TestFillConstantFP32Op(TestFillConstantOp):
'''Test fill_constant op with specified value'''
class TestFillConstantOp3(OpTest):
def setUp(self):
'''Test fill_constant op with specified int64 value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
def init_dtype(self):
self.dtype = np.float32
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 10000000000}
self.outputs = {'Out': np.full((123, 92), 10000000000)}
def init_value(self):
self.value = 3.8
def test_check_output(self):
self.check_output()
class TestFillConstantFP16Op(TestFillConstantOp):
'''Test fill_constant op with specified value'''
class TestFillConstantOp4(OpTest):
def setUp(self):
'''Test fill_constant op with specified int value'''
self.op_type = "fill_constant"
self.python_api = fill_wrapper
def init_dtype(self):
self.dtype = np.float16
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3}
self.outputs = {'Out': np.full((123, 92), 3)}
def init_value(self):
self.value = 3.8
def test_check_output(self):
self.check_output()
class TestFillConstantINT64Op(TestFillConstantOp):
'''Test fill_constant op with specified int64 value'''
def init_dtype(self):
self.dtype = np.int64
def init_value(self):
self.value = 10000000000
class TestFillConstantINT32Op(TestFillConstantOp):
'''Test fill_constant op with specified int value'''
def init_dtype(self):
self.dtype = np.int32
def init_value(self):
self.value = 3
@unittest.skipIf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册