未验证 提交 911cb2ea 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Support NoNeedBuffer for final state codegen (#39628)

* Support NoNeedBuffer for final state codegen

* Replaced pten with phi
上级 8d1d0bdf
...@@ -127,6 +127,15 @@ def ReadBwdFile(filepath): ...@@ -127,6 +127,15 @@ def ReadBwdFile(filepath):
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
def ParseNoNeedBuffer(string):
# string: "x, y"
no_need_buffer_set = set()
for name in string.split(","):
no_need_buffer_set.add(name.strip())
return no_need_buffer_set
def ParseYamlArgs(string): def ParseYamlArgs(string):
# Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y # Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y
...@@ -397,7 +406,7 @@ def SlotNameMatching(backward_inputs_list, backward_returns_list, ...@@ -397,7 +406,7 @@ def SlotNameMatching(backward_inputs_list, backward_returns_list,
def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
backward_attrs_list): backward_attrs_list, no_need_buffer_set):
# Inputs: # Inputs:
# fwd_api_name = "" # fwd_api_name = ""
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...} # backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
...@@ -410,15 +419,20 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -410,15 +419,20 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
set_tensor_wrapper_methods_str = "" set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = "" tensor_wrapper_members_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items(): for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set:
no_need_buffer = "true"
else:
no_need_buffer = "false"
tensor_wrapper_name = GetSavedName(tname) tensor_wrapper_name = GetSavedName(tname)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved); {} = egr::TensorWrapper({}, full_reserved, {});
}} }}
""" """
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname) tname, tname, tensor_wrapper_name, tname, no_need_buffer)
PLAIN_TENSOR_MEMBER_TEMPLATE = """ PLAIN_TENSOR_MEMBER_TEMPLATE = """
egr::TensorWrapper {}; egr::TensorWrapper {};
...@@ -430,12 +444,12 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -430,12 +444,12 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{ void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
for(const auto& eager_tensor : {}) {{ for(const auto& eager_tensor : {}) {{
{}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved) ); {}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) );
}}; }};
}} }}
""" """
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name) tname, tname, tname, tensor_wrapper_name, no_need_buffer)
VECTOR_TENSOR_MEMBER_TEMPLATE = """ VECTOR_TENSOR_MEMBER_TEMPLATE = """
std::vector<egr::TensorWrapper> {}; std::vector<egr::TensorWrapper> {};
...@@ -997,6 +1011,10 @@ if __name__ == "__main__": ...@@ -997,6 +1011,10 @@ if __name__ == "__main__":
assert 'output' in fwd_api.keys() assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys() assert 'backward' in fwd_api.keys()
no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api['no_need_buffer'])
fwd_api_name = fwd_api['api'] fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args'] fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output'] fwd_returns_str = fwd_api['output']
...@@ -1062,7 +1080,8 @@ if __name__ == "__main__": ...@@ -1062,7 +1080,8 @@ if __name__ == "__main__":
# Node Declaration Generation # Node Declaration Generation
node_declaration_str += GenerateNodeDeclaration( node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list) fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str) print("Generated Node Declaration: ", node_declaration_str)
node_definition_str += GenerateNodeDefinition( node_definition_str += GenerateNodeDefinition(
......
...@@ -34,7 +34,8 @@ class TensorWrapper { ...@@ -34,7 +34,8 @@ class TensorWrapper {
public: public:
TensorWrapper() = default; TensorWrapper() = default;
explicit TensorWrapper(const paddle::experimental::Tensor& tensor, explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
bool full_reserved = false) { bool full_reserved = false,
bool no_need_buffer = false) {
/** /**
* Normally, we should fully reserved all non-output or non-leaf fwd tensor * Normally, we should fully reserved all non-output or non-leaf fwd tensor
* here. And for fwd output tensor, we should not reserve its autogradmeta, * here. And for fwd output tensor, we should not reserve its autogradmeta,
...@@ -48,7 +49,22 @@ class TensorWrapper { ...@@ -48,7 +49,22 @@ class TensorWrapper {
} }
// shallow copy tensor_impl here // shallow copy tensor_impl here
intermidiate_tensor_.set_impl(tensor.impl()); if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(tensor.impl().get());
auto tw_dense_tensor = std::make_shared<phi::DenseTensor>();
tw_dense_tensor->set_meta(dense_tensor->meta());
intermidiate_tensor_.set_impl(tw_dense_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
}
} else {
intermidiate_tensor_.set_impl(tensor.impl());
}
intermidiate_tensor_.set_name(tensor.name() + "@Saved"); intermidiate_tensor_.set_name(tensor.name() + "@Saved");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
EagerUtils::unsafe_autograd_meta(tensor), EagerUtils::unsafe_autograd_meta(tensor),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册