提交 c43f8afb 编写于 作者: P phlrain

fix generator bug;

上级 b6a94760
...@@ -760,12 +760,10 @@ def GenerateNodeCreationCodes( ...@@ -760,12 +760,10 @@ def GenerateNodeCreationCodes(
# SetTensorWrappers # SetTensorWrappers
set_tensor_wrappers_list = [] set_tensor_wrappers_list = []
fwd_api_input_num = 0
for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items(): for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs) is_optional = (name in optional_inputs)
if is_fwd_input: if is_fwd_input:
fwd_api_input_num += 1
if is_optional: if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
...@@ -777,9 +775,7 @@ def GenerateNodeCreationCodes( ...@@ -777,9 +775,7 @@ def GenerateNodeCreationCodes(
fwd_output_pos = forward_outputs_position_map[name][1] fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)" tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else: else:
assert IsPlainTensorType(atype), atype tw_name = f"api_result"
out_pos = pos - fwd_api_input_num
tw_name = f"std::get<{out_pos}>(api_result)"
if is_optional: if is_optional:
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册