未验证 提交 911cb2ea 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Support NoNeedBuffer for final state codegen (#39628)

* Support NoNeedBuffer for final state codegen

* Replaced pten with phi
上级 8d1d0bdf
......@@ -127,6 +127,15 @@ def ReadBwdFile(filepath):
######################
### Yaml Parsers ###
######################
def ParseNoNeedBuffer(string):
# string: "x, y"
no_need_buffer_set = set()
for name in string.split(","):
no_need_buffer_set.add(name.strip())
return no_need_buffer_set
def ParseYamlArgs(string):
# Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y
......@@ -397,7 +406,7 @@ def SlotNameMatching(backward_inputs_list, backward_returns_list,
def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
backward_attrs_list):
backward_attrs_list, no_need_buffer_set):
# Inputs:
# fwd_api_name = ""
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
......@@ -410,15 +419,20 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set:
no_need_buffer = "true"
else:
no_need_buffer = "false"
tensor_wrapper_name = GetSavedName(tname)
if IsPlainTensorType(ttype):
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved);
{} = egr::TensorWrapper({}, full_reserved, {});
}}
"""
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname)
tname, tname, tensor_wrapper_name, tname, no_need_buffer)
PLAIN_TENSOR_MEMBER_TEMPLATE = """
egr::TensorWrapper {};
......@@ -430,12 +444,12 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
for(const auto& eager_tensor : {}) {{
{}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved) );
{}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) );
}};
}}
"""
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name)
tname, tname, tname, tensor_wrapper_name, no_need_buffer)
VECTOR_TENSOR_MEMBER_TEMPLATE = """
std::vector<egr::TensorWrapper> {};
......@@ -997,6 +1011,10 @@ if __name__ == "__main__":
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api['no_need_buffer'])
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
......@@ -1062,7 +1080,8 @@ if __name__ == "__main__":
# Node Declaration Generation
node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list)
fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str)
node_definition_str += GenerateNodeDefinition(
......
......@@ -34,7 +34,8 @@ class TensorWrapper {
public:
TensorWrapper() = default;
explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
bool full_reserved = false) {
bool full_reserved = false,
bool no_need_buffer = false) {
/**
* Normally, we should fully reserved all non-output or non-leaf fwd tensor
* here. And for fwd output tensor, we should not reserve its autogradmeta,
......@@ -48,7 +49,22 @@ class TensorWrapper {
}
// shallow copy tensor_impl here
if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(tensor.impl().get());
auto tw_dense_tensor = std::make_shared<phi::DenseTensor>();
tw_dense_tensor->set_meta(dense_tensor->meta());
intermidiate_tensor_.set_impl(tw_dense_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
}
} else {
intermidiate_tensor_.set_impl(tensor.impl());
}
intermidiate_tensor_.set_name(tensor.name() + "@Saved");
PADDLE_ENFORCE_NOT_NULL(
EagerUtils::unsafe_autograd_meta(tensor),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册