未验证 提交 6bad3009 编写于 作者: Z zhangbo9674 提交者: GitHub

fix setTensorWrapper with no_need_buffers (#41892)

上级 b3959fe4
......@@ -730,11 +730,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
is_optional = (name in optional_inputs)
if is_fwd_input:
need_input_data = "false" if name in self.no_need_buffers else "true"
if is_optional:
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, {need_input_data});"
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else:
if num_fwd_outputs > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册