From 56493c9e98c1c91076db8e5a4c76fdef5388f138 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Thu, 31 Mar 2022 09:52:01 +0800 Subject: [PATCH] fix eager_gen node bug (#41165) --- .../auto_code_generator/final_state_generator/eager_gen.py | 2 +- python/paddle/fluid/tests/unittests/test_inplace.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index a601784042..0f78763d6c 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1255,7 +1255,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): if num_outputs == 1: get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;" else: - get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{fwd_position}];" + get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];" get_outputs_str += get_tensor_str + "\n" # Prepare for Node Creation if Necessary diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 6670f2a174..617e9811d6 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -325,8 +325,6 @@ class TestDygraphInplaceFlatten(TestDygraphInplace): return var.flatten_() -""" -# TODO: need to fix bug class TestDygraphInplaceScatter(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]]) @@ -345,7 +343,6 @@ class TestDygraphInplaceScatter(TestDygraphInplace): [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') return paddle.scatter_(var, index, updates, overwrite=False) -""" class TestDygraphInplaceElu(TestDygraphInplace): -- GitLab