未验证 提交 418b983c 编写于 作者: Y YuhangLi 提交者: GitHub

[AMP OP&Test]shape op fp/bf16 support (#52184)

上级 ceca55c5
...@@ -27,13 +27,13 @@ class TestShapeOp(OpTest): ...@@ -27,13 +27,13 @@ class TestShapeOp(OpTest):
self.op_type = "shape" self.op_type = "shape"
self.python_api = paddle.shape self.python_api = paddle.shape
self.config() self.config()
self.shape = [2, 3] input = np.zeros(self.shape, dtype=self.dtype)
input = np.zeros(self.shape)
self.inputs = {'Input': input} self.inputs = {'Input': input}
self.outputs = {'Out': np.array(self.shape)} self.outputs = {'Out': np.array(self.shape)}
def config(self): def config(self):
self.shape = [2, 3] self.shape = [2, 3]
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -42,11 +42,31 @@ class TestShapeOp(OpTest): ...@@ -42,11 +42,31 @@ class TestShapeOp(OpTest):
class case1(TestShapeOp): class case1(TestShapeOp):
def config(self): def config(self):
self.shape = [2] self.shape = [2]
self.dtype = np.float32
class case2(TestShapeOp): class case2(TestShapeOp):
def config(self): def config(self):
self.shape = [1, 2, 3] self.shape = [1, 2, 3]
self.dtype = np.float32
class TestShapeOpFp16(TestShapeOp):
def config(self):
self.shape = [2, 3]
self.dtype = np.float16
class case1Fp16(TestShapeOp):
def config(self):
self.shape = [2]
self.dtype = np.float16
class case2Fp16(TestShapeOp):
def config(self):
self.shape = [1, 2, 3]
self.dtype = np.float16
class TestShapeWithSelectedRows(unittest.TestCase): class TestShapeWithSelectedRows(unittest.TestCase):
...@@ -95,7 +115,6 @@ class TestShapeOpBf16(OpTest): ...@@ -95,7 +115,6 @@ class TestShapeOpBf16(OpTest):
self.dtype = 'bfloat16' self.dtype = 'bfloat16'
self.python_api = paddle.shape self.python_api = paddle.shape
self.config() self.config()
self.shape = [2, 3]
input = np.zeros(self.shape) input = np.zeros(self.shape)
input = convert_float_to_uint16(input.astype('float32')) input = convert_float_to_uint16(input.astype('float32'))
self.inputs = {'Input': input} self.inputs = {'Input': input}
...@@ -109,5 +128,15 @@ class TestShapeOpBf16(OpTest): ...@@ -109,5 +128,15 @@ class TestShapeOpBf16(OpTest):
self.check_output_with_place(place) self.check_output_with_place(place)
class case1Bf16(TestShapeOpBf16):
def config(self):
self.shape = [2]
class case2Bf16(TestShapeOpBf16):
def config(self):
self.shape = [1, 2, 3]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册