未验证 提交 97cd0f51 编写于 作者: Z zyfncg 提交者: GitHub

Refactor code auto-gene for no_need_buffer (#41025)

* refactor code auto-gene for no_need_buffer

* fix some bug

* delete test code
上级 9219495c
...@@ -334,10 +334,10 @@ class FunctionGeneratorBase: ...@@ -334,10 +334,10 @@ class FunctionGeneratorBase:
self.inplace_map[key] = val self.inplace_map[key] = val
def ParseNoNeedBuffer(self): def ParseNoNeedBuffer(self):
forward_api_contents = self.forward_api_contents grad_api_contents = self.grad_api_contents
if 'no_need_buffer' in forward_api_contents.keys(): if 'no_need_buffer' in grad_api_contents.keys():
no_need_buffer_str = forward_api_contents['no_need_buffer'] no_need_buffer_str = grad_api_contents['no_need_buffer']
for name in no_need_buffer_str.split(","): for name in no_need_buffer_str.split(","):
name = name.strip() name = name.strip()
name = RemoveSpecialSymbolsInName(name) name = RemoveSpecialSymbolsInName(name)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -56,7 +56,7 @@ def ParseArguments(): ...@@ -56,7 +56,7 @@ def ParseArguments():
######################## ########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
""" """
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved, {}); {} = egr::TensorWrapper({}, full_reserved, {});
}} }}
""" """
...@@ -121,19 +121,19 @@ NODE_DECLARATION_TEMPLATE = \ ...@@ -121,19 +121,19 @@ NODE_DECLARATION_TEMPLATE = \
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override; std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }} std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
{} {}
is_tensor_wrappers_cleared = true; is_tensor_wrappers_cleared = true;
}} }}
// SetTensorWrapperX, SetTensorWrapperY, ... // SetTensorWrapperX, SetTensorWrapperY, ...
{} {}
// SetAttributes // SetAttributes
{} {}
bool IsTensorWrappersCleared() override {{ bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared; return is_tensor_wrappers_cleared;
}} }}
private: private:
// TensorWrappers // TensorWrappers
...@@ -192,7 +192,7 @@ FORWARD_BODY_TEMPLATE = \ ...@@ -192,7 +192,7 @@ FORWARD_BODY_TEMPLATE = \
if(require_any_grad) {{ if(require_any_grad) {{
{} {}
egr::EagerUtils::PassStopGradient({}); egr::EagerUtils::PassStopGradient({});
// Node Construction // Node Construction
{} {}
// SetAttributes // SetAttributes
...@@ -379,7 +379,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -379,7 +379,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
#self.forward_outputs_position_map #self.forward_outputs_position_map
#self.optional_inputs #self.optional_inputs
#self.no_need_buffers #self.no_need_buffers
#self.intermediate_outputs #self.intermediate_outputs
#self.inplace_map #self.inplace_map
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace) FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
......
...@@ -169,7 +169,7 @@ cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kerne ...@@ -169,7 +169,7 @@ cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kerne
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl) cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils)
cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl) cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl)
cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl) cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl)
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api) cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api)
......
...@@ -315,6 +315,31 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -315,6 +315,31 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
return api_declaration return api_declaration
# Backward API Override this method
def gene_kernel_backend_select(self):
backend_select_code = ""
if self.kernel['backend'] is not None:
if '>' in self.kernel['backend']:
vars_list = self.kernel['backend'].split('>')
assert len(
vars_list
) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
backend_select_code = f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
else:
backend_args = [
ele.strip() for ele in self.kernel['backend'].split(',')
]
backend_select_code = f"""
kernel_backend = ParseBackend({", ".join(backend_args)});
"""
return backend_select_code
def gene_kernel_select(self) -> str: def gene_kernel_select(self) -> str:
api = self.api api = self.api
input_names = self.inputs['names'] input_names = self.inputs['names']
...@@ -345,26 +370,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -345,26 +370,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
attr_data_type_count = attr_data_type_count + 1 attr_data_type_count = attr_data_type_count + 1
# preprocess kernel configures # preprocess kernel configures
kernel_select_code = "" kernel_select_code = self.gene_kernel_backend_select()
if kernel['backend'] is not None:
if '>' in kernel['backend']:
vars_list = kernel['backend'].split('>')
assert len(
vars_list
) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
else:
args_str = ""
for ele in kernel['backend'].split(','):
args_str = args_str + ele.strip() + ', '
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackend({args_str[:-2]});
"""
if kernel['layout'] is not None: if kernel['layout'] is not None:
if '>' in kernel['layout']: if '>' in kernel['layout']:
......
...@@ -24,6 +24,7 @@ class BackwardAPI(BaseAPI): ...@@ -24,6 +24,7 @@ class BackwardAPI(BaseAPI):
def __init__(self, backward_item_yaml): def __init__(self, backward_item_yaml):
super(BackwardAPI, self).__init__(backward_item_yaml) super(BackwardAPI, self).__init__(backward_item_yaml)
self.check_args(backward_item_yaml['forward']) self.check_args(backward_item_yaml['forward'])
self.no_need_buffer = self.parse_no_need_buffer(backward_item_yaml)
def get_api_name(self, api_item_yaml): def get_api_name(self, api_item_yaml):
return api_item_yaml['backward_api'] return api_item_yaml['backward_api']
...@@ -41,6 +42,15 @@ class BackwardAPI(BaseAPI): ...@@ -41,6 +42,15 @@ class BackwardAPI(BaseAPI):
return api, fw_inputs, fw_attrs, outputs return api, fw_inputs, fw_attrs, outputs
def parse_no_need_buffer(self, api_item_yaml):
no_need_buffer = []
if 'no_need_buffer' in api_item_yaml:
no_need_buffer = [
item.strip()
for item in api_item_yaml['no_need_buffer'].split(',')
]
return no_need_buffer
def check_args(self, forward_config): def check_args(self, forward_config):
# parse the forward and backward config # parse the forward and backward config
_, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config( _, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config(
...@@ -67,6 +77,19 @@ class BackwardAPI(BaseAPI): ...@@ -67,6 +77,19 @@ class BackwardAPI(BaseAPI):
f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \ f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
Please check the output of {self.api} in yaml." Please check the output of {self.api} in yaml."
def gene_kernel_backend_select(self):
all_no_need_buffer = True
for in_name in self.inputs['names']:
if in_name not in self.no_need_buffer:
all_no_need_buffer = False
if all_no_need_buffer:
return """
kernel_backend = ParseBackend(egr::Controller::Instance().GetExpectedPlace());
"""
else:
return super().gene_kernel_backend_select()
def get_return_type(self, out_type_list): def get_return_type(self, out_type_list):
return out_type_list[0] if len( return out_type_list[0] if len(
out_type_list) == 1 else "std::vector<std::vector<Tensor>>" out_type_list) == 1 else "std::vector<std::vector<Tensor>>"
...@@ -154,6 +177,7 @@ def source_include(header_file_path): ...@@ -154,6 +177,7 @@ def source_include(header_file_path):
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
""" """
......
...@@ -28,6 +28,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): ...@@ -28,6 +28,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI):
def get_api_func_name(self): def get_api_func_name(self):
return self.api return self.api
def gene_kernel_backend_select(self):
return BackwardAPI.gene_kernel_backend_select(self)
def get_return_type(self, out_type_list): def get_return_type(self, out_type_list):
return BackwardAPI.get_return_type(self, out_type_list) return BackwardAPI.get_return_type(self, out_type_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册