From 24b2cc8d0880ed54bab6fd21c797ecd9b8b27c72 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Thu, 29 Jun 2023 15:02:11 +0800 Subject: [PATCH] support cast, dropout, gather, mean prim ops zero dim (#54966) --- .../composite_backward_api.h | 15 +++++++-- .../incubate/autograd/composite_rules.py | 8 +++-- test/legacy_test/test_cast_op.py | 11 ++++++- test/legacy_test/test_dropout_op.py | 32 +++++++++++++++++++ test/legacy_test/test_gather_op.py | 14 ++++++++ test/legacy_test/test_mean_op.py | 18 +++++++++-- 6 files changed, 88 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 0b75cfef148..b0090fce1f2 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -150,11 +150,20 @@ void gather_grad(const Tensor& x, } // transpose out_grad and zero grad to target rank. - auto tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); - auto tmp_out_grad = transpose(out_grad, tmp_perm); + auto tmp_zero_x_grad = zero_tensor; + auto tmp_out_grad = out_grad; + if (zero_tensor.dims().size() > 0) { + tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); + } + if (out_grad.dims().size() > 0) { + tmp_out_grad = transpose(out_grad, tmp_perm); + } // scatter grad to grad_x auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); - auto tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); + auto tmp_grad_x_tranposed = tmp_grad_x; + if (tmp_grad_x.dims().size() > 0) { + tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); + } set_output(tmp_grad_x_tranposed, grad_x); } diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 0941da78768..7640b9d95a9 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -255,9 +255,11 @@ def mean_composite(x, axis, keepdim): axes = axis or list(range(0, len(x.shape))) axes = [axes] if isinstance(axes, int) else axes sum_x = sum(x, axis=axes, keepdim=keepdim) - value_to_fill = functools.reduce( - operator.mul, [x.shape[axis] for axis in axes] - ) + ele_nums_list = [x.shape[axis] for axis in axes] + if ele_nums_list == []: + value_to_fill = 1 + else: + value_to_fill = functools.reduce(operator.mul, ele_nums_list) norm = fill_constant( shape=[], value=value_to_fill, diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index dde01a2296c..d1353b24f49 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -34,7 +34,8 @@ def cast_wrapper(x, out_dtype=None): class TestCastOpFp32ToFp64(OpTest): def setUp(self): - ipt = np.random.random(size=[10, 10]) + self.init_shapes() + ipt = np.random.random(size=self.input_shape) self.inputs = {'X': ipt.astype('float32')} self.outputs = {'Out': ipt.astype('float64')} self.attrs = { @@ -46,6 +47,9 @@ class TestCastOpFp32ToFp64(OpTest): self.python_api = cast_wrapper self.public_python_api = cast_wrapper + def init_shapes(self): + self.input_shape = [10, 10] + def test_check_output(self): self.check_output() @@ -53,6 +57,11 @@ class TestCastOpFp32ToFp64(OpTest): self.check_grad(['X'], ['Out'], check_prim=True) +class TestCastOpFp32ToFp64_ZeroDim(TestCastOpFp32ToFp64): + def init_shapes(self): + self.input_shape = () + + class TestCastOpFp16ToFp32(OpTest): def setUp(self): ipt = np.random.random(size=[10, 10]) diff --git a/test/legacy_test/test_dropout_op.py b/test/legacy_test/test_dropout_op.py index 4a0b4edba07..088e9ce4830 100644 --- a/test/legacy_test/test_dropout_op.py +++ b/test/legacy_test/test_dropout_op.py @@ -87,6 +87,24 @@ class TestDropoutOp(OpTest): self.check_grad(['X'], 'Out', check_prim=False) +class TestDropoutOp_ZeroDim(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.prim_op_type = "comp" + self.python_api = dropout_wapper + self.public_python_api = prim_dropout_wrapper + self.inputs = {'X': np.random.random(()).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones(()).astype('uint8'), + } + # Because prim op compare res with dygraph + # when p = 0 dropout api return x,in dygraph mode x_grad = out_grad, + # but in static mode x_grad = [] + self.enable_check_static_comp = False + + class TestDropoutOpInput1d(OpTest): def setUp(self): self.op_type = "dropout" @@ -126,6 +144,20 @@ class TestDropoutOp2(TestDropoutOp): } +class TestDropoutOp2_ZeroDim(TestDropoutOp2): + def setUp(self): + self.op_type = "dropout" + self.python_api = dropout_wapper + self.public_python_api = prim_dropout_wrapper + self.prim_op_type = "comp" + self.inputs = {'X': np.random.random(()).astype("float32")} + self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False} + self.outputs = { + 'Out': np.zeros(()).astype('float32'), + 'Mask': np.zeros(()).astype('uint8'), + } + + class TestDropoutOp3(TestDropoutOp): def setUp(self): self.op_type = "dropout" diff --git a/test/legacy_test/test_gather_op.py b/test/legacy_test/test_gather_op.py index 0565b412438..ff67a85484a 100644 --- a/test/legacy_test/test_gather_op.py +++ b/test/legacy_test/test_gather_op.py @@ -70,6 +70,20 @@ class TestGatherOp(OpTest): pass +class TestGatherOp_ZeroDim(TestGatherOp): + def config(self): + """ + For multi-dimension input + """ + self.x_shape = 100 + self.config_dtype() + self.index = 2 + self.index_type = "int32" + + def if_enable_cinn(self): + self.enable_cinn = False + + class TestGatherOpFP16(TestGatherOp): def config_dtype(self): self.x_type = "float16" diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index 9e5e7fc17c9..5f569263041 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -152,17 +152,20 @@ class TestReduceMeanOp(OpTest): self.public_python_api = reduce_mean_wrapper self.prim_op_type = "comp" self.dtype = 'float64' - self.shape = [2, 3, 4, 5] + self.init_shapes() self.axis = [0] + if self.shape == []: + self.axis = [] self.keepdim = False self.set_attrs() self.if_enable_cinn() np.random.seed(10) x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) - if not hasattr(self, "reduce_all"): + if not hasattr(self, "reduce_all") and not x_np.shape == (): self.reduce_all = (not self.axis) or len(self.axis) == len(x_np) - + if x_np.shape == (): + self.reduce_all = True out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all) self.inputs = {'X': x_np} self.outputs = {'Out': out_np} @@ -172,6 +175,9 @@ class TestReduceMeanOp(OpTest): 'reduce_all': self.reduce_all, } + def init_shapes(self): + self.shape = [2, 3, 4, 5] + def set_attrs(self): pass @@ -195,6 +201,12 @@ class TestReduceMeanOp(OpTest): ) +class TestReduceMeanOp_ZeroDim(TestReduceMeanOp): + def init_shapes(self): + self.shape = [] + self.enable_cinn = False + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), -- GitLab