未验证 提交 f8d02146 编写于 作者: C Charles-hit 提交者: GitHub

support zero_dim for some prim ops (#54892)

上级 abc1c3d4
......@@ -554,6 +554,8 @@ def squeeze2_composite(x, axis):
axis can only be list, not int
"""
rank = len(x.shape)
if rank == 0:
return [assign(x), None]
if len(axis) == 0:
dims = set(range(rank))
else:
......
......@@ -1385,6 +1385,11 @@ class TestSqrtCompFp32(TestActivation):
self.dtype = np.float32
class TestSqrtComp_ZeroDim(TestSqrtComp):
def init_shape(self):
self.shape = []
class TestRsqrt(TestActivation):
def setUp(self):
self.op_type = "rsqrt"
......@@ -2029,7 +2034,7 @@ class TestLeakyRelu_ZeroDim(TestLeakyRelu):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
pass
class TestLeakyReluAPI(unittest.TestCase):
......
......@@ -584,10 +584,7 @@ class TestProdOp_ZeroDim(OpTest):
self.public_python_api = raw_reduce_prod
self.op_type = "reduce_prod"
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}
self.init_inputs_and_outputs()
# 0-D tensor doesn't support in cinn
self.enable_cinn = False
......@@ -603,6 +600,29 @@ class TestProdOp_ZeroDim(OpTest):
self.check_grad(['X'], 'Out', check_prim=True)
class TestProdOp_ZeroDim1(TestProdOp):
def setUp(self):
self.python_api = paddle.prod
self.public_python_api = paddle.prod
self.op_type = "reduce_prod"
self.prim_op_type = "prim"
self.init_inputs_and_outputs()
# 0-D tensor doesn't support in cinn
self.enable_cinn = False
def init_inputs_and_outputs(self):
self.inputs = {'X': np.random.random([100]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}
class TestProdOp_ZeroDim2(TestProdOp_ZeroDim1):
def init_inputs_and_outputs(self):
self.inputs = {'X': np.random.random([5, 6, 10]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}
class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
......
......@@ -100,6 +100,20 @@ class TestSqueezeOp1BF16Op(TestSqueezeOp):
self.dtype = np.uint16
class TestSqueezeOp_ZeroDim1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (0,)
self.new_shape = ()
class TestSqueezeOp_ZeroDim2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 1, 1)
self.axes = (0, 1, 2)
self.new_shape = ()
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册