未验证 提交 56493c9e 编写于 作者: P pangyoki 提交者: GitHub

fix eager_gen node bug (#41165)

上级 11d1a51a
...@@ -1255,7 +1255,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1255,7 +1255,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if num_outputs == 1: if num_outputs == 1:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;" get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;"
else: else:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{fwd_position}];" get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];"
get_outputs_str += get_tensor_str + "\n" get_outputs_str += get_tensor_str + "\n"
# Prepare for Node Creation if Necessary # Prepare for Node Creation if Necessary
......
...@@ -325,8 +325,6 @@ class TestDygraphInplaceFlatten(TestDygraphInplace): ...@@ -325,8 +325,6 @@ class TestDygraphInplaceFlatten(TestDygraphInplace):
return var.flatten_() return var.flatten_()
"""
# TODO: need to fix bug
class TestDygraphInplaceScatter(TestDygraphInplace): class TestDygraphInplaceScatter(TestDygraphInplace):
def init_data(self): def init_data(self):
self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]]) self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]])
...@@ -345,7 +343,6 @@ class TestDygraphInplaceScatter(TestDygraphInplace): ...@@ -345,7 +343,6 @@ class TestDygraphInplaceScatter(TestDygraphInplace):
[[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
return paddle.scatter_(var, index, updates, overwrite=False) return paddle.scatter_(var, index, updates, overwrite=False)
"""
class TestDygraphInplaceElu(TestDygraphInplace): class TestDygraphInplaceElu(TestDygraphInplace):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册