From 49f5a97bff2b416c2ba4ed9f49687687ff940098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 4 Jan 2023 14:32:46 +0800 Subject: [PATCH] Add for-else (#49521) * add for-else * add * for unpacking --- python/paddle/fluid/layers/math_op_patch.py | 12 +++++------- .../tests/unittests/test_imperative_auto_prune.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 126bc1c6eb..feed6641af 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -381,22 +381,20 @@ def monkey_patch_variable(): lhs_dtype = safe_get_dtype(self) if not isinstance(other_var, Variable): if reverse: - has_batch_size = False for elem in self.shape: if elem < 0: - has_batch_size = True + other_var = create_tensor_with_batchsize( + self, other_var, lhs_dtype + ) break - if not has_batch_size: + else: + # when break is not triggered, enter the else branch other_var = create_tensor( current_block(self), other_var, dtype=lhs_dtype, shape=self.shape, ) - else: - other_var = create_tensor_with_batchsize( - self, other_var, lhs_dtype - ) else: # add fill_op to current_block other_var = create_scalar( diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index 54cba6eb80..679a141fc5 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -372,9 +372,9 @@ class TestImperativeAutoPrune(unittest.TestCase): loss = model.embed_linear0(indices) loss.backward() _, params_grads = optimizer.minimize(loss) - for items in params_grads: - assert items[0].name is not model.embed1.weight.name - assert items[0].name is not model.linear_1.weight.name + for (items_0, *items_len) in params_grads: + assert items_0.name is not model.embed1.weight.name + assert items_0.name is not model.linear_1.weight.name assert model.embed1.weight._grad_ivar() is None assert model.linear_1.weight._grad_ivar() is None -- GitLab