未验证 提交 21c333df 编写于 作者: Z zhangbo9674 提交者: GitHub

fix setTensorWrapper with no_need_buffers (#41892) (#41952)

上级 b4adbe5c
...@@ -730,11 +730,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -730,11 +730,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
is_optional = (name in optional_inputs) is_optional = (name in optional_inputs)
if is_fwd_input: if is_fwd_input:
need_input_data = "false" if name in self.no_need_buffers else "true"
if is_optional: if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else: else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, {need_input_data});" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);"
set_input_tensor_wrappers_list.append(set_tensor_wrappers) set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: else:
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册