未验证 提交 cb09cf99 编写于 作者: C Charles-hit 提交者: GitHub

fix dygraph higer node creation (#47231)

* fix dygraph higer node creation

* fix eager generator

* modify generator code

* fix eager generator
上级 14536d0f
......@@ -1780,6 +1780,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name
def GenerateHigherOrderNodeCreationCode(self):
has_higher_order_node = False
namespace = self.namespace
grad_api_contents = self.grad_api_contents
forward_apis_dict = self.forward_apis_dict
......@@ -1807,14 +1808,33 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_grad_node_out_list = next_node_generator.grad_node_out_list
self.RecordGrad2NextGradNameMapping(next_node_generator)
is_invoke_forward_api = IsInvokeForwardApi(
self.grad_api_contents, self.forward_apis_dict
)
if next_node_generator is not None:
has_higher_order_node = True
return (
has_higher_order_node,
is_invoke_forward_api,
next_grad_node_creation_str,
next_grad_node_out_list,
next_node_generator.backward_forward_inputs_map,
)
else:
return next_grad_node_creation_str, next_grad_node_out_list, None
elif not is_invoke_forward_api:
next_grad_node_creation_str = f""" if(trace_backward) {{
PADDLE_THROW(phi::errors::Unavailable(
\"The Op {self.backward_api_name} doesn't have any grad\"
\"op. If you don't intend calculating higher order\"
\"derivatives, please set `create_graph`to False.\"));
}}"""
return (
has_higher_order_node,
is_invoke_forward_api,
next_grad_node_creation_str,
next_grad_node_out_list,
None,
)
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
......@@ -1906,6 +1926,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def GenerateNodeDefinition(
self,
has_higher_order_node,
is_invoke_forward_api,
next_grad_node_creation_str,
next_grad_node_out_list,
backward_forward_inputs_map_next,
......@@ -1920,9 +1942,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_inplace_map = self.backward_inplace_map
indent = GetIndent(1)
is_invoke_forward_api = IsInvokeForwardApi(
self.grad_api_contents, self.forward_apis_dict
)
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
grad_api_args_len = (
......@@ -1968,7 +1987,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
is_optional = name in self.optional_inputs
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});"
if backward_inplace_map and name in backward_inplace_map.keys():
if len(next_grad_node_creation_str) > 0:
if has_higher_order_node:
if (
transformed_tensor_name
in backward_forward_inputs_map_next
......@@ -2039,7 +2058,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Inplace in backward op
if backward_inplace_map and name in backward_inplace_map.keys():
if len(next_grad_node_creation_str) > 0:
if has_higher_order_node:
if (
transformed_tensor_name
in backward_forward_inputs_map_next
......@@ -2144,7 +2163,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
inplace_str = f""" if (api_output_{out_index} != nullptr && can_be_inplaced) {{
egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
}}"""
if len(next_grad_node_creation_str) > 0:
if has_higher_order_node:
inplace_for_grad_outs_str += f"""
if (trace_backward) {{
{optional_inplace_str}
......@@ -2236,7 +2255,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
else:
assert IsVectorTensorType(rtype)
if len(next_grad_node_creation_str) > 0:
if has_higher_order_node > 0:
output_autograd_meta = f"""
auto& {transformed_tensor_name} = returns[{pos}];
std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});
......@@ -2327,6 +2346,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
#####################
# Higher-order GradNode generation
(
has_higher_order_node,
is_invoke_forward_api,
next_grad_node_creation_str,
next_grad_node_out_list,
backward_forward_inputs_map,
......@@ -2335,6 +2356,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition(
has_higher_order_node,
is_invoke_forward_api,
next_grad_node_creation_str,
next_grad_node_out_list,
backward_forward_inputs_map,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册