未验证 提交 e5ebd347 编写于 作者: P pangyoki 提交者: GitHub

support backward inplace for eager dygraph mode (#42795)

* support inplace in backward

* fix final_state_linear

* fix format of backward_inplace_map

* little change

* add subtract in yaml

* fix hook mem leak

* fix hook use_count

* little format change

* fix
Co-authored-by: NJiabinYang <360788950@qq.com>
上级 2cb61405
......@@ -307,6 +307,23 @@ def ParseYamlBackward(args_str, returns_str):
return inputs_list, attrs_list, returns_list
def ParseYamlInplaceInfo(string):
# inplace_map_str: "(x -> out0), (y -> out2)"
inplace_map = {}
for pair in string.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]
if pair.endswith(")"):
pair = pair[:-1]
key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
inplace_map[key] = val
return inplace_map
########################
### Generator Base ###
########################
......@@ -334,25 +351,14 @@ class FunctionGeneratorBase:
self.optional_inputs = [] #[name, ...]
self.no_need_buffers = [] #[name, ...]
self.intermediate_outputs = [] #[name, ...]
self.inplace_map = {} #{name : name, ...}
self.forward_inplace_map = {} #{name : name, ...}
def ParseInplaceInfo(self):
def ParseForwardInplaceInfo(self):
forward_api_contents = self.forward_api_contents
if 'inplace' not in forward_api_contents.keys(): return
# inplace_map_str: "(x -> out0), (y -> out2)"
inplace_map_str = forward_api_contents['inplace']
for pair in inplace_map_str.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]
if pair.endswith(")"):
pair = pair[:-1]
key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
self.inplace_map[key] = val
self.forward_inplace_map = ParseYamlInplaceInfo(inplace_map_str)
def ParseNoNeedBuffer(self):
grad_api_contents = self.grad_api_contents
......
......@@ -29,6 +29,7 @@ from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFu
from codegen_utils import GetInplacedFunctionName
from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import ParseYamlInplaceInfo
from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage, GetIndent
......@@ -347,6 +348,16 @@ CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = \
if( {}.impl() ) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""
CHECK_BACKWARD_INPLACE_TEMPLATE = \
"""
bool can_be_inplaced = false;
if ({}.initialized()) {{
VLOG(10) << {}.name() << "({}) use_count: " << {}.impl().use_count();
if ({}.impl().use_count() == 1 || ({}.impl().use_count() == 2 && {}.impl().get() == {}.impl().get())) {{
can_be_inplaced = true;
}}
}}"""
CHECK_NAN_AND_INF_TEMPLATE = \
""" if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
"""
......@@ -407,7 +418,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
#self.optional_inputs
#self.no_need_buffers
#self.intermediate_outputs
#self.inplace_map
#self.forward_inplace_map
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
self.grad_api_contents = grad_api_contents
......@@ -438,6 +449,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.backward_grad_outputs_map = {
} #{ "name" : [type, fwd_position, orig_position] ...}
self.backward_inplace_map = {} #{name : name, ...}
def ParseBackwardInplaceInfo(self):
grad_api_contents = self.grad_api_contents
if 'inplace' not in grad_api_contents.keys(): return
inplace_map_str = grad_api_contents['inplace']
self.backward_inplace_map = ParseYamlInplaceInfo(inplace_map_str)
def DygraphYamlValidationCheck(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents
......@@ -777,8 +797,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
##########################
## Parsing Raw Contents ##
##########################
# Parse inplace_map
self.ParseInplaceInfo()
# Parse forward and backward inplace_map
self.ParseForwardInplaceInfo()
self.ParseBackwardInplaceInfo()
# Parse no_need_buffer
self.ParseNoNeedBuffer()
......@@ -837,7 +858,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map if is_inplaced else {}
forward_inplace_map = self.forward_inplace_map if is_inplaced else {}
indent = GetIndent(1)
# Get Function Args
......@@ -869,7 +890,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(NEW_{name}_temp_tensor) : {name};\n"
)
else:
if is_inplaced and inplace_map and name in inplace_map.keys(
if is_inplaced and forward_inplace_map and name in forward_inplace_map.keys(
):
arg_str = f"paddle::experimental::Tensor& {name}"
amp_tensors_vector_list.append(f"{{{name}}}")
......@@ -944,13 +965,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
returns_list[pos] = f"{name}"
if IsPlainTensorType(rtype):
if is_inplaced and inplace_map and name in inplace_map.values():
if is_inplaced and forward_inplace_map and name in forward_inplace_map.values(
):
returns_type_list[pos] = "paddle::experimental::Tensor&"
else:
returns_type_list[pos] = "paddle::experimental::Tensor"
else:
assert IsVectorTensorType(rtype)
if is_inplaced and inplace_map and name in inplace_map.values():
if is_inplaced and forward_inplace_map and name in forward_inplace_map.values(
):
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>&"
else:
......@@ -1014,7 +1037,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
for inplace_name in forward_inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
......@@ -1258,6 +1281,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
backward_inplace_map = self.backward_inplace_map
indent = GetIndent(1)
# Construct grad_api function args
......@@ -1282,6 +1306,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
else:
fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n"
inplace_grad_input_str = ""
# Grad Ins from TensorWrappers
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
......@@ -1290,6 +1315,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
is_optional = (name in self.optional_inputs)
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});"
if backward_inplace_map and name in backward_inplace_map.keys():
tensor_wrapper_intermidiate_tensor_str = f"(&this->{tensor_wrapper_name})->get_intermidiate_tensor()"
tensor_wrapper_recover_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name, name,
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name,
tensor_wrapper_intermidiate_tensor_str)
inplace_grad_input_str = transformed_tensor_name
if is_optional:
tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name,
......@@ -1312,6 +1345,16 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if IsPlainTensorType(ttype):
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"
# Inplace in backward op
if backward_inplace_map and name in backward_inplace_map.keys():
grads_tensor_str = f"grads[{fwd_position}][0]"
get_tensor_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name, name,
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name,
grads_tensor_str)
inplace_grad_input_str = transformed_tensor_name
if is_optional:
get_tensor_str += "\n" + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name,
......@@ -1357,8 +1400,16 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_args.append(f"api_output_{out_index}")
if IsPlainTensorType(ttype):
inplace_for_grad_outs_str = ""
if backward_inplace_map and name in backward_inplace_map.values(
):
inplace_for_grad_outs_str = f"""
{indent}if (api_output_{out_index} != nullptr && can_be_inplaced) {{
{indent} egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
{indent}}}"""
grad_function_call_str += f"""
auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];"""
auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];{inplace_for_grad_outs_str}"""
else:
assert IsVectorTensorType(ttype)
......
......@@ -259,7 +259,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
#self.optional_inputs
#self.no_need_buffers
#self.intermediate_outputs
#self.inplace_map
#self.forward_inplace_map
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
self.is_forward_only = True
......@@ -275,7 +275,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
def GeneratePythonCFunction(self):
namespace = self.namespace
inplace_map = self.inplace_map
forward_inplace_map = self.forward_inplace_map
forward_api_name = self.forward_api_name
orig_forward_attrs_list = self.orig_forward_attrs_list
forward_inputs_position_map = self.forward_inputs_position_map
......@@ -359,7 +359,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name_prefix, forward_api_name, namespace,
forward_api_name, forward_api_name)
if inplace_map:
if forward_inplace_map:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
if is_forward_only:
......@@ -372,9 +372,9 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
GetForwardFunctionName(inplaced_forward_api_name))
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
forward_inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(forward_inplace_map)}"
for inplace_input, inplace_output in forward_inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input,
inplaced_forward_api_name, inplace_output)
......@@ -401,8 +401,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Initialized optional_inputs
self.ParseDispensable()
# Initialized inplace_map
self.ParseInplaceInfo()
# Initialized forward_inplace_map
self.ParseForwardInplaceInfo()
# Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list
self.CollectOriginalForwardInfo()
......
......@@ -116,6 +116,10 @@ class TensorWrapper {
return recovered_tensor;
}
paddle::experimental::Tensor get_intermidiate_tensor() {
return intermidiate_tensor_;
}
void clear() { intermidiate_tensor_.reset(); }
private:
......
......@@ -271,6 +271,33 @@ void EagerUtils::HandleViewBetweenInputAndOutput(
}
}
void EagerUtils::HandleViewBetweenInputAndOutput(
const paddle::experimental::Tensor& input_tensor,
paddle::experimental::Tensor* view_output_tensor) {
PADDLE_ENFORCE_EQ(
input_tensor.initialized(), true,
paddle::platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", input_tensor.name()));
if (input_tensor.is_dense_tensor()) {
auto input_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input_tensor.impl());
if (view_output_tensor->impl() == nullptr) {
view_output_tensor->set_impl(std::make_shared<phi::DenseTensor>());
}
auto view_output_dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(view_output_tensor->impl());
view_output_dense_tensor->ShareBufferWith(*input_dense_tensor);
view_output_dense_tensor->ShareInplaceVersionCounterWith(
*input_dense_tensor);
VLOG(3) << "Perform View between Output Tensor("
<< view_output_tensor->name() << ") and Input Tensor("
<< input_tensor.name()
<< "), share allocation and inplace version.";
}
}
std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
const std::vector<std::shared_ptr<EagerVariable>>& outs) {
std::vector<paddle::experimental::Tensor> res;
......
......@@ -172,6 +172,9 @@ class EagerUtils {
static void HandleViewBetweenInputAndOutput(
const std::shared_ptr<EagerVariable>& input_var,
const std::shared_ptr<EagerVariable>& view_output_var);
static void HandleViewBetweenInputAndOutput(
const paddle::experimental::Tensor& input_tensor,
paddle::experimental::Tensor* view_output_tensor);
// TensorWrapper Utils
static paddle::experimental::Tensor RecoverTensorWrapper(TensorWrapper* tw);
......
......@@ -3169,7 +3169,7 @@ def reshape(x, shape, name=None):
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = _C_ops.reshape2(x, None, 'shape', shape)
out = _C_ops.final_state_reshape(x, shape)
elif isinstance(shape, tmp_tensor_type):
shape.stop_gradient = True
out, _ = _C_ops.reshape2(x, shape)
......
......@@ -66,6 +66,7 @@
func : add_grad
no_need_buffer : x, y
backward : add_double_grad
inplace : (out_grad -> x_grad)
- backward_api : add_n_grad
forward : add_n (Tensor[] x) -> Tensor(out)
......@@ -383,6 +384,7 @@
kernel :
func : cross_entropy_with_softmax_grad
data_type : softmax
inplace : (softmax -> input_grad)
- backward_api : cross_grad
forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out)
......@@ -646,6 +648,7 @@
data_type: out_grad
backend: out_grad
layout: out_grad
inplace : (out_grad -> x_grad)
- backward_api : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
......@@ -1492,6 +1495,7 @@
backend: out_grad
layout: out_grad
backward : reshape_double_grad
inplace : (out_grad -> x_grad)
- backward_api : roi_align_grad
forward : roi_align (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned) -> Tensor(out)
......@@ -1563,6 +1567,7 @@
output : Tensor(x_grad)
invoke : scale(out_grad, scale, 0.0, bias_after_scale)
backward : scale_double_grad
inplace : (out_grad -> x_grad)
- backward_api : scale_triple_grad
forward : scale_double_grad (Tensor grad_grad_x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(grad_grad_out)
......@@ -1755,6 +1760,7 @@
param: [xshape]
kernel :
func : squeeze_grad
inplace : (out_grad -> x_grad)
- backward_api : stack_grad
forward : stack (Tensor[] x, int axis) -> Tensor(out)
......@@ -1802,6 +1808,7 @@
func : subtract_grad
no_need_buffer : x, y
backward : subtract_double_grad
inplace : (out_grad -> x_grad)
- backward_api : sum_double_grad
forward : sum_grad (Tensor x, Tensor grad_out, int64_t[] dims, bool keep_dim, bool reduce_all=false) -> Tensor(grad_x)
......@@ -2025,6 +2032,7 @@
param: [xshape]
kernel :
func : unsqueeze_grad
inplace : (out_grad -> x_grad)
- backward_api : where_grad
forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册