未验证 提交 1efc80c6 编写于 作者: J Jiabin Yang 提交者: GitHub

fix deriv with inplace (#43930)

上级 77d75aa4
......@@ -152,17 +152,22 @@ paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallV
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{}
// Prepare Grad function call
{}
// Get GradIn autograd_meta
{}
// Compute Require Grad
{}
// Inplace Check
{}
// Inplace Strategy
{}
// Call grad_api function
VLOG(3) << \"Final State Running: {}\";
{}
// Check NaN and Inf id needed
{}
// Get GradIn autograd_meta
{}
// Get GradOut autograd_meta
{}
// Compute Require Grad
{}
// Create Grad Node
{}
......@@ -1219,6 +1224,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_grad_node_creation_str = ""
next_grad_node_out_list = []
next_node_generator = None
if next_grad_api_contents:
# Fake forward_api_contents and backward_api_contents
forward_api_contents = grad_api_contents
......@@ -1229,12 +1235,15 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
next_grad_node_creation_str = next_node_generator.node_creation_str
next_grad_node_out_list = next_node_generator.grad_node_out_list
self.RecordGrad2NextGradNameMapping(next_node_generator)
return next_grad_node_creation_str, next_grad_node_out_list
if next_node_generator is not None:
return next_grad_node_creation_str, next_grad_node_out_list, next_node_generator.backward_forward_inputs_map
else:
return next_grad_node_creation_str, next_grad_node_out_list, None
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
......@@ -1296,7 +1305,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
attribute_members_str)
def GenerateNodeDefinition(self, next_grad_node_creation_str,
next_grad_node_out_list):
next_grad_node_out_list,
backward_forward_inputs_map_next):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
......@@ -1330,6 +1340,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n"
inplace_grad_input_str = ""
inplaced_tensor_wrapper = False
inplace_check_str = ""
# Grad Ins from TensorWrappers
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
......@@ -1340,7 +1352,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});"
if backward_inplace_map and name in backward_inplace_map.keys():
tensor_wrapper_intermidiate_tensor_str = f"(&this->{tensor_wrapper_name})->get_intermidiate_tensor()"
tensor_wrapper_recover_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
inplace_check_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name, name,
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name,
......@@ -1359,6 +1371,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
get_grad_in_args_list.append(tensor_wrapper_recover_str)
optional_inplace_check = False
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
......@@ -1370,8 +1383,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Inplace in backward op
if backward_inplace_map and name in backward_inplace_map.keys():
if len(next_grad_node_creation_str) > 0:
if (transformed_tensor_name
in backward_forward_inputs_map_next) and (
backward_forward_inputs_map_next[
transformed_tensor_name][1]):
optional_inplace_check = False
grads_tensor_str = f"grads[{fwd_position}][0]"
get_tensor_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
inplace_check_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name, name,
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name,
......@@ -1406,14 +1425,15 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Grad Function Call String
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
grad_api_namespace = f"paddle::experimental::{namespace}"
grad_function_call_str = f"""
grad_function_prepare_str = f"""
const auto& out_metas = OutputMeta();
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> returns({slot_num_bwd_outputs});
for (int i = 0; i < {slot_num_bwd_outputs}; ++i) {{
out_metas[i].size() == 0 ? returns[i].resize(1) : returns[i].resize(out_metas[i].size());
}}
"""
inplace_for_grad_outs_str = ""
optional_inplace_str = ""
# Grad Outputs
out_index = -1
for name, (ttype, fwd_position,
......@@ -1421,22 +1441,35 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
transformed_tensor_name = self.TransformToNextGradName(name)
out_index = out_index + 1
grad_api_args.append(f"api_output_{out_index}")
if not optional_inplace_check:
optional_inplace_str = "VLOG(6) << \"No Inplace should happend for wrappered input\";"
else:
optional_inplace_str = f"""if (api_output_{out_index} != nullptr && can_be_inplaced) {{
egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
}}"""
if IsPlainTensorType(ttype):
inplace_for_grad_outs_str = ""
if backward_inplace_map and name in backward_inplace_map.values(
):
inplace_for_grad_outs_str = f"""
{indent}if (api_output_{out_index} != nullptr && can_be_inplaced) {{
{indent} egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
{indent}}}"""
inplace_str = f"""if (api_output_{out_index} != nullptr && can_be_inplaced) {{
egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
}}"""
if len(next_grad_node_creation_str) > 0:
inplace_for_grad_outs_str += f"""
if (!require_any_grad) {{
{inplace_str}
}}else{{
{optional_inplace_str}
}}"""
else:
inplace_for_grad_outs_str += inplace_str
grad_function_call_str += f"""
auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];{inplace_for_grad_outs_str}"""
grad_function_prepare_str += f"""
auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];"""
else:
assert IsVectorTensorType(ttype)
grad_function_call_str += f"""
grad_function_prepare_str += f"""
std::vector<paddle::experimental::Tensor*> api_output_{out_index};
api_output_{out_index}.reserve(returns[{fwd_position}].size());
for (size_t i = 0; i < returns[{fwd_position}].size(); ++i) {{
......@@ -1449,7 +1482,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_args_str = ", ".join(grad_api_args)
grad_function_call_str = grad_function_call_str + f"""
grad_function_call_str = f"""
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
# Check Nan and Inf
......@@ -1542,9 +1575,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name = GetGradNodeName(self.backward_api_name)
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, check_nan_inf_str, inputs_autograd_meta_str,
outputs_autograd_meta_str, compute_require_grad_str,
grad_node_name, fill_zero_str, get_grad_in_args_str,
grad_function_prepare_str, inputs_autograd_meta_str,
compute_require_grad_str, inplace_check_str,
inplace_for_grad_outs_str, grad_node_name, grad_function_call_str,
check_nan_inf_str, outputs_autograd_meta_str,
next_grad_node_creation_str, returns_str)
def run(self):
......@@ -1556,13 +1591,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ##
#####################
# Higher-order GradNode generation
next_grad_node_creation_str, next_grad_node_out_list = self.GenerateHigherOrderNodeCreationCode(
next_grad_node_creation_str, next_grad_node_out_list, backward_forward_inputs_map = self.GenerateHigherOrderNodeCreationCode(
)
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition(next_grad_node_creation_str,
next_grad_node_out_list)
next_grad_node_out_list,
backward_forward_inputs_map)
class DygraphForwardAndNodesGenerator(GeneratorBase):
......
......@@ -95,8 +95,6 @@ void GradTensorHolder::add(size_t slot_id,
paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() "
"which exceeds size of buffer"));
VLOG(6) << "Add Tensor for buffer_ slot: " << slot_id
<< ", size: " << buffer_[slot_id].size();
if (buffer_[slot_id].empty()) {
VLOG(6) << "Pass add Tensor for buffer_ slot: " << slot_id
<< " since its buffer_ is empty ";
......@@ -119,8 +117,12 @@ void GradTensorHolder::add(size_t slot_id,
// framework::Variable is initialized.
if ((!buffer_tensor.defined() || !buffer_tensor.initialized())) {
// Simply copy tensor->impl
VLOG(6) << "Move Tensor for buffer_ slot: " << slot_id
<< ", size: " << buffer_[slot_id].size();
buffer_tensor = t;
} else {
VLOG(6) << "Add Tensor for buffer_ slot: " << slot_id
<< ", size: " << buffer_[slot_id].size();
// Accumulation
PADDLE_ENFORCE_EQ(t.initialized(),
true,
......
......@@ -19,6 +19,7 @@
#include "gtest/gtest.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/selected_rows.h"
......@@ -76,7 +77,7 @@ TEST(GradTensorHolder, Interfaces) {
std::vector<GradSlotMeta> slot_meta(1);
GradTensorHolder grad_tensor_holder =
GradTensorHolder({slot_meta, slot_meta});
egr::EagerUtils::autograd_meta(&et0);
// add():
// fill one
grad_tensor_holder.CopyValueFromTensor(0, 0, et0, true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册