未验证 提交 866c2877 编写于 作者: W Wang Xinyu 提交者: GitHub

[AMP OP&Test] Add float16 OpTest for squeeze, unsqueeze (#52018)

* add squeeze, unsqueeze, transpose fp16 unitest

* Update test_transpose_op.py
上级 d6011cb6
...@@ -52,6 +52,31 @@ class TestSqueezeOp(OpTest): ...@@ -52,6 +52,31 @@ class TestSqueezeOp(OpTest):
self.attrs = {"axes": self.axes} self.attrs = {"axes": self.axes}
class TestSqueezeFP16Op(OpTest):
def setUp(self):
self.op_type = "squeeze"
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float16")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, 2)
self.new_shape = (3, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
class TestSqueezeBF16Op(OpTest): class TestSqueezeBF16Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "squeeze" self.op_type = "squeeze"
......
...@@ -50,6 +50,29 @@ class TestUnsqueezeOp(OpTest): ...@@ -50,6 +50,29 @@ class TestUnsqueezeOp(OpTest):
self.attrs = {"axes": self.axes} self.attrs = {"axes": self.axes}
class TestUnsqueezeFP16Op(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float16")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (3, 40)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
class TestUnsqueezeBF16Op(OpTest): class TestUnsqueezeBF16Op(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册