From 97cd0f51d6a24c7aff7946a6b2cbf3eea377a424 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 30 Mar 2022 10:45:56 +0800 Subject: [PATCH] Refactor code auto-gene for no_need_buffer (#41025) * refactor code auto-gene for no_need_buffer * fix some bug * delete test code --- .../final_state_generator/codegen_utils.py | 6 +-- .../final_state_generator/eager_gen.py | 18 ++++---- paddle/phi/api/lib/CMakeLists.txt | 2 +- python/paddle/utils/code_gen/api_base.py | 46 +++++++++++-------- .../paddle/utils/code_gen/backward_api_gen.py | 24 ++++++++++ .../utils/code_gen/sparse_bw_api_gen.py | 3 ++ 6 files changed, 66 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 89939a68f2..62ba3ee503 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -334,10 +334,10 @@ class FunctionGeneratorBase: self.inplace_map[key] = val def ParseNoNeedBuffer(self): - forward_api_contents = self.forward_api_contents + grad_api_contents = self.grad_api_contents - if 'no_need_buffer' in forward_api_contents.keys(): - no_need_buffer_str = forward_api_contents['no_need_buffer'] + if 'no_need_buffer' in grad_api_contents.keys(): + no_need_buffer_str = grad_api_contents['no_need_buffer'] for name in no_need_buffer_str.split(","): name = name.strip() name = RemoveSpecialSymbolsInName(name) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 3f91abdebb..7339f3581a 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1,11 +1,11 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -56,7 +56,7 @@ def ParseArguments(): ######################## SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \ """ - void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ + void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ {} = egr::TensorWrapper({}, full_reserved, {}); }} """ @@ -121,19 +121,19 @@ NODE_DECLARATION_TEMPLATE = \ virtual std::vector> operator()( std::vector>& grads, bool create_graph = false) override; std::string name() override {{ return \" {} \"; }} - + void ClearTensorWrappers() override {{ {} is_tensor_wrappers_cleared = true; }} - + // SetTensorWrapperX, SetTensorWrapperY, ... {} // SetAttributes {} bool IsTensorWrappersCleared() override {{ - return is_tensor_wrappers_cleared; + return is_tensor_wrappers_cleared; }} private: // TensorWrappers @@ -192,7 +192,7 @@ FORWARD_BODY_TEMPLATE = \ if(require_any_grad) {{ {} egr::EagerUtils::PassStopGradient({}); - + // Node Construction {} // SetAttributes @@ -379,7 +379,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): #self.forward_outputs_position_map #self.optional_inputs #self.no_need_buffers - #self.intermediate_outputs + #self.intermediate_outputs #self.inplace_map FunctionGeneratorBase.__init__(self, forward_api_contents, namespace) diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index cd525368e5..af25330191 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -169,7 +169,7 @@ cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kerne cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) -cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl) +cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils) cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl) cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl) cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api) diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 438e6f788e..1c58334794 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -315,6 +315,31 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self return api_declaration + # Backward API Override this method + def gene_kernel_backend_select(self): + backend_select_code = "" + if self.kernel['backend'] is not None: + if '>' in self.kernel['backend']: + vars_list = self.kernel['backend'].split('>') + assert len( + vars_list + ) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}." + assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \ + f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type." + backend_select_code = f""" + kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); +""" + + else: + backend_args = [ + ele.strip() for ele in self.kernel['backend'].split(',') + ] + backend_select_code = f""" + kernel_backend = ParseBackend({", ".join(backend_args)}); +""" + + return backend_select_code + def gene_kernel_select(self) -> str: api = self.api input_names = self.inputs['names'] @@ -345,26 +370,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self attr_data_type_count = attr_data_type_count + 1 # preprocess kernel configures - kernel_select_code = "" - if kernel['backend'] is not None: - if '>' in kernel['backend']: - vars_list = kernel['backend'].split('>') - assert len( - vars_list - ) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}." - assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \ - f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type." - kernel_select_code = kernel_select_code + f""" - kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); -""" - - else: - args_str = "" - for ele in kernel['backend'].split(','): - args_str = args_str + ele.strip() + ', ' - kernel_select_code = kernel_select_code + f""" - kernel_backend = ParseBackend({args_str[:-2]}); -""" + kernel_select_code = self.gene_kernel_backend_select() if kernel['layout'] is not None: if '>' in kernel['layout']: diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index 69631b574d..bf3c775236 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -24,6 +24,7 @@ class BackwardAPI(BaseAPI): def __init__(self, backward_item_yaml): super(BackwardAPI, self).__init__(backward_item_yaml) self.check_args(backward_item_yaml['forward']) + self.no_need_buffer = self.parse_no_need_buffer(backward_item_yaml) def get_api_name(self, api_item_yaml): return api_item_yaml['backward_api'] @@ -41,6 +42,15 @@ class BackwardAPI(BaseAPI): return api, fw_inputs, fw_attrs, outputs + def parse_no_need_buffer(self, api_item_yaml): + no_need_buffer = [] + if 'no_need_buffer' in api_item_yaml: + no_need_buffer = [ + item.strip() + for item in api_item_yaml['no_need_buffer'].split(',') + ] + return no_need_buffer + def check_args(self, forward_config): # parse the forward and backward config _, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config( @@ -67,6 +77,19 @@ class BackwardAPI(BaseAPI): f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \ Please check the output of {self.api} in yaml." + def gene_kernel_backend_select(self): + all_no_need_buffer = True + for in_name in self.inputs['names']: + if in_name not in self.no_need_buffer: + all_no_need_buffer = False + + if all_no_need_buffer: + return """ + kernel_backend = ParseBackend(egr::Controller::Instance().GetExpectedPlace()); +""" + else: + return super().gene_kernel_backend_select() + def get_return_type(self, out_type_list): return out_type_list[0] if len( out_type_list) == 1 else "std::vector>" @@ -154,6 +177,7 @@ def source_include(header_file_path): #include "paddle/phi/api/include/api.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" """ diff --git a/python/paddle/utils/code_gen/sparse_bw_api_gen.py b/python/paddle/utils/code_gen/sparse_bw_api_gen.py index 5e30f509ea..9f74cf9ad5 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api_gen.py +++ b/python/paddle/utils/code_gen/sparse_bw_api_gen.py @@ -28,6 +28,9 @@ class SparseBackwardAPI(SparseAPI, BackwardAPI): def get_api_func_name(self): return self.api + def gene_kernel_backend_select(self): + return BackwardAPI.gene_kernel_backend_select(self) + def get_return_type(self, out_type_list): return BackwardAPI.get_return_type(self, out_type_list) -- GitLab