From 68c9e3e4e6f1c175cd1243b0a6c20736318c81fb Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Thu, 24 Mar 2022 14:10:39 +0800 Subject: [PATCH] [Refactor] refactored eager_gen.py PR #1 (#40815) * [Refactor] refactored eager_gen.py PR #1 * [Refactor] refactored eager_gen.py PR #1 * Refactored version 2 * Added automatic code generation utils * Fixed merge issues --- .../final_state_generator/codegen_utils.py | 415 +++ .../final_state_generator/eager_gen.py | 2275 ++++++++--------- .../final_state_generator/python_c_gen.py | 200 +- 3 files changed, 1536 insertions(+), 1354 deletions(-) create mode 100644 paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py new file mode 100644 index 0000000000..6e1bee37a4 --- /dev/null +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -0,0 +1,415 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import yaml +import re +import argparse +import os + +######################## +### Global Variables ### +######################## +ops_to_fill_zero_for_empty_grads = set(list("split")) + +# For API dispatch used at python-level +# { op_name : [arg_name, ...] } +core_ops_returns_info = {} +core_ops_args_info = {} +core_ops_args_type_info = {} + +yaml_types_mapping = { + 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ + 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ + 'str' : 'std::string', \ + 'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ + 'int64[]' : 'std::vector', 'int[]' : 'std::vector', + 'Tensor' : 'Tensor', + 'Tensor[]' : 'std::vector', + 'Tensor[Tensor[]]' : 'std::vector>', + 'Scalar' : 'paddle::experimental::Scalar', + 'ScalarArray' : 'paddle::experimental::ScalarArray' +} + + +############################# +### File Reader Helpers ### +############################# +def ReadFwdFile(filepath): + f = open(filepath, 'r') + contents = yaml.load(f, Loader=yaml.FullLoader) + f.close() + return contents + + +def ReadBwdFile(filepath): + f = open(filepath, 'r') + contents = yaml.load(f, Loader=yaml.FullLoader) + ret = {} + for content in contents: + if 'backward_api' in content.keys(): + api_name = content['backward_api'] + else: + assert False + + ret[api_name] = content + f.close() + return ret + + +################################## +### Generic Helper Functions ### +################################## +def FindGradName(string): + return string + "_grad" + + +def FindForwardName(string): + if not string.endswith("_grad"): + return None + return string[:-5] + + +def IsPlainTensorType(string): + plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor'] + if string in plain_tensor_types: + return True + return False + + +def IsVectorTensorType(string): + vector_tensor_types = [ + 'std::vector>', 'std::vector' + ] + if string in vector_tensor_types: + return True + return False + + +def GetSavedName(string): + return string + "_" + + +def GetConstReference(string): + ret = string + if not string.startswith("const "): + ret = "const " + string + if not string.endswith("&"): + ret += "&" + return ret + + +def RemoveConstAndReference(string): + ret = string + if string.startswith("const "): + ret = ret[6:] + if string.endswith("&"): + ret = ret[:-1] + + return ret + + +def GetGradNodeName(string): + return f"FinalGradNode{string}" + + +def GetDygraphForwardFunctionName(string): + return f"{string}_final_state_dygraph_function" + + +def GetIntermediateAPIFunctionName(string): + return string + "_intermediate" + + +def GetAutoGradMetaName(string): + return f"{string}_autograd_meta" + + +def GetAutoGradMetaVectorName(string): + return f"{string}_autograd_meta_vec" + + +def RemoveSpecialSymbolsInName(string): + # Remove any name after '@' + ret = string.split("@")[0] + return ret + + +def RecoverBaseNameOfInplaceFunction(function_name): + return function_name[:-1] + + +def GetInplacedFunctionName(function_name): + return function_name + "_" + + +def GetForwardFunctionName(string): + return f"{string}_final_state_dygraph_function" + + +###################### +### Yaml Parsers ### +###################### +def ParseYamlArgs(string): + # Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y + + # inputs_list = [ [arg_name, arg_type, orig_position], ...] + inputs_list = [] + # attrs_list = [ [arg_name, arg_type, default_value, orig_position], ...] + attrs_list = [] + + args = [x.strip() for x in string.strip().split(",")] + atype = r'((const )?\S+) ' + aname = r'(.*)' + pattern = f'{atype}{aname}' + for i in range(len(args)): + arg = args[i] + m = re.search(pattern, arg) + arg_type = m.group(1).strip() + arg_name = m.group(3).split("=")[0].strip() + default_value = m.group(3).split("=")[1].strip() if len( + m.group(3).split("=")) > 1 else None + + assert arg_type in yaml_types_mapping.keys( + ), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping." + arg_type = yaml_types_mapping[arg_type] + + arg_name = RemoveSpecialSymbolsInName(arg_name) + if "Tensor" in arg_type: + assert default_value is None + inputs_list.append([arg_name, arg_type, i]) + else: + attrs_list.append([arg_name, arg_type, default_value, i]) + + return inputs_list, attrs_list + + +def ParseYamlReturns(string): + # Example0: Tensor(out), Tensor(out1) + # Example1: Tensor, Tensor + # Example2: Tensor[](out), Tensor + + # list = [ [ret_name, ret_type, orig_position], ...] + returns_list = [] + + returns = [x.strip() for x in string.strip().split(",")] + + for i in range(len(returns)): + ret = returns[i] + + ret_name = "" + if "(" in ret and ")" in ret: + # Remove trailing ')' + ret = ret[:-1] + ret_type = ret.split("(")[0].strip() + ret_name = ret.split("(")[1].strip() + else: + ret_type = ret.strip() + + assert ret_type in yaml_types_mapping.keys( + ), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping." + ret_type = yaml_types_mapping[ret_type] + + assert "Tensor" in ret_type + ret_name = RemoveSpecialSymbolsInName(ret_name) + returns_list.append([ret_name, ret_type, i]) + + return returns_list + + +def ParseYamlForwardFromBackward(string): + # Example: matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out) + + fname = r'(.*?)' + wspace = r'\s*' + fargs = r'(.*?)' + frets = r'(.*)' + pattern = f'{fname}{wspace}\({wspace}{fargs}{wspace}\){wspace}->{wspace}{frets}' + + m = re.search(pattern, string) + function_name = m.group(1) + function_args = m.group(2) + function_returns = m.group(3) + + forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args) + forward_returns_list = ParseYamlReturns(function_returns) + + return forward_inputs_list, forward_attrs_list, forward_returns_list + + +def ParseYamlForward(args_str, returns_str): + # args Example: (const Tensor& x, const Tensor& y, bool transpose_x = false, bool transpose_y = false) + # returns Example: Tensor, Tensor + + fargs = r'(.*?)' + wspace = r'\s*' + args_pattern = f'\({fargs}\)' + args_str = re.search(args_pattern, args_str).group(1) + + inputs_list, attrs_list = ParseYamlArgs(args_str) + returns_list = ParseYamlReturns(returns_str) + + return inputs_list, attrs_list, returns_list + + +def ParseYamlBackward(args_str, returns_str): + # args Example: (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false) + # returns Example: Tensor(x_grad), Tensor(y_grad) + + fargs = r'(.*?)' + wspace = r'\s*' + args_pattern = f'\({fargs}\)' + args_str = re.search(args_pattern, args_str).group(1) + + inputs_list, attrs_list = ParseYamlArgs(args_str) + returns_list = ParseYamlReturns(returns_str) + + return inputs_list, attrs_list, returns_list + + +######################## +### Generator Base ### +######################## +class FunctionGeneratorBase: + def __init__(self, forward_api_contents, namespace): + self.forward_api_contents = forward_api_contents + self.namespace = namespace + + self.forward_api_name = "" + + self.orig_forward_inputs_list = [ + ] #[ [arg_name, arg_type, orig_position], ...] + self.orig_forward_attrs_list = [ + ] #[ [attr_name, attr_type, default_value, orig_position], ...] + self.orig_forward_returns_list = [ + ] #[ [ret_name, ret_type, orig_position], ...] + + # Processed Forward Data + self.forward_inputs_position_map = { + } #{ "name" : [type, fwd_position] } + self.forward_outputs_position_map = { + } #{ "name" : [type, fwd_position] } + + # Special Op Attributes + self.optional_inputs = [] #[name, ...] + self.no_need_buffers = [] #[name, ...] + self.intermediate_outputs = [] #[name, ...] + self.inplace_map = {} #{name : name, ...} + + def ParseInplaceInfo(self): + forward_api_contents = self.forward_api_contents + if 'inplace' not in forward_api_contents.keys(): return + + # inplace_map_str: "(x -> out0), (y -> out2)" + inplace_map_str = forward_api_contents['inplace'] + for pair in inplace_map_str.split(","): + pair = pair.strip() + if pair.startswith("("): + pair = pair[1:] + + if pair.endswith(")"): + pair = pair[:-1] + + key = pair.split("->")[0].strip() + val = pair.split("->")[1].strip() + self.inplace_map[key] = val + + def ParseNoNeedBuffer(self): + forward_api_contents = self.forward_api_contents + + if 'no_need_buffer' in forward_api_contents.keys(): + no_need_buffer_str = forward_api_contents['no_need_buffer'] + for name in no_need_buffer_str.split(","): + name = name.strip() + name = RemoveSpecialSymbolsInName(name) + self.no_need_buffers.append(name.strip()) + + def ParseDispensable(self): + forward_api_contents = self.forward_api_contents + + if 'optional' in forward_api_contents.keys(): + optional_inputs_str = forward_api_contents['optional'] + for name in optional_inputs_str.split(","): + name = name.strip() + name = RemoveSpecialSymbolsInName(name) + self.optional_inputs.append(name) + + def ParseIntermediate(self): + forward_api_contents = self.forward_api_contents + + if 'intermediate' in forward_api_contents.keys(): + intermediate_str = forward_api_contents['intermediate'] + for name in intermediate_str.split(","): + name = name.strip() + name = RemoveSpecialSymbolsInName(name) + self.intermediate_outputs.append(name) + + def CollectOriginalForwardInfo(self): + forward_api_contents = self.forward_api_contents + + self.forward_api_name = forward_api_contents['api'] + forward_args_str = forward_api_contents['args'] + forward_returns_str = forward_api_contents['output'] + + assert 'api' in forward_api_contents.keys( + ), "Unable to find \"api\" in forward_api_contents keys" + assert 'args' in forward_api_contents.keys( + ), "Unable to find \"args\" in forward_api_contents keys" + assert 'output' in forward_api_contents.keys( + ), "Unable to find \"output\" in forward_api_contents keys" + + # Collect Original Forward Inputs/Outputs and then perform validation checks + self.orig_forward_inputs_list, self.orig_forward_attrs_list, self.orig_forward_returns_list = ParseYamlForward( + forward_args_str, forward_returns_str) + + def DetermineForwardPositionMap(self, forward_inputs_list, + forward_returns_list): + for i in range(len(forward_inputs_list)): + forward_input = forward_inputs_list[i] + input_name = forward_input[0] + input_type = forward_input[1] + input_pos = forward_input[2] + + self.forward_inputs_position_map[ + input_name] = [input_type, input_pos] + + for i in range(len(forward_returns_list)): + forward_return = forward_returns_list[i] + return_name = forward_return[0] + return_type = forward_return[1] + return_pos = forward_return[2] + + self.forward_outputs_position_map[ + return_name] = [return_type, return_pos] + print("Generated Forward Input Position Map: ", + self.forward_inputs_position_map) + print("Generated Forward Output Position Map: ", + self.forward_outputs_position_map) + + +class YamlGeneratorBase: + def __init__(self, api_yaml_path): + self.namespace = "" + self.api_yaml_path = api_yaml_path + + self.forward_api_list = [] + + def ParseForwardYamlContents(self): + api_yaml_path = self.api_yaml_path + self.forward_api_list = ReadFwdFile(api_yaml_path) + + def InferNameSpace(self): + api_yaml_path = self.api_yaml_path + if "sparse" in api_yaml_path: + self.namespace = "sparse::" diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 1d18cbe782..fd750c0d07 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -16,31 +16,25 @@ import yaml import re import argparse import os - -ops_to_fill_zero_for_empty_grads = set(list("split")) - -# For API dispatch used at python-level -# { op_name : [arg_name, ...] } -core_ops_returns_info = {} -core_ops_args_info = {} -core_ops_args_type_info = {} - -namespace = "" - -yaml_types_mapping = { - 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ - 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ - 'str' : 'std::string', \ - 'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ - 'int64[]' : 'std::vector', 'int[]' : 'std::vector', - 'Tensor' : 'Tensor', - 'Tensor[]' : 'std::vector', - 'Tensor[Tensor[]]' : 'std::vector>', - 'Scalar' : 'paddle::experimental::Scalar', - 'ScalarArray' : 'paddle::experimental::ScalarArray' -} - - +from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info +from codegen_utils import yaml_types_mapping +from codegen_utils import ReadFwdFile, ReadBwdFile +from codegen_utils import FindGradName, FindForwardName, GetSavedName, GetGradNodeName +from codegen_utils import IsPlainTensorType, IsVectorTensorType +from codegen_utils import GetConstReference, RemoveConstAndReference +from codegen_utils import GetDygraphForwardFunctionName, GetIntermediateAPIFunctionName +from codegen_utils import GetAutoGradMetaName, GetAutoGradMetaVectorName +from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFunction +from codegen_utils import GetInplacedFunctionName +from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward +from codegen_utils import ParseYamlForward, ParseYamlBackward +from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase +from codegen_utils import ops_to_fill_zero_for_empty_grads + + +########### +## Utils ## +########### def ParseArguments(): parser = argparse.ArgumentParser( description='Eager Code Generator Args Parser') @@ -55,845 +49,129 @@ def ParseArguments(): return args -################# -### Helpers ### -################# -def RecoverBaseNameOfInplaceFunction(function_name): - return function_name[:-1] - - -def GetInplacedFunctionName(function_name): - return function_name + "_" - - -def FindGradName(string): - return string + "_grad" - - -def FindForwardName(string): - if not string.endswith("_grad"): - return None - return string[:-5] - - -def IsPlainTensorType(string): - plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor'] - if string in plain_tensor_types: - return True - return False - - -def IsVectorTensorType(string): - vector_tensor_types = [ - 'std::vector>', 'std::vector' - ] - if string in vector_tensor_types: - return True - return False - - -def GetSavedName(string): - return string + "_" - - -def GetConstReference(string): - ret = string - if not string.startswith("const "): - ret = "const " + string - if not string.endswith("&"): - ret += "&" - return ret - - -def RemoveConstAndReference(string): - ret = string - if string.startswith("const "): - ret = ret[6:] - if string.endswith("&"): - ret = ret[:-1] - - return ret - - -def GetGradNodeName(string): - return f"FinalGradNode{string}" - - -def GetForwardFunctionName(string): - return f"{string}_final_state_dygraph_function" - - -def GetAutoGradMetaName(string): - return f"{string}_autograd_meta" - - -def GetAutoGradMetaVectorName(string): - return f"{string}_autograd_meta_vec" - - -###################### -### File Readers ### -###################### -def ReadFwdFile(filepath): - f = open(filepath, 'r') - contents = yaml.load(f, Loader=yaml.FullLoader) - f.close() - return contents - - -def ReadBwdFile(filepath): - f = open(filepath, 'r') - contents = yaml.load(f, Loader=yaml.FullLoader) - ret = {} - for content in contents: - if 'backward_api' in content.keys(): - api_name = content['backward_api'] - else: - assert False - - ret[api_name] = content - f.close() - return ret - - -###################### -### Yaml Parsers ### -###################### -def ParseInplaceInfo(string): - # string: "(x -> out0), (y -> out2)" - inplace_map = {} - for pair in string.split(","): - pair = pair.strip() - if pair.startswith("("): - pair = pair[1:] - - if pair.endswith(")"): - pair = pair[:-1] - - key = pair.split("->")[0].strip() - val = pair.split("->")[1].strip() - inplace_map[key] = val - - return inplace_map - - -def RemoveSpecialSymbolsInName(string): - # Remove any name after '@' - ret = string.split("@")[0] - return ret - - -def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): - # intermediate_outputs : [name0, name1, ...] - # forward_returns_list : [[ret_name, type, orig_pos], ...] - """ - Check whether intermediate_outputs are positioned - at the very end of forward_returns_list - """ - - intermediate_positions = range( - len(forward_returns_list) - len(intermediate_outputs), - len(forward_returns_list)) - for ret_name, _, pos in forward_returns_list: - if ret_name in intermediate_outputs: - assert pos in intermediate_positions - - -def ParseDispensable(string): - # string: "X, Y" - string = RemoveSpecialSymbolsInName(string) - return [v.strip() for v in string.split(",")] - - -def ParseIntermediate(string): - string = RemoveSpecialSymbolsInName(string) - return [v.strip() for v in string.split(",")] - - -def ParseNoNeedBuffer(string): - # string: "x, y" - string = RemoveSpecialSymbolsInName(string) - - no_need_buffer_set = set() - for name in string.split(","): - no_need_buffer_set.add(name.strip()) - - return no_need_buffer_set - - -def ParseYamlArgs(string): - # Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y - - # inputs_list = [ [arg_name, arg_type, orig_position], ...] - inputs_list = [] - # attrs_list = [ [arg_name, arg_type, default_value, orig_position], ...] - attrs_list = [] - - args = [x.strip() for x in string.strip().split(",")] - atype = r'((const )?\S+) ' - aname = r'(.*)' - pattern = f'{atype}{aname}' - for i in range(len(args)): - arg = args[i] - m = re.search(pattern, arg) - arg_type = m.group(1).strip() - arg_name = m.group(3).split("=")[0].strip() - default_value = m.group(3).split("=")[1].strip() if len( - m.group(3).split("=")) > 1 else None - - assert arg_type in yaml_types_mapping.keys( - ), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping." - arg_type = yaml_types_mapping[arg_type] - - arg_name = RemoveSpecialSymbolsInName(arg_name) - if "Tensor" in arg_type: - assert default_value is None - inputs_list.append([arg_name, arg_type, i]) - else: - attrs_list.append([arg_name, arg_type, default_value, i]) - - return inputs_list, attrs_list - - -def ParseYamlReturns(string): - # Example0: Tensor(out), Tensor(out1) - # Example1: Tensor, Tensor - # Example2: Tensor[](out), Tensor - - # list = [ [ret_name, ret_type, orig_position], ...] - returns_list = [] - - returns = [x.strip() for x in string.strip().split(",")] - - for i in range(len(returns)): - ret = returns[i] - - ret_name = "" - if "(" in ret and ")" in ret: - # Remove trailing ')' - ret = ret[:-1] - ret_type = ret.split("(")[0].strip() - ret_name = ret.split("(")[1].strip() - else: - ret_type = ret.strip() - - assert ret_type in yaml_types_mapping.keys( - ), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping." - ret_type = yaml_types_mapping[ret_type] - - assert "Tensor" in ret_type - ret_name = RemoveSpecialSymbolsInName(ret_name) - returns_list.append([ret_name, ret_type, i]) - - return returns_list - - -def ParseYamlForwardFromBackward(string): - # Example: matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out) - - fname = r'(.*?)' - wspace = r'\s*' - fargs = r'(.*?)' - frets = r'(.*)' - pattern = f'{fname}{wspace}\({wspace}{fargs}{wspace}\){wspace}->{wspace}{frets}' - - m = re.search(pattern, string) - function_name = m.group(1) - function_args = m.group(2) - function_returns = m.group(3) - - forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args) - forward_returns_list = ParseYamlReturns(function_returns) - - return forward_inputs_list, forward_attrs_list, forward_returns_list - - -def ParseYamlForward(args_str, returns_str): - # args Example: (const Tensor& x, const Tensor& y, bool transpose_x = false, bool transpose_y = false) - # returns Example: Tensor, Tensor - - fargs = r'(.*?)' - wspace = r'\s*' - args_pattern = f'\({fargs}\)' - args_str = re.search(args_pattern, args_str).group(1) - - inputs_list, attrs_list = ParseYamlArgs(args_str) - returns_list = ParseYamlReturns(returns_str) - - return inputs_list, attrs_list, returns_list - - -def ParseYamlBackward(args_str, returns_str): - # args Example: (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false) - # returns Example: Tensor(x_grad), Tensor(y_grad) - - fargs = r'(.*?)' - wspace = r'\s*' - args_pattern = f'\({fargs}\)' - args_str = re.search(args_pattern, args_str).group(1) - - inputs_list, attrs_list = ParseYamlArgs(args_str) - returns_list = ParseYamlReturns(returns_str) - - return inputs_list, attrs_list, returns_list - - -####################### -### Preprocessing ### -####################### -def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list, - forward_returns_list, orig_forward_inputs_list, - orig_forward_attrs_list, orig_forward_returns_list): - for i in range(len(forward_inputs_list)): - forward_input_name = forward_inputs_list[i][0] - forward_input_type = forward_inputs_list[i][1] - forward_input_pos = forward_inputs_list[i][2] - orig_input_name = orig_forward_inputs_list[i][0] - orig_input_type = orig_forward_inputs_list[i][1] - orig_input_pos = orig_forward_inputs_list[i][2] - - assert forward_input_type == orig_input_type - assert forward_input_pos == orig_input_pos - - for i in range(len(forward_attrs_list)): - orig_attr_name = orig_forward_attrs_list[i][0] - orig_attr_type = orig_forward_attrs_list[i][1] - orig_attr_default = orig_forward_attrs_list[i][2] - orig_attr_pos = orig_forward_attrs_list[i][3] - forward_attr_name = forward_attrs_list[i][0] - forward_attr_type = forward_attrs_list[i][1] - forward_attr_default = forward_attrs_list[i][2] - forward_attr_pos = forward_attrs_list[i][3] - assert orig_attr_type == forward_attr_type - assert orig_attr_default == forward_attr_default - assert orig_attr_pos == forward_attr_pos - - for i in range(len(forward_returns_list)): - orig_return_type = orig_forward_returns_list[i][1] - orig_return_pos = orig_forward_returns_list[i][2] - forward_return_type = forward_returns_list[i][1] - forward_return_pos = forward_returns_list[i][2] - - assert orig_return_type == forward_return_type - assert orig_return_pos == forward_return_pos - - # Check Order: Inputs, Attributes - max_input_position = -1 - for _, _, pos in forward_inputs_list: - max_input_position = max(max_input_position, pos) - - max_attr_position = -1 - for _, _, _, pos in forward_attrs_list: - assert pos > max_input_position - max_attr_position = max(max_attr_position, pos) - - -def BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map, - backward_attrs_list): - - # Check Order: TensorWrappers, GradTensors, Attributes - max_fwd_input_position = -1 - for _, (_, _, pos) in backward_fwd_input_map.items(): - max_fwd_input_position = max(max_fwd_input_position, pos) - - max_grad_tensor_position = -1 - for _, (_, _, pos) in backward_grad_input_map.items(): - assert pos > max_fwd_input_position - max_grad_tensor_position = max(max_grad_tensor_position, pos) - - max_attr_position = -1 - for _, _, _, pos in backward_attrs_list: - assert pos > max_grad_tensor_position - max_attr_position = max(max_attr_position, pos) - - -def DetermineForwardPositionMap(forward_inputs_list, forward_returns_list): - forward_inputs_position_map = {} - forward_outputs_position_map = {} - for i in range(len(forward_inputs_list)): - forward_input = forward_inputs_list[i] - input_name = forward_input[0] - input_type = forward_input[1] - input_pos = forward_input[2] - - forward_inputs_position_map[input_name] = [input_type, input_pos] - - for i in range(len(forward_returns_list)): - forward_return = forward_returns_list[i] - return_name = forward_return[0] - return_type = forward_return[1] - return_pos = forward_return[2] - - forward_outputs_position_map[return_name] = [return_type, return_pos] - - return forward_inputs_position_map, forward_outputs_position_map - - -def SlotNameMatching(backward_inputs_list, backward_returns_list, - forward_inputs_position_map, forward_outputs_position_map): - - backward_fwd_input_map = {} - backward_grad_input_map = {} - backward_grad_output_map = {} - - for backward_input in backward_inputs_list: - backward_input_name = backward_input[0] - backward_input_type = backward_input[1] - backward_input_pos = backward_input[2] - - backward_fwd_name = FindForwardName(backward_input_name) - if backward_fwd_name: - # Grad Input - assert backward_fwd_name in forward_outputs_position_map.keys() - matched_forward_output_type = forward_outputs_position_map[ - backward_fwd_name][0] - matched_forward_output_pos = forward_outputs_position_map[ - backward_fwd_name][1] - - backward_grad_input_map[backward_input_name] = [ - backward_input_type, matched_forward_output_pos, - backward_input_pos - ] - else: - # TensorWrapper Input - if backward_input_name in forward_inputs_position_map.keys(): - tensor_wrapper_type = forward_inputs_position_map[ - backward_input_name][0] - backward_fwd_input_map[backward_input_name] = [ - backward_input_type, True, backward_input_pos - ] - - elif backward_input_name in forward_outputs_position_map.keys(): - tensor_wrapper_type = forward_outputs_position_map[ - backward_input_name][0] - backward_fwd_input_map[backward_input_name] = [ - backward_input_type, False, backward_input_pos - ] - else: - assert False, backward_input_name - - for backward_output in backward_returns_list: - backward_output_name = backward_output[0] - backward_output_type = backward_output[1] - backward_output_pos = backward_output[2] - - backward_fwd_name = FindForwardName(backward_output_name) - assert backward_fwd_name is not None - assert backward_fwd_name in forward_inputs_position_map.keys( - ), backward_fwd_name - - matched_forward_input_type = forward_inputs_position_map[ - backward_fwd_name][0] - matched_forward_input_pos = forward_inputs_position_map[ - backward_fwd_name][1] - - backward_grad_output_map[backward_output_name] = [ - backward_output_type, matched_forward_input_pos, backward_output_pos - ] - - return backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map - - -def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, - backward_attrs_list, no_need_buffer_set): - # Inputs: - # fwd_api_name = "" - # backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...} - # backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - - # Determine Node Name - forward_op_name = fwd_api_name - - # SetTensorWrapper Methods & TensorWrapper Members - set_tensor_wrapper_methods_str = "" - tensor_wrapper_members_str = "" - clear_tensor_wrapper_str = "" - for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items(): - if tname in no_need_buffer_set: - no_need_buffer = "true" - else: - no_need_buffer = "false" - - tensor_wrapper_name = GetSavedName(tname) - if IsPlainTensorType(ttype): - SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """ +######################## +## Code Gen Templates ## +######################## +SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \ +""" void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ {} = egr::TensorWrapper({}, full_reserved, {}); }} """ - set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( - tname, tname, tensor_wrapper_name, tname, no_need_buffer) - PLAIN_TENSOR_MEMBER_TEMPLATE = """ - egr::TensorWrapper {}; +PLAIN_TENSOR_MEMBER_TEMPLATE = \ +""" + egr::TensorWrapper {}; """ - tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( - tensor_wrapper_name) - CLEAR_TENSOR_WRAPPERS_TEMPLATE = """ - {}.clear(); +CLEAR_TENSOR_WRAPPER_TEMPLATE = \ +""" + {}.clear(); """ - clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format( - tensor_wrapper_name) - else: - assert IsVectorTensorType(ttype) - SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ - void SetTensorWrapper{}(const std::vector& {}, bool full_reserved) {{ - for(const auto& eager_tensor : {}) {{ - {}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) ); - }}; - }} +SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \ +""" + void SetTensorWrapper{}(const std::vector& {}, bool full_reserved) {{ + for(const auto& eager_tensor : {}) {{ + {}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) ); + }}; + }} """ - set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( - tname, tname, tname, tensor_wrapper_name, no_need_buffer) - VECTOR_TENSOR_MEMBER_TEMPLATE = """ - std::vector {}; +VECTOR_TENSOR_MEMBER_TEMPLATE = \ +""" + std::vector {}; """ - tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( - tensor_wrapper_name) - CLEAR_TENSOR_WRAPPERS_TEMPLATE = """ - for (auto tw: {}) { - tw.clear(); - }; +CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \ """ - clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format( - tensor_wrapper_name) - - # End: SetTensorWrapper Methods & TensorWrapper Members - - # SetAttributes & Attribute Members - set_attribute_methods_str = "" - attribute_members_str = "" - for aname, atype, default_val, _ in backward_attrs_list: - saved_attr_name = GetSavedName(aname) - SET_ATTR_METHOD_TEMPLATE = """ - void SetAttribute{}({} {}) {{ - {} = {}; - }} + for (auto tw: {}) { + tw.clear(); + }; """ - set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format( - aname, GetConstReference(atype), aname, saved_attr_name, aname) - if default_val: - ATTRIBUTE_MEMBER_TEMPLATE = """ +SET_ATTR_METHOD_TEMPLATE = \ +""" + void SetAttribute{}({} {}) {{ + {} = {}; + }} +""" + +ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = \ +""" {} {} = {}; - """ - attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( - RemoveConstAndReference(atype), saved_attr_name, default_val) - else: - ATTRIBUTE_MEMBER_TEMPLATE = """ +""" + +ATTRIBUTE_MEMBER_TEMPLATE = \ +""" {} {}; - """ - attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( - RemoveConstAndReference(atype), saved_attr_name) - # End: SetAttributes & Attribute Members - - grad_node_name = GetGradNodeName(fwd_api_name) - NODE_DECLARATION_TEMPLATE = """ -class {} : public egr::GradNodeBase {{ - public: - {}() : egr::GradNodeBase() {{}} - {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : - egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} - ~{}() override = default; - - virtual std::vector> operator()( - std::vector>& grads, bool create_graph = false) override; - - std::string name() override {{ return \" {} \"; }} - - void ClearTensorWrappers() override {{ - {} - is_tensor_wrappers_cleared = true; - }} - - // SetTensorWrapperX, SetTensorWrapperY, ... - {} - // SetAttributes - {} - - bool IsTensorWrappersCleared() override {{ - return is_tensor_wrappers_cleared; - }} - private: - // TensorWrappers - {} - - bool is_tensor_wrappers_cleared = false; - - // Attributes - {} -}}; """ - node_declaration_str = NODE_DECLARATION_TEMPLATE.format( - grad_node_name, grad_node_name, grad_node_name, grad_node_name, - grad_node_name, clear_tensor_wrapper_str, - set_tensor_wrapper_methods_str, set_attribute_methods_str, - tensor_wrapper_members_str, attribute_members_str) - - return node_declaration_str - - -def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, - backward_grad_input_map, backward_grad_output_map, - backward_attrs_list): - # fwd_api_name = "" - # backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...} - # backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...} - # backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...} - # backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - - # Construct grad_api function args - # Order: TensorWrappers, GradTensors, Attributes - grad_api_args_len = len(backward_fwd_input_map.keys()) + len( - backward_grad_input_map.keys()) + len(backward_attrs_list) - grad_api_args = ["" for i in range(grad_api_args_len)] - for name, (_, is_fwd_input, - grad_api_position), in backward_fwd_input_map.items(): - tensor_wrapper_name = GetSavedName(name) - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" - - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_input_map.items(): - if IsPlainTensorType(ttype): - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}][0]" - else: - assert IsVectorTensorType(ttype) - grad_api_args[grad_api_position] = f"hooked_grads[{fwd_position}]" - - for name, _, _, grad_api_position in backward_attrs_list: - saved_attribute_name = GetSavedName(name) - grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" - grad_api_args_str = ", ".join(grad_api_args) - - # Construct grad_api returns - num_bwd_outputs = len(backward_grad_output_map.keys()) - returns_str = f"std::vector> returns({num_bwd_outputs});\n" - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_output_map.items(): - # Infer Grad API Return Type - if num_bwd_outputs == 1: - # Single tensor output, return as is - if IsPlainTensorType(ttype): - returns_str += "returns[0] = { grad_api_returns };\n" - else: - assert IsVectorTensorType(ttype) - returns_str += "returns[0] = grad_api_returns;\n" - else: - # Rearrange output order accordingly - returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" - returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" - returns_str += f"return returns;\n" - grad_node_name = GetGradNodeName(fwd_api_name) +NODE_DECLARATION_TEMPLATE = \ +""" + class {} : public egr::GradNodeBase {{ + public: + {}() : egr::GradNodeBase() {{}} + {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : + egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} + ~{}() override = default; + + virtual std::vector> operator()( + std::vector>& grads, bool create_graph = false) override; + std::string name() override {{ return \" {} \"; }} + + void ClearTensorWrappers() override {{ + {} + is_tensor_wrappers_cleared = true; + }} + + // SetTensorWrapperX, SetTensorWrapperY, ... + {} + // SetAttributes + {} - fill_zero_str = "" - if fwd_api_name in ops_to_fill_zero_for_empty_grads: - fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + bool IsTensorWrappersCleared() override {{ + return is_tensor_wrappers_cleared; + }} + private: + // TensorWrappers + {} - if len(namespace) > 0: - grad_api_namespace = f"paddle::experimental::{namespace}" - else: - grad_api_namespace = f"paddle::experimental" + bool is_tensor_wrappers_cleared = false; - FUNCTION_TEMPLATE = """ -std::vector> {}::operator()(std::vector>& grads, bool create_graph) {{ - {} - auto hooked_grads = ApplyGradientHooks(grads); - - // Call grad_api function - VLOG(3) << \"Final State Running: \" << \"{}\"; - auto grad_api_returns = {}::{}({}); - {} -}} - """ - - node_definition_str = FUNCTION_TEMPLATE.format( - grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, - bwd_api_name, grad_api_args_str, returns_str) - - return node_definition_str - - -def GenerateNodeCreationCodes( - fwd_api_name, bwd_api_name, forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list, forward_call_str, - backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list, optional_inputs, - inplace_map): - # fwd_api_name = "" - # forward_inputs_position_map = { "name" : [type, fwd_position] } - # forward_outputs_position_map = { "name" : [type, fwd_position] } - # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - # backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...} - # backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...} - # backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...} - # backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - - # Get Input AutoGradMeta - inputs_autograd_meta_list = [] - compute_require_grad_args_list = ["trace_backward"] - for name, (ttype, pos) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" - else: - assert IsVectorTensorType(ttype) - input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" - - inputs_autograd_meta_list.append(input_autograd_meta) - compute_require_grad_args_list.append(input_autograd_meta_name) - inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) - compute_require_grad_args_str = ",".join(compute_require_grad_args_list) - - # Get Output AutoGradMeta - outputs_autograd_meta_list = [] - pass_stop_gradient_args_list = ["false"] - num_fwd_outputs = len(forward_outputs_position_map.keys()) - for name, (rtype, pos) in forward_outputs_position_map.items(): - output_autograd_meta_name = GetAutoGradMetaName(name) - output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - if num_fwd_outputs == 1: - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - else: - # Tuple api_result - if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" - else: - assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" - - outputs_autograd_meta_list.append(output_autograd_meta) - pass_stop_gradient_args_list.append(output_autograd_meta_name) - - # ComputeRequireGrad & PassStopGradient - outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) - - # Check Inplace - check_inplace_str = "" - bump_inplace_version_str = "" - for inplace_name in inplace_map.keys(): - inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) - check_inplace_str += f""" - // Check Inplace - egr::EagerUtils::CheckInplace({inplace_name}, {inplace_autograd_meta_name}, require_any_grad);\n + // Attributes + {} + }}; """ - bump_inplace_version_str += f""" - // Bump Inplace Version - {inplace_name}.bump_inplace_version(); - VLOG(3) << \"Tensor(\" << {inplace_name}.name() << \") uses Inplace Strategy.\";\n +FUNCTION_TEMPLATE = \ +""" + std::vector> {}::operator()(std::vector>& grads, bool create_graph) {{ + {} + auto hooked_grads = ApplyGradientHooks(grads); + + // Call grad_api function + VLOG(3) << \"Final State Running: \" << \"{}\"; + auto grad_api_returns = {}{}({}); + {} + }} """ - # Node Construction - num_bwd_inputs = len(backward_grad_input_map.keys()) - num_bwd_outputs = len(backward_grad_output_map.keys()) - grad_node_name = GetGradNodeName( - RecoverBaseNameOfInplaceFunction( - fwd_api_name)) if inplace_map else GetGradNodeName(fwd_api_name) - node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});" - - # SetAttributes - set_attributes_list = [] - forward_attrs_name_set = set() - for name, _, _, _ in forward_attrs_list: - forward_attrs_name_set.add(name) - - for name, _, default_val_attr, _ in backward_attrs_list: - if name in forward_attrs_name_set: - set_attributes = f" grad_node->SetAttribute{name}({name});" - else: - set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});" - set_attributes_list.append(set_attributes) - set_attributes_str = "\n".join(set_attributes_list) - - # SetTensorWrappers - set_tensor_wrappers_list = [] - for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items(): - is_optional = (name in optional_inputs) - - if is_fwd_input: - if is_optional: - set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" - else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" - else: - if num_fwd_outputs > 1: - # Aligned with forward output position - assert name in forward_outputs_position_map.keys() - fwd_output_pos = forward_outputs_position_map[name][1] - tw_name = f"std::get<{fwd_output_pos}>(api_result)" - else: - tw_name = f"api_result" +FORWARD_FUNCTION_TEMPLATE = \ +""" + {} {}({}) {{ + {} + + {} - if is_optional: - set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" - else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" - set_tensor_wrappers_list.append(set_tensor_wrappers) - set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) - - # SetGradOutMeta & SetEdges - set_grad_out_meta_list = [] - set_edges_list = [] - for name, (_, pos) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" - set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" - set_grad_out_meta_list.append(set_grad_out_meta) - set_edges_list.append(set_edges) - set_grad_out_meta_str = "\n".join(set_grad_out_meta_list) - set_edges_str = "\n".join(set_edges_list) - - # SetOutRank & SetHistory & SetGradInMeta - set_out_rank_list = [] - set_history_list = [] - set_grad_in_meta_list = [] - set_retain_grad_list = [] - num_outputs = len(forward_outputs_position_map.keys()) - for name, (_, pos) in forward_outputs_position_map.items(): - output_autograd_meta_name = GetAutoGradMetaName(name) - set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" - set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" - if num_outputs == 1: - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" - set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});" - else: - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));" - set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});" - - set_out_rank_list.append(set_out_rank) - set_history_list.append(set_history) - set_grad_in_meta_list.append(set_grad_in_meta) - set_retain_grad_list.append(set_retain_grad) - - set_out_rank_str = "\n".join(set_out_rank_list) - set_history_str = "\n".join(set_history_list) - set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) - set_retain_grad_str = "\n".join(set_retain_grad_list) - - node_event_name = fwd_api_name + " node_creation" - NODE_CREATION_TEMPLATE = """ - paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n - """ - node_creation_event_str = NODE_CREATION_TEMPLATE.format(node_event_name) + // Returns + return {}; + }} - NODE_CREATION_TEMPLATE = """ +""" +NODE_CREATION_TEMPLATE = \ +""" // Get AutoGradMeta {} bool trace_backward = egr::Controller::Instance().HasGrad(); @@ -924,185 +202,72 @@ def GenerateNodeCreationCodes( {} }} }} +""" +NAMESPACE_WRAPPER_TEMPLATE = \ +""" +namespace {} {{ + {} +}} """ - node_creation_str = NODE_CREATION_TEMPLATE.format( - inputs_autograd_meta_str, compute_require_grad_args_str, - check_inplace_str, forward_call_str, bump_inplace_version_str, - node_creation_event_str, outputs_autograd_meta_str, - pass_stop_gradient_args_str, node_construction_str, set_attributes_str, - set_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str, - set_out_rank_str, set_history_str, set_grad_in_meta_str, - set_retain_grad_str) - - return node_creation_str - - -def GenerateForwardDefinition( - fwd_api_name, bwd_api_name, forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list, - backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list, optional_inputs, - intermediate_outputs, inplace_map): - # fwd_api_name = "" - # forward_inputs_position_map = { "name" : [type, fwd_position] } - # forward_outputs_position_map = { "name" : [type, fwd_position] } - # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - # backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...} - # backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...} - # backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...} - # backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - # optional_inputs = ["name0", ...] - - # Get Function Args - num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys( - )) - inputs_args_definition_list = ["" for i in range(num_inputs)] - inputs_args_declaration_list = ["" for i in range(num_inputs)] - inputs_call_list = ["" for i in range(num_inputs)] - for name, (ttype, pos) in forward_inputs_position_map.items(): - inputs_call_list[pos] = f"{name}" - is_optional = (name in optional_inputs) - if IsPlainTensorType(ttype): - if is_optional: - arg_str = f"const paddle::optional& {name}" - else: - if inplace_map and name in inplace_map.keys(): - arg_str = f"paddle::experimental::Tensor& {name}" - else: - arg_str = f"const paddle::experimental::Tensor& {name}" - else: - assert IsVectorTensorType(ttype) - arg_str = f"const std::vector& {name}" - inputs_args_definition_list[pos] = arg_str - inputs_args_declaration_list[pos] = arg_str +NODE_CC_FILE_TEMPLATE = \ +""" +#include "glog/logging.h" +#include "paddle/phi/api/all.h" +#include "paddle/phi/api/backward/backward_api.h" +#include "paddle/phi/api/backward/sparse_bw_api.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" +#include "paddle/fluid/eager/to_static/run_program_op_node.h" - for name, atype, default_val, pos in forward_attrs_list: - inputs_call_list[pos] = name - if default_val is not None: - inputs_args_declaration_list[ - pos] = f"{atype} {name} = {default_val}" - else: - inputs_args_declaration_list[pos] = f"{atype} {name}" - inputs_args_definition_list[pos] = f"{atype} {name}" - - inputs_args_declaration_str = ", ".join(inputs_args_declaration_list) - inputs_args_definition_str = ", ".join(inputs_args_definition_list) - inputs_call_args_str = ", ".join(inputs_call_list) - - # Forward Full Logic - if len(intermediate_outputs) == 0: - function_name = fwd_api_name - else: - function_name = fwd_api_name + "_intermediate" - - if len(namespace) > 0: - forward_call_str = f"auto api_result = paddle::experimental::{namespace}::{function_name}({inputs_call_args_str});" - else: - forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});" - - # Get return type list & outputs - num_outputs = len(forward_outputs_position_map.keys()) - len( - intermediate_outputs) - returns_type_list = ["" for i in range(num_outputs)] - returns_list = ["" for i in range(num_outputs)] - for name, (rtype, pos) in forward_outputs_position_map.items(): - if name in intermediate_outputs: - continue - if num_outputs == 1: - returns_list[0] = f"api_result" - else: - # Tuple api_result - returns_list[pos] = f"std::get<{pos}>(api_result)" +#include "paddle/phi/api/include/sparse_api.h" - if IsPlainTensorType(rtype): - returns_type_list[pos] = "paddle::experimental::Tensor" - else: - assert IsVectorTensorType(rtype) - returns_type_list[pos] = "std::vector" - - if num_outputs == 1: - returns_str = returns_list[0] - returns_type_str = returns_type_list[0] - else: - returns_type_str = ", ".join(returns_type_list) - returns_type_str = f"std::tuple<{returns_type_str}>" - returns_str = ", ".join(returns_list) - returns_str = f"std::make_tuple({returns_str})" - - node_creation_str = GenerateNodeCreationCodes( - fwd_api_name, bwd_api_name, forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list, forward_call_str, - backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list, optional_inputs, - inplace_map) - - dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{fwd_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" - - FORWARD_FUNCTION_TEMPLATE = """ -{} {}({}) {{ - {} - {} - - // Returns - return {}; -}} """ - forward_function_name = GetForwardFunctionName(fwd_api_name) - forward_function_str = FORWARD_FUNCTION_TEMPLATE.format( - returns_type_str, forward_function_name, inputs_args_definition_str, - dygraph_event_str, node_creation_str, returns_str) - forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});" - - return forward_function_str, forward_function_declaration_str - - -def CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list): - # fwd_api_name : "" - # forward_inputs_position_map = { "name" : [type, fwd_position] } - # forward_outputs_position_map = { "name" : [type, fwd_position] } - # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] - num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list) - num_returns = len(forward_outputs_position_map.keys()) - - final_state_fwd_api_name = "final_state_" + fwd_api_name - core_ops_returns_info[ - final_state_fwd_api_name] = ["" for i in range(num_returns)] - core_ops_args_info[final_state_fwd_api_name] = ["" for i in range(num_args)] - core_ops_args_type_info[ - final_state_fwd_api_name] = ["" for i in range(num_args)] - for name, (ttype, pos) in forward_inputs_position_map.items(): - core_ops_args_info[final_state_fwd_api_name][pos] = name - if IsPlainTensorType(ttype): - core_ops_args_type_info[final_state_fwd_api_name][pos] = "tensor" - else: - assert IsVectorTensorType(ttype) - core_ops_args_type_info[final_state_fwd_api_name][pos] = "list" - - for name, _, _, pos in forward_attrs_list: - core_ops_args_info[final_state_fwd_api_name][pos] = name +NODE_H_FILE_TEMPLATE = \ +""" +#pragma once +#include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/eager/grad_node_info.h" - for name, (ttype, pos) in forward_outputs_position_map.items(): - core_ops_returns_info[final_state_fwd_api_name][pos] = name +{} +""" +FORWARD_CC_FILE_TEMPLATE = \ +""" +#include "paddle/phi/api/lib/dygraph_api.h" +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" -def GenerateCoreOpInfoDeclaration(): - core_ops_declaration_str = """ - extern std::unordered_map> core_ops_final_state_args_info; - extern std::unordered_map> core_ops_final_state_args_type_info; - extern std::unordered_map> core_ops_final_state_returns_info; +#include "paddle/phi/api/include/sparse_api.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" +{} +{} """ - return core_ops_declaration_str +FORWARD_H_FILE_TEMPLATE = \ +""" +#pragma once +#include "glog/logging.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/phi/api/all.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/eager/to_static/run_program_op_func.h" -def GenerateCoreOpInfoDefinition(): +{} +{} +""" - CORE_OPS_INFO_TEMPLATE = """ +CORE_OPS_INFO_TEMPLATE = \ +""" std::unordered_map> core_ops_final_state_args_info = {{ {} }}; @@ -1114,6 +279,38 @@ std::unordered_map> core_ops_final_state_r }}; """ + +CORE_OPS_DECLARATION_TEMPLATE = \ +""" + extern std::unordered_map> core_ops_final_state_args_info; + extern std::unordered_map> core_ops_final_state_args_type_info; + extern std::unordered_map> core_ops_final_state_returns_info; + +""" + +CHECK_INPLACE_TEMPLATE = \ +""" + // Check Inplace + egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n +""" + +BUMP_INPLACE_VERSION_TEMPLATE = \ +""" + // Bump Inplace Version + {}.bump_inplace_version(); + VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";\n +""" + + +####################### +## Generator Helpers ## +####################### +def GenerateCoreOpInfoDeclaration(): + return CORE_OPS_DECLARATION_TEMPLATE + + +def GenerateCoreOpInfoDefinition(): + op_args_info_list = [] for op_name, arg_list in core_ops_args_info.items(): arg_str = ",".join(["\"" + v + "\"" for v in arg_list]) @@ -1142,68 +339,864 @@ std::unordered_map> core_ops_final_state_r return core_ops_info_definition_str +##################### +## Generator Class ## +##################### +class DygraphSingleFunctionGenerator(FunctionGeneratorBase): + def __init__(self, forward_api_contents, grad_api_contents, namespace): + self.forward_api_contents = forward_api_contents + # Members from Parent: + #self.namespace + #self.forward_api_contents + #self.forward_api_name + #self.orig_forward_inputs_list + #self.orig_forward_attrs_list + #self.orig_forward_returns_list + #self.forward_inputs_position_map + #self.forward_outputs_position_map + #self.optional_inputs + #self.no_need_buffers + #self.intermediate_outputs + #self.inplace_map + FunctionGeneratorBase.__init__(self, forward_api_contents, namespace) + + self.grad_api_contents = grad_api_contents + + # Raw Contents + self.backward_forward_str = "" + self.backward_api_name = "" + + self.forward_attrs_list = [ + ] #[ [attr_name, attr_type, default_value, orig_position], ...] + self.forward_inputs_list = [ + ] #[ [arg_name, arg_type, orig_position], ...] + self.forward_returns_list = [ + ] #[ [ret_name, ret_type, orig_position], ...] + + self.backward_inputs_list = [ + ] #[ [attr_name, attr_type, default_value, orig_position], ...] + self.backward_attrs_list = [ + ] #[ [arg_name, arg_type, orig_position], ...] + self.backward_returns_list = [ + ] #[ [ret_name, ret_type, orig_position], ...] + + # SlotNameMatched Backward Data + self.backward_forward_inputs_map = { + } #{ "name" : [type, is_fwd_input, orig_position] ...} + self.backward_grad_inputs_map = { + } #{ "name" : [type, fwd_position, orig_position] ...} + self.backward_grad_outputs_map = { + } #{ "name" : [type, fwd_position, orig_position] ...} + + # Generated Results + self.forward_definition_str = "" + self.forward_declaration_str = "" + self.node_declaration_str = "" + self.node_definition_str = "" + + def DygraphYamlValidationCheck(self): + forward_api_contents = self.forward_api_contents + grad_api_contents = self.grad_api_contents + + assert 'api' in forward_api_contents.keys() + assert 'args' in forward_api_contents.keys() + assert 'output' in forward_api_contents.keys() + assert 'backward' in forward_api_contents.keys() + + assert 'args' in grad_api_contents.keys() + assert 'output' in grad_api_contents.keys() + assert 'forward' in grad_api_contents.keys() + + def ForwardsValidationCheck(self): + forward_inputs_list = self.forward_inputs_list + forward_attrs_list = self.forward_attrs_list + forward_returns_list = self.forward_returns_list + + orig_forward_inputs_list = self.orig_forward_inputs_list + orig_forward_attrs_list = self.orig_forward_attrs_list + orig_forward_returns_list = self.orig_forward_returns_list + + for i in range(len(forward_inputs_list)): + forward_input_name = forward_inputs_list[i][0] + forward_input_type = forward_inputs_list[i][1] + forward_input_pos = forward_inputs_list[i][2] + orig_input_name = orig_forward_inputs_list[i][0] + orig_input_type = orig_forward_inputs_list[i][1] + orig_input_pos = orig_forward_inputs_list[i][2] + + assert forward_input_type == orig_input_type + assert forward_input_pos == orig_input_pos + + for i in range(len(forward_attrs_list)): + orig_attr_name = orig_forward_attrs_list[i][0] + orig_attr_type = orig_forward_attrs_list[i][1] + orig_attr_default = orig_forward_attrs_list[i][2] + orig_attr_pos = orig_forward_attrs_list[i][3] + forward_attr_name = forward_attrs_list[i][0] + forward_attr_type = forward_attrs_list[i][1] + forward_attr_default = forward_attrs_list[i][2] + forward_attr_pos = forward_attrs_list[i][3] + assert orig_attr_type == forward_attr_type + assert orig_attr_default == forward_attr_default + assert orig_attr_pos == forward_attr_pos + + for i in range(len(forward_returns_list)): + orig_return_type = orig_forward_returns_list[i][1] + orig_return_pos = orig_forward_returns_list[i][2] + forward_return_type = forward_returns_list[i][1] + forward_return_pos = forward_returns_list[i][2] + + assert orig_return_type == forward_return_type + assert orig_return_pos == forward_return_pos + + # Check Order: Inputs, Attributes + max_input_position = -1 + for _, _, pos in forward_inputs_list: + max_input_position = max(max_input_position, pos) + + max_attr_position = -1 + for _, _, _, pos in forward_attrs_list: + assert pos > max_input_position + max_attr_position = max(max_attr_position, pos) + + def BackwardValidationCheck(self): + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_attrs_list = self.backward_attrs_list + + # Check Order: TensorWrappers, GradTensors, Attributes + max_fwd_input_position = -1 + for _, (_, _, pos) in backward_forward_inputs_map.items(): + max_fwd_input_position = max(max_fwd_input_position, pos) + + max_grad_tensor_position = -1 + for _, (_, _, pos) in backward_grad_inputs_map.items(): + assert pos > max_fwd_input_position + max_grad_tensor_position = max(max_grad_tensor_position, pos) + + max_attr_position = -1 + for _, _, _, pos in backward_attrs_list: + assert pos > max_grad_tensor_position + max_attr_position = max(max_attr_position, pos) + + def IntermediateValidationCheck(self): + intermediate_outputs = self.intermediate_outputs + forward_returns_list = self.forward_returns_list + """ + Check whether intermediate_outputs are positioned + at the very end of forward_returns_list + """ + intermediate_positions = range( + len(forward_returns_list) - len(intermediate_outputs), + len(forward_returns_list)) + for ret_name, _, pos in forward_returns_list: + if ret_name in intermediate_outputs: + assert pos in intermediate_positions + + def CollectBackwardInfo(self): + forward_api_contents = self.forward_api_contents + grad_api_contents = self.grad_api_contents + + self.backward_api_name = forward_api_contents['backward'] + self.backward_forward_str = grad_api_contents['forward'] + + backward_args_str = grad_api_contents['args'] + backward_returns_str = grad_api_contents['output'] + + self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward( + backward_args_str, backward_returns_str) + print("Parsed Backward Inputs List: ", self.backward_inputs_list) + print("Prased Backward Attrs List: ", self.backward_attrs_list) + print("Parsed Backward Returns List: ", self.backward_returns_list) + + def CollectForwardInfoFromBackwardContents(self): + + backward_forward_str = self.backward_forward_str + + self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward( + backward_forward_str) + + def SlotNameMatching(self): + backward_inputs_list = self.backward_inputs_list + backward_returns_list = self.backward_returns_list + forward_inputs_position_map = self.forward_inputs_position_map + forward_outputs_position_map = self.forward_outputs_position_map + + for backward_input in backward_inputs_list: + backward_input_name = backward_input[0] + backward_input_type = backward_input[1] + backward_input_pos = backward_input[2] + + backward_fwd_name = FindForwardName(backward_input_name) + if backward_fwd_name: + # Grad Input + assert backward_fwd_name in forward_outputs_position_map.keys() + matched_forward_output_type = forward_outputs_position_map[ + backward_fwd_name][0] + matched_forward_output_pos = forward_outputs_position_map[ + backward_fwd_name][1] + + self.backward_grad_inputs_map[backward_input_name] = [ + backward_input_type, matched_forward_output_pos, + backward_input_pos + ] + else: + # TensorWrapper Input + if backward_input_name in forward_inputs_position_map.keys(): + tensor_wrapper_type = forward_inputs_position_map[ + backward_input_name][0] + self.backward_forward_inputs_map[backward_input_name] = [ + backward_input_type, True, backward_input_pos + ] + + elif backward_input_name in forward_outputs_position_map.keys(): + tensor_wrapper_type = forward_outputs_position_map[ + backward_input_name][0] + self.backward_forward_inputs_map[backward_input_name] = [ + backward_input_type, False, backward_input_pos + ] + else: + assert False, backward_input_name + + for backward_output in backward_returns_list: + backward_output_name = backward_output[0] + backward_output_type = backward_output[1] + backward_output_pos = backward_output[2] + + backward_fwd_name = FindForwardName(backward_output_name) + assert backward_fwd_name is not None + assert backward_fwd_name in forward_inputs_position_map.keys( + ), f"Unable to find {backward_fwd_name} in forward inputs" + + matched_forward_input_type = forward_inputs_position_map[ + backward_fwd_name][0] + matched_forward_input_pos = forward_inputs_position_map[ + backward_fwd_name][1] + + self.backward_grad_outputs_map[backward_output_name] = [ + backward_output_type, matched_forward_input_pos, + backward_output_pos + ] + print("Generated Backward Fwd Input Map: ", + self.backward_forward_inputs_map) + print("Generated Backward Grad Input Map: ", + self.backward_grad_inputs_map) + print("Generated Backward Grad Output Map: ", + self.backward_grad_outputs_map) + + def GenerateNodeDeclaration(self): + forward_op_name = self.forward_api_name + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_attrs_list = self.backward_attrs_list + no_need_buffers = self.no_need_buffers + + # SetTensorWrapper Methods & TensorWrapper Members + set_tensor_wrapper_methods_str = "" + tensor_wrapper_members_str = "" + clear_tensor_wrapper_str = "" + for tname, (ttype, is_fwd_input, + _) in backward_forward_inputs_map.items(): + no_need_buffer = "true" if tname in no_need_buffers else "false" + tensor_wrapper_name = GetSavedName(tname) + if IsPlainTensorType(ttype): + set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format( + tname, tname, tensor_wrapper_name, tname, no_need_buffer) + + tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( + tensor_wrapper_name) + + clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format( + tensor_wrapper_name) + + else: + assert IsVectorTensorType(ttype) + set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format( + tname, tname, tname, tensor_wrapper_name, no_need_buffer) + + tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( + tensor_wrapper_name) + + clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format( + tensor_wrapper_name) + + # SetAttributes & Attribute Members + set_attribute_methods_str = "" + attribute_members_str = "" + for aname, atype, default_val, _ in backward_attrs_list: + saved_attr_name = GetSavedName(aname) + set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format( + aname, GetConstReference(atype), aname, saved_attr_name, aname) + + if default_val: + attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format( + RemoveConstAndReference(atype), saved_attr_name, + default_val) + else: + attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( + RemoveConstAndReference(atype), saved_attr_name) + + grad_node_name = GetGradNodeName(forward_op_name) + self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( + grad_node_name, grad_node_name, grad_node_name, grad_node_name, + grad_node_name, clear_tensor_wrapper_str, + set_tensor_wrapper_methods_str, set_attribute_methods_str, + tensor_wrapper_members_str, attribute_members_str) + + print("Generated Node Declaration: ", self.node_declaration_str) + + def GenerateNodeDefinition(self): + namespace = self.namespace + forward_api_name = self.forward_api_name + backward_api_name = self.backward_api_name + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map + backward_attrs_list = self.backward_attrs_list + + # Construct grad_api function args + # Order: TensorWrappers, GradTensors, Attributes + grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( + backward_grad_inputs_map.keys()) + len(backward_attrs_list) + grad_api_args = ["" for i in range(grad_api_args_len)] + for name, (_, is_fwd_input, + grad_api_position), in backward_forward_inputs_map.items(): + tensor_wrapper_name = GetSavedName(name) + grad_api_args[ + grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + + for _, (ttype, fwd_position, + grad_api_position) in backward_grad_inputs_map.items(): + if IsPlainTensorType(ttype): + grad_api_args[ + grad_api_position] = f"hooked_grads[{fwd_position}][0]" + else: + assert IsVectorTensorType(ttype) + grad_api_args[ + grad_api_position] = f"hooked_grads[{fwd_position}]" + + for name, _, _, grad_api_position in backward_attrs_list: + saved_attribute_name = GetSavedName(name) + grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" + grad_api_args_str = ", ".join(grad_api_args) + + # Construct grad_api returns + num_bwd_outputs = len(backward_grad_outputs_map.keys()) + returns_str = f"std::vector> returns({num_bwd_outputs});\n" + for _, (ttype, fwd_position, + grad_api_position) in backward_grad_outputs_map.items(): + # Infer Grad API Return Type + if num_bwd_outputs == 1: + # Single tensor output, return as is + if IsPlainTensorType(ttype): + returns_str += "returns[0] = { grad_api_returns };\n" + else: + assert IsVectorTensorType(ttype) + returns_str += "returns[0] = grad_api_returns;\n" + else: + # Rearrange output order accordingly + returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" + returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" + returns_str += f"return returns;\n" + + grad_node_name = GetGradNodeName(forward_api_name) + + fill_zero_str = "" + if forward_api_name in ops_to_fill_zero_for_empty_grads: + fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + + grad_api_namespace = f"paddle::experimental::{namespace}" + + self.node_definition_str = FUNCTION_TEMPLATE.format( + grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, + backward_api_name, grad_api_args_str, returns_str) + + print("Generated Node Definition: ", self.node_definition_str) + + def GenerateForwardDefinition(self, is_inplaced): + namespace = self.namespace + forward_api_name = GetInplacedFunctionName( + self.forward_api_name) if is_inplaced else self.forward_api_name + backward_api_name = self.backward_api_name + forward_inputs_position_map = self.forward_inputs_position_map + forward_outputs_position_map = self.forward_outputs_position_map + forward_attrs_list = self.forward_attrs_list + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map + backward_attrs_list = self.backward_attrs_list + optional_inputs = self.optional_inputs + intermediate_outputs = self.intermediate_outputs + inplace_map = self.inplace_map + + # Get Function Args + num_inputs = len(forward_attrs_list) + len( + forward_inputs_position_map.keys()) + inputs_args_definition_list = ["" for i in range(num_inputs)] + inputs_args_declaration_list = ["" for i in range(num_inputs)] + inputs_call_list = ["" for i in range(num_inputs)] + for name, (ttype, pos) in forward_inputs_position_map.items(): + inputs_call_list[pos] = f"{name}" + is_optional = (name in optional_inputs) + if IsPlainTensorType(ttype): + if is_optional: + arg_str = f"const paddle::optional& {name}" + else: + if inplace_map and name in inplace_map.keys(): + arg_str = f"paddle::experimental::Tensor& {name}" + else: + arg_str = f"const paddle::experimental::Tensor& {name}" + else: + assert IsVectorTensorType(ttype) + arg_str = f"const std::vector& {name}" + + inputs_args_definition_list[pos] = arg_str + inputs_args_declaration_list[pos] = arg_str + + for name, atype, default_val, pos in forward_attrs_list: + inputs_call_list[pos] = name + if default_val is not None: + inputs_args_declaration_list[ + pos] = f"{atype} {name} = {default_val}" + else: + inputs_args_declaration_list[pos] = f"{atype} {name}" + inputs_args_definition_list[pos] = f"{atype} {name}" + + inputs_args_declaration_str = ", ".join(inputs_args_declaration_list) + inputs_args_definition_str = ", ".join(inputs_args_definition_list) + inputs_call_args_str = ", ".join(inputs_call_list) + + # Forward Full Logic + function_name = forward_api_name + if len(intermediate_outputs) > 0: + function_name = GetIntermediateAPIFunctionName(function_name) + + forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" + + # Get return type list & outputs + num_outputs = len(forward_outputs_position_map.keys()) - len( + intermediate_outputs) + returns_type_list = ["" for i in range(num_outputs)] + returns_list = ["" for i in range(num_outputs)] + for name, (rtype, pos) in forward_outputs_position_map.items(): + if name in intermediate_outputs: + continue + if num_outputs == 1: + returns_list[0] = f"api_result" + else: + # Tuple api_result + returns_list[pos] = f"std::get<{pos}>(api_result)" + + if IsPlainTensorType(rtype): + returns_type_list[pos] = "paddle::experimental::Tensor" + else: + assert IsVectorTensorType(rtype) + returns_type_list[ + pos] = "std::vector" + + if num_outputs == 1: + returns_str = returns_list[0] + returns_type_str = returns_type_list[0] + else: + returns_type_str = ", ".join(returns_type_list) + returns_type_str = f"std::tuple<{returns_type_str}>" + returns_str = ", ".join(returns_list) + returns_str = f"std::make_tuple({returns_str})" + + self.GenerateNodeCreationCodes(forward_call_str) + + node_creation_str = self.node_creation_str + dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" + forward_function_name = GetDygraphForwardFunctionName(forward_api_name) + + self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( + returns_type_str, forward_function_name, inputs_args_definition_str, + dygraph_event_str, node_creation_str, returns_str) + self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" + + print("Generated Forward Definition: ", self.forward_definition_str) + print("Generated Forward Declaration: ", self.forward_declaration_str) + + def GenerateNodeCreationCodes(self, forward_call_str): + forward_api_name = self.forward_api_name + forward_inputs_position_map = self.forward_inputs_position_map + forward_outputs_position_map = self.forward_outputs_position_map + forward_attrs_list = self.forward_attrs_list + backward_forward_inputs_map = self.backward_forward_inputs_map + backward_grad_inputs_map = self.backward_grad_inputs_map + backward_grad_outputs_map = self.backward_grad_outputs_map + backward_attrs_list = self.backward_attrs_list + optional_inputs = self.optional_inputs + inplace_map = self.inplace_map + + # Get Input AutoGradMeta + inputs_autograd_meta_list = [] + compute_require_grad_args_list = ["trace_backward"] + for name, (ttype, pos) in forward_inputs_position_map.items(): + input_autograd_meta_name = GetAutoGradMetaName(name) + if IsPlainTensorType(ttype): + input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" + input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) + inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) + compute_require_grad_args_str = ",".join(compute_require_grad_args_list) + + # Get Output AutoGradMeta + outputs_autograd_meta_list = [] + pass_stop_gradient_args_list = ["false"] + num_fwd_outputs = len(forward_outputs_position_map.keys()) + for name, (rtype, pos) in forward_outputs_position_map.items(): + output_autograd_meta_name = GetAutoGradMetaName(name) + output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) + if num_fwd_outputs == 1: + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + else: + # Tuple api_result + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + + outputs_autograd_meta_list.append(output_autograd_meta) + pass_stop_gradient_args_list.append(output_autograd_meta_name) + + # ComputeRequireGrad & PassStopGradient + outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) + pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) + + # Check Inplace + check_inplace_str = "" + bump_inplace_version_str = "" + for inplace_name in inplace_map.keys(): + inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) + bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name) + + # Node Construction + num_backward_inputs = len(backward_grad_inputs_map.keys()) + num_backward_outputs = len(backward_grad_outputs_map.keys()) + grad_node_name = GetGradNodeName(forward_api_name) + + node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_backward_inputs}, {num_backward_outputs});" + + # SetAttributes + set_attributes_list = [] + forward_attrs_name_set = set() + for name, _, _, _ in forward_attrs_list: + forward_attrs_name_set.add(name) + + for name, _, default_val_attr, _ in backward_attrs_list: + if name in forward_attrs_name_set: + set_attributes = f" grad_node->SetAttribute{name}({name});" + else: + set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});" + set_attributes_list.append(set_attributes) + set_attributes_str = "\n".join(set_attributes_list) + + # SetTensorWrappers + set_tensor_wrappers_list = [] + for name, (atype, is_fwd_input, + pos) in backward_forward_inputs_map.items(): + is_optional = (name in optional_inputs) + + if is_fwd_input: + if is_optional: + set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" + else: + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" + else: + if num_fwd_outputs > 1: + # Aligned with forward output position + assert name in forward_outputs_position_map.keys() + fwd_output_pos = forward_outputs_position_map[name][1] + tw_name = f"std::get<{fwd_output_pos}>(api_result)" + else: + tw_name = f"api_result" + + if is_optional: + set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" + else: + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" + set_tensor_wrappers_list.append(set_tensor_wrappers) + set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) + + # SetGradOutMeta & SetEdges + set_grad_out_meta_list = [] + set_edges_list = [] + for name, (_, pos) in forward_inputs_position_map.items(): + input_autograd_meta_name = GetAutoGradMetaName(name) + set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" + set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" + set_grad_out_meta_list.append(set_grad_out_meta) + set_edges_list.append(set_edges) + set_grad_out_meta_str = "\n".join(set_grad_out_meta_list) + set_edges_str = "\n".join(set_edges_list) + + # SetOutRank & SetHistory & SetGradInMeta + set_out_rank_list = [] + set_history_list = [] + set_grad_in_meta_list = [] + set_retain_grad_list = [] + num_outputs = len(forward_outputs_position_map.keys()) + for name, (_, pos) in forward_outputs_position_map.items(): + output_autograd_meta_name = GetAutoGradMetaName(name) + set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" + set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" + + if num_outputs == 1: + set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" + set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});" + else: + set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));" + set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});" + set_out_rank_list.append(set_out_rank) + set_history_list.append(set_history) + set_grad_in_meta_list.append(set_grad_in_meta) + set_retain_grad_list.append(set_retain_grad) + + set_out_rank_str = "\n".join(set_out_rank_list) + set_history_str = "\n".join(set_history_list) + set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) + set_retain_grad_str = "\n".join(set_retain_grad_list) + + node_event_name = forward_api_name + " node_creation" + node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" + + self.node_creation_str = NODE_CREATION_TEMPLATE.format( + inputs_autograd_meta_str, compute_require_grad_args_str, + check_inplace_str, forward_call_str, bump_inplace_version_str, + node_creation_event_str, outputs_autograd_meta_str, + pass_stop_gradient_args_str, node_construction_str, + set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str, + set_edges_str, set_out_rank_str, set_history_str, + set_grad_in_meta_str, set_retain_grad_str) + + def GenerateInplacedForwardDygraphFunctions(self): + # Inplaced Version Dygraph Function Generation + forward_api_name = self.forward_api_name + forward_api_contents = self.forward_api_contents + + if forward_api_name != "sum" and "inplace" in forward_api_contents.keys( + ): + # Node Definition Generation + self.GenerateForwardDefinition(is_inplaced=True) + self.UpdateCoreOpsInformation(is_inplaced=True) + + def UpdateCoreOpsInformation(self, is_inplaced): + forward_api_name = GetInplacedFunctionName( + self.forward_api_name) if is_inplaced else self.forward_api_name + forward_inputs_position_map = self.forward_inputs_position_map + forward_outputs_position_map = self.forward_outputs_position_map + forward_attrs_list = self.forward_attrs_list + + num_args = len(forward_inputs_position_map.keys()) + len( + forward_attrs_list) + num_returns = len(forward_outputs_position_map.keys()) + + final_state_fwd_api_name = "final_state_" + forward_api_name + core_ops_returns_info[ + final_state_fwd_api_name] = ["" for i in range(num_returns)] + core_ops_args_info[ + final_state_fwd_api_name] = ["" for i in range(num_args)] + core_ops_args_type_info[ + final_state_fwd_api_name] = ["" for i in range(num_args)] + for name, (ttype, pos) in forward_inputs_position_map.items(): + core_ops_args_info[final_state_fwd_api_name][pos] = name + if IsPlainTensorType(ttype): + core_ops_args_type_info[final_state_fwd_api_name][ + pos] = "tensor" + else: + assert IsVectorTensorType(ttype) + core_ops_args_type_info[final_state_fwd_api_name][pos] = "list" + + for name, _, _, pos in forward_attrs_list: + core_ops_args_info[final_state_fwd_api_name][pos] = name + + for name, (ttype, pos) in forward_outputs_position_map.items(): + core_ops_returns_info[final_state_fwd_api_name][pos] = name + + def run(self): + # Basic Validation Check + self.DygraphYamlValidationCheck() + + ########################## + ## Parsing Raw Contents ## + ########################## + # Parse inplace_map + self.ParseInplaceInfo() + + # Parse no_need_buffer + self.ParseNoNeedBuffer() + + # Parse optional_inputs + self.ParseDispensable() + + # Parse intermediate_outputs + self.ParseIntermediate() + self.IntermediateValidationCheck() + + # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list + self.CollectBackwardInfo() + + # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list + self.CollectForwardInfoFromBackwardContents() + + # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list + self.CollectOriginalForwardInfo() + + # Forwards Validation Check + self.ForwardsValidationCheck() + + ############################# + ## Process Parsed Contents ## + ############################# + # Initialize forward_inputs_position_map, forward_outputs_position_map + self.DetermineForwardPositionMap(self.forward_inputs_list, + self.forward_returns_list) + + # Initialize forward_inputs_position_map, forward_outputs_position_map + self.SlotNameMatching() + + # Backward Validation Check + self.BackwardValidationCheck() + + ##################### + ## Code Generation ## + ##################### + self.GenerateNodeDeclaration() + self.GenerateNodeDefinition() + self.GenerateForwardDefinition(is_inplaced=False) + + self.UpdateCoreOpsInformation(is_inplaced=False) + + self.GenerateInplacedForwardDygraphFunctions() + + +class DygraphYamlGenerator(YamlGeneratorBase): + def __init__(self, api_yaml_path, backward_yaml_path): + # Parent members: + # self.namespace + # self.api_yaml_path + # self.forward_api_list + YamlGeneratorBase.__init__(self, api_yaml_path) + + self.backward_yaml_path = backward_yaml_path + self.grad_api_dict = {} + + self.forward_definition_str = "" + self.forward_declaration_str = "" + self.node_declaration_str = "" + self.node_definition_str = "" + + def ParseYamlContents(self): + self.ParseForwardYamlContents() + + backward_yaml_path = self.backward_yaml_path + self.grad_api_dict = ReadBwdFile(backward_yaml_path) + + def GetBackwardAPIContents(self, forward_api_contents): + grad_api_dict = self.grad_api_dict + + if 'backward' not in forward_api_contents.keys(): return None + + backward_api_name = forward_api_contents['backward'] + assert backward_api_name in grad_api_dict.keys() + backward_api_contents = grad_api_dict[backward_api_name] + + return backward_api_contents + + def GenerateCode(self): + forward_api_list = self.forward_api_list + grad_api_dict = self.grad_api_dict + namespace = self.namespace + + for forward_api_contents in forward_api_list: + backward_api_contents = self.GetBackwardAPIContents( + forward_api_contents) + if backward_api_contents is None: continue + + d_generator = DygraphSingleFunctionGenerator( + forward_api_contents, backward_api_contents, namespace) + d_generator.run() + + self.forward_definition_str += d_generator.forward_definition_str + "\n" + self.forward_declaration_str += d_generator.forward_declaration_str + "\n" + self.node_declaration_str += d_generator.node_declaration_str + "\n" + self.node_definition_str += d_generator.node_definition_str + "\n" + + if len(namespace) > 0: + if namespace.endswith("::"): + namespace = namespace[:-2] + self.forward_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format( + namespace, self.forward_definition_str) + self.forward_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format( + namespace, self.forward_declaration_str) + self.node_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format( + namespace, self.node_declaration_str) + self.node_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format( + namespace, self.node_definition_str) + + def run(self): + self.ParseYamlContents() + + self.InferNameSpace() + + self.GenerateCode() + + +################## +## File Writers ## +################## def GenerateNodeCCFile(filepath, node_definition_str): - file_contents = """ -#include "glog/logging.h" -#include "paddle/phi/api/all.h" -#include "paddle/phi/api/backward/backward_api.h" -#include "paddle/fluid/imperative/tracer.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/eager/utils.h" -#include "paddle/fluid/eager/api/utils/global_utils.h" -#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" -#include "paddle/fluid/eager/to_static/run_program_op_node.h" + if os.path.exists(filepath): + os.remove(filepath) -#include "paddle/phi/api/backward/sparse_bw_api.h" -""" - file_contents += node_definition_str + file_contents = NODE_CC_FILE_TEMPLATE.format(node_definition_str) with open(filepath, 'a') as f: f.write(file_contents) def GenerateNodeHFile(filepath, node_declaration_str): - file_contents = """ -#pragma once -#include "paddle/fluid/eager/tensor_wrapper.h" -#include "paddle/fluid/eager/grad_node_info.h" + if os.path.exists(filepath): + os.remove(filepath) -""" - file_contents += node_declaration_str + file_contents = NODE_H_FILE_TEMPLATE.format(node_declaration_str) with open(filepath, 'a') as f: f.write(file_contents) def GenerateForwardCCFile(filepath, forward_definition_str): - file_contents = """ -#include "paddle/phi/api/lib/dygraph_api.h" -#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" -#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" - -#include "paddle/phi/api/include/sparse_api.h" -#include "paddle/fluid/eager/api/utils/global_utils.h" -#include "paddle/fluid/platform/profiler/event_tracing.h" - -""" + if os.path.exists(filepath): + os.remove(filepath) - file_contents += GenerateCoreOpInfoDefinition() - file_contents += forward_definition_str + core_ops_info_str = GenerateCoreOpInfoDefinition() + file_contents = FORWARD_CC_FILE_TEMPLATE.format(core_ops_info_str, + forward_definition_str) with open(filepath, 'a') as f: f.write(file_contents) def GenerateForwardHFile(filepath, forward_function_declaration_str): - file_contents = """ -#pragma once -#include "glog/logging.h" -#include "paddle/fluid/eager/autograd_meta.h" -#include "paddle/phi/api/all.h" -#include "paddle/fluid/eager/utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/eager/to_static/run_program_op_func.h" + if os.path.exists(filepath): + os.remove(filepath) -""" - file_contents += GenerateCoreOpInfoDeclaration() - file_contents += forward_function_declaration_str + core_ops_info_str = GenerateCoreOpInfoDeclaration() + file_contents = FORWARD_H_FILE_TEMPLATE.format( + core_ops_info_str, forward_function_declaration_str) with open(filepath, 'a') as f: f.write(file_contents) @@ -1224,199 +1217,13 @@ if __name__ == "__main__": api_yaml_path = api_yaml_paths[i] backward_yaml_path = backward_yaml_paths[i] - if "sparse" in api_yaml_path: - assert "sparse" in backward_yaml_path - namespace = "sparse" - else: - namespace = "" - - fwd_api_list = ReadFwdFile(api_yaml_path) - grad_api_dict = ReadBwdFile(backward_yaml_path) - - yaml_forward_definition_str = "" - yaml_forward_declaration_str = "" - yaml_node_declaration_str = "" - yaml_node_definition_str = "" - for fwd_api in fwd_api_list: - # We only generate Ops with grad - if 'backward' not in fwd_api.keys(): - continue + generator = DygraphYamlGenerator(api_yaml_path, backward_yaml_path) + generator.run() - assert 'api' in fwd_api.keys() - assert 'args' in fwd_api.keys() - assert 'output' in fwd_api.keys() - assert 'backward' in fwd_api.keys() - - no_need_buffer_set = set() - if 'no_need_buffer' in fwd_api.keys(): - no_need_buffer_set = ParseNoNeedBuffer(fwd_api[ - 'no_need_buffer']) - - fwd_api_name = fwd_api['api'] - fwd_args_str = fwd_api['args'] - fwd_returns_str = fwd_api['output'] - - inplace_map = {} - if 'inplace' in fwd_api.keys(): - inplace_map = ParseInplaceInfo(fwd_api['inplace']) - - bwd_api_name = fwd_api['backward'] - assert bwd_api_name in grad_api_dict.keys(), bwd_api_name - bwd_api = grad_api_dict[bwd_api_name] - - assert 'args' in bwd_api.keys() - assert 'output' in bwd_api.keys() - assert 'forward' in bwd_api.keys() - - # Parse Dispensable Inputs - optional_inputs = [] - if 'optional' in fwd_api.keys(): - optional_inputs = ParseDispensable(fwd_api['optional']) - - bwd_forward_str = bwd_api['forward'] - bwd_args_str = bwd_api['args'] - bwd_returns_str = bwd_api['output'] - - # Collect Forward Inputs/Outputs - forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward( - bwd_forward_str) - print("Parsed Forward Inputs List: ", forward_inputs_list) - print("Prased Forward Attrs List: ", forward_attrs_list) - print("Parsed Forward Returns List: ", forward_returns_list) - - intermediate_outputs = [] - if 'intermediate' in fwd_api.keys(): - intermediate_outputs = ParseIntermediate(fwd_api[ - 'intermediate']) - - IntermediateValidationCheck(intermediate_outputs, - forward_returns_list) - - # Collect Original Forward Inputs/Outputs and then perform validation checks - orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( - fwd_args_str, fwd_returns_str) - print("Parsed Original Forward Inputs List: ", - orig_forward_inputs_list) - print("Prased Original Forward Attrs List: ", - orig_forward_attrs_list) - print("Parsed Original Forward Returns List: ", - orig_forward_returns_list) - - # Forward Validation Checks - ForwardsValidationCheck( - forward_inputs_list, forward_attrs_list, forward_returns_list, - orig_forward_inputs_list, orig_forward_attrs_list, - orig_forward_returns_list) - - # Parse Backward Inputs/Outputs - backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward( - bwd_args_str, bwd_returns_str) - print("Parsed Backward Inputs List: ", backward_inputs_list) - print("Prased Backward Attrs List: ", backward_attrs_list) - print("Parsed Backward Returns List: ", backward_returns_list) - - # Determine Forward Inputs/Outputs Position - forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap( - forward_inputs_list, forward_returns_list) - print("Generated Forward Input Position Map: ", - forward_inputs_position_map) - print("Generated Forward Output Position Map: ", - forward_outputs_position_map) - - # SlotName Matching - backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching( - backward_inputs_list, backward_returns_list, - forward_inputs_position_map, forward_outputs_position_map) - print("Generated Backward Fwd Input Map: ", backward_fwd_input_map) - print("Generated Backward Grad Input Map: ", - backward_grad_input_map) - print("Generated Backward Grad Output Map: ", - backward_grad_output_map) - - # Backward Validation Check - BackwardValidationCheck(backward_fwd_input_map, - backward_grad_input_map, - backward_attrs_list) - - # Node Declaration Generation - yaml_node_declaration_str += GenerateNodeDeclaration( - fwd_api_name, backward_fwd_input_map, backward_attrs_list, - no_need_buffer_set) - print("Generated Node Declaration: ", node_declaration_str) - - yaml_node_definition_str += GenerateNodeDefinition( - fwd_api_name, bwd_api_name, backward_fwd_input_map, - backward_grad_input_map, backward_grad_output_map, - backward_attrs_list) - print("Generated Node Definition: ", node_definition_str) - - # Node Definition Generation - definition_declaration_pair = GenerateForwardDefinition( - fwd_api_name, bwd_api_name, forward_inputs_position_map, - forward_outputs_position_map, orig_forward_attrs_list, - backward_fwd_input_map, backward_grad_input_map, - backward_grad_output_map, backward_attrs_list, optional_inputs, - intermediate_outputs, {}) - print("Generated Forward Definition: ", forward_definition_str) - print("Generated Forward Declaration: ", forward_declaration_str) - yaml_forward_definition_str += definition_declaration_pair[0] - yaml_forward_declaration_str += definition_declaration_pair[1] - - # For python-level API dispatch - CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, - forward_outputs_position_map, - orig_forward_attrs_list) - - # Inplaced Version Dygraph Function Generation - if fwd_api_name != "sum" and "inplace" in fwd_api.keys(): - fwd_api_name_inplaced = GetInplacedFunctionName(fwd_api_name) - - # Node Definition Generation - definition_declaration_pair = GenerateForwardDefinition( - fwd_api_name_inplaced, bwd_api_name, - forward_inputs_position_map, forward_outputs_position_map, - forward_attrs_list, backward_fwd_input_map, - backward_grad_input_map, backward_grad_output_map, - backward_attrs_list, optional_inputs, intermediate_outputs, - inplace_map) - print("Generated Inplaced Forward Definition: ", - forward_definition_str) - print("Generated Inplaced Forward Declaration: ", - forward_declaration_str) - forward_definition_str += definition_declaration_pair[0] - forward_declaration_str += definition_declaration_pair[1] - - # For python-level API dispatch - CollectCoreOpsInformation( - fwd_api_name_inplaced, forward_inputs_position_map, - forward_outputs_position_map, forward_attrs_list) - - if len(namespace) > 0: - forward_definition_str += f"""namespace {namespace} {{ - {yaml_forward_definition_str} -}} -""" - - forward_declaration_str += f"""namespace {namespace} {{ - {yaml_forward_declaration_str} -}} -""" - - node_declaration_str += f"""namespace {namespace} {{ - {yaml_node_declaration_str} -}} -""" - - node_definition_str += f"""namespace {namespace} {{ - {yaml_node_definition_str} -}} -""" - - else: - forward_definition_str += yaml_forward_definition_str - forward_declaration_str += yaml_forward_declaration_str - node_declaration_str += yaml_node_declaration_str - node_definition_str += yaml_node_definition_str + node_declaration_str += generator.node_declaration_str + "\n" + node_definition_str += generator.node_definition_str + "\n" + forward_definition_str += generator.forward_definition_str + "\n" + forward_declaration_str += generator.forward_declaration_str + "\n" # Generate Files nodes_h_path = args.nodes_h_path @@ -1424,12 +1231,6 @@ if __name__ == "__main__": forwards_h_path = args.forwards_h_path forwards_cc_path = args.forwards_cc_path - for path in [ - nodes_cc_path, nodes_h_path, forwards_h_path, forwards_cc_path - ]: - if os.path.exists(path): - os.remove(path) - GenerateNodeCCFile(nodes_cc_path, node_definition_str) GenerateNodeHFile(nodes_h_path, node_declaration_str) GenerateForwardCCFile(forwards_cc_path, forward_definition_str) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 5a732212a5..c7be9480f5 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -15,7 +15,10 @@ import os import argparse import logging -from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap, GetInplacedFunctionName, ParseInplaceInfo +from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase +from codegen_utils import yaml_types_mapping +from codegen_utils import ReadFwdFile, IsVectorTensorType, GetForwardFunctionName +from codegen_utils import ParseYamlForward, GetInplacedFunctionName ########################### ## Global Configurations ## @@ -121,7 +124,10 @@ FUNCTION_NAME_TEMPLATE = \ PYTHON_C_FUNCTION_REG_TEMPLATE = \ -"{{\"final_state_{}\", (PyCFunction)(void(*)(void)) {}eager_final_state_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}}" +""" +{{\"final_state_{}\", (PyCFunction)(void(*)(void)) {}eager_final_state_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}} + +""" PYTHON_C_WRAPPER_TEMPLATE = \ @@ -229,77 +235,39 @@ NAMESPACE_WRAPPER_TEMPLATE = \ ####################### ## Generator Classes ## ####################### -class PythonCSingleFunctionGenerator: - def __init__(self, fwd_api_contents, namespace): - self.fwd_api_contents = fwd_api_contents - self.namespace = namespace - - # Raw Contents - self.forward_api_name = "" - self.forward_args_str = "" - self.forward_returns_str = "" - - # Raw Data - self.forward_attrs_list = None #[ [attr_name, attr_type, default_value, orig_position], ...] - self.forward_inputs_list = None #[ [arg_name, arg_type, orig_position], ...] - self.forward_returns_list = None #[ [ret_name, ret_type, orig_position], ...] - - # Processed Data - self.forward_inputs_position_map = None #{ "name" : [type, fwd_position] } - self.forward_outputs_position_map = None #{ "name" : [type, fwd_position] } - - # Special Op Attributes - self.optional_inputs = [] #[name, ...] +class PythonCSingleFunctionGenerator(FunctionGeneratorBase): + def __init__(self, forward_api_contents, namespace): + # Members from Parent: + #self.namespace + #self.forward_api_contents + #self.forward_api_name + #self.orig_forward_inputs_list + #self.orig_forward_attrs_list + #self.orig_forward_returns_list + #self.forward_inputs_position_map + #self.forward_outputs_position_map + #self.optional_inputs + #self.no_need_buffers + #self.intermediate_outputs + #self.inplace_map + FunctionGeneratorBase.__init__(self, forward_api_contents, namespace) + self.is_forward_only = True # Generated Results self.python_c_function_str = "" self.python_c_function_reg_str = "" - def CollectRawContents(self): - fwd_api_contents = self.fwd_api_contents - - assert 'api' in fwd_api_contents.keys( - ), "Unable to find \"api\" in fwd_api_contents keys" - assert 'args' in fwd_api_contents.keys( - ), "Unable to find \"args\" in fwd_api_contents keys" - assert 'output' in fwd_api_contents.keys( - ), "Unable to find \"output\" in fwd_api_contents keys" - - self.forward_api_name = fwd_api_contents['api'] - self.forward_args_str = fwd_api_contents['args'] - self.forward_returns_str = fwd_api_contents['output'] - def CollectIsForwardOnly(self): - fwd_api_contents = self.fwd_api_contents - self.is_forward_only = False if 'backward' in fwd_api_contents.keys( + forward_api_contents = self.forward_api_contents + self.is_forward_only = False if 'backward' in forward_api_contents.keys( ) else True - def CollectOptionalInputs(self): - fwd_api_contents = self.fwd_api_contents - if 'optional' in fwd_api_contents.keys(): - self.optional_inputs = ParseDispensable(fwd_api_contents[ - 'optional']) - - def CollectForwardInOutAttr(self): - forward_args_str = self.forward_args_str - forward_returns_str = self.forward_returns_str - - self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForward( - forward_args_str, forward_returns_str) - - def CollectForwardPositionMap(self): - forward_inputs_list = self.forward_inputs_list - forward_returns_list = self.forward_returns_list - - self.forward_inputs_position_map, self.forward_outputs_position_map = DetermineForwardPositionMap( - forward_inputs_list, forward_returns_list) - - def GeneratePythonCFunction(self, inplace_map): + def GeneratePythonCFunction(self): namespace = self.namespace - forward_api_name = GetInplacedFunctionName( - self.forward_api_name) if inplace_map else self.forward_api_name - forward_attrs_list = self.forward_attrs_list + inplace_map = self.inplace_map + forward_api_name = self.forward_api_name + orig_forward_attrs_list = self.orig_forward_attrs_list forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map optional_inputs = self.optional_inputs @@ -326,7 +294,7 @@ class PythonCSingleFunctionGenerator: parse_attributes_str = "" # Generate Python-C Attributes Parsing Logic - for name, atype, _, pos in forward_attrs_list: + for name, atype, _, pos in orig_forward_attrs_list: parsing_function_name = FindParsingFunctionFromAttributeType(atype) parse_attributes_str += PARSE_PYTHON_C_ARGS_TEMPLATE.format( name, pos, atype, name, parsing_function_name, name, @@ -334,11 +302,11 @@ class PythonCSingleFunctionGenerator: # Generate Dygraph Function Call Logic num_args = len(forward_inputs_position_map.keys()) + len( - forward_attrs_list) + orig_forward_attrs_list) dygraph_function_call_list = ["" for i in range(num_args)] for name, (_, pos) in forward_inputs_position_map.items(): dygraph_function_call_list[pos] = f"{name}" - for name, _, _, pos in forward_attrs_list: + for name, _, _, pos in orig_forward_attrs_list: dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_str = ",".join(dygraph_function_call_list) @@ -350,17 +318,7 @@ class PythonCSingleFunctionGenerator: fwd_function_name = FUNCTION_NAME_TEMPLATE.format( "::", namespace, GetForwardFunctionName(forward_api_name)) - if inplace_map: - assert len( - inplace_map - ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" - for inplace_input, inplace_output in inplace_map.items(): - return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format( - forward_api_name, inplace_input, forward_api_name, - inplace_output) - break - else: - return_str = " return ToPyObject(out);" + return_str = " return ToPyObject(out);" # Generate Record Event for performance profiling pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format( @@ -374,29 +332,56 @@ class PythonCSingleFunctionGenerator: self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( forward_api_name, namespace, forward_api_name, forward_api_name) - def run(self, inplace_map): + if len(inplace_map) > 0: + inplaced_forward_api_name = GetInplacedFunctionName( + self.forward_api_name) + assert len( + inplace_map + ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" + for inplace_input, inplace_output in inplace_map.items(): + return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format( + inplaced_forward_api_name, inplace_input, + inplaced_forward_api_name, inplace_output) + break + + self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( + inplaced_forward_api_name, pythonc_record_event_str, + inplaced_forward_api_name, get_eager_tensor_str, + parse_attributes_str, fwd_function_name, + dygraph_function_call_str, return_str) + + # Generate Python-C Function Registration + self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format( + inplaced_forward_api_name, namespace, inplaced_forward_api_name, + inplaced_forward_api_name) + + def run(self): # Initialized is_forward_only self.CollectIsForwardOnly() - # Initialized forward_api_name, forward_args_str, forward_returns_str - self.CollectRawContents() - if SkipAPIGeneration(self.forward_api_name): return False - # Initialized optional_inputs - self.CollectOptionalInputs() + self.ParseDispensable() + + # Initialized inplace_map + self.ParseInplaceInfo() - # Initialized forward_inputs_list, forward_returns_list, forward_attrs_list - self.CollectForwardInOutAttr() + # Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list + self.CollectOriginalForwardInfo() logging.info( - f"Parsed Original Forward Inputs List: \n{self.forward_inputs_list}") + f"Parsed Original Forward Inputs List: \n{self.orig_forward_inputs_list}" + ) logging.info( - f"Prased Original Forward Attrs List: \n{self.forward_attrs_list}") + f"Prased Original Forward Attrs List: \n{self.orig_forward_attrs_list}" + ) logging.info( - f"Parsed Original Forward Returns List: \n{self.forward_returns_list}" + f"Parsed Original Forward Returns List: \n{self.orig_forward_returns_list}" ) + if SkipAPIGeneration(self.forward_api_name): return False + # Initialized forward_inputs_position_map, forward_outputs_position_map - self.CollectForwardPositionMap() + self.DetermineForwardPositionMap(self.orig_forward_inputs_list, + self.orig_forward_returns_list) logging.info( f"Generated Forward Input Position Map: {self.forward_inputs_position_map}" ) @@ -405,7 +390,7 @@ class PythonCSingleFunctionGenerator: ) # Code Generation - self.GeneratePythonCFunction(inplace_map) + self.GeneratePythonCFunction() logging.info( f"Generated Python-C Function: {self.python_c_function_str}") logging.info( @@ -415,21 +400,18 @@ class PythonCSingleFunctionGenerator: return True -class PythonCYamlGenerator: +class PythonCYamlGenerator(YamlGeneratorBase): def __init__(self, path): - self.yaml_path = path - - self.namespace = "" - self.forward_api_list = [] + # Parent members: + # self.namespace + # self.api_yaml_path + # self.forward_api_list + YamlGeneratorBase.__init__(self, api_yaml_path) # Generated Result self.python_c_functions_reg_str = "" self.python_c_functions_str = "" - def ParseYamlContents(self): - yaml_path = self.yaml_path - self.forward_api_list = ReadFwdFile(yaml_path) - def GeneratePythonCFunctions(self): namespace = self.namespace forward_api_list = self.forward_api_list @@ -437,28 +419,12 @@ class PythonCYamlGenerator: for forward_api_content in forward_api_list: f_generator = PythonCSingleFunctionGenerator(forward_api_content, namespace) - status = f_generator.run({}) + status = f_generator.run() if status == True: self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n" self.python_c_functions_str += f_generator.python_c_function_str + "\n" - if 'inplace' in forward_api_content.keys(): - inplace_map = ParseInplaceInfo(forward_api_content['inplace']) - - f_generator_inplace = PythonCSingleFunctionGenerator( - forward_api_content, namespace) - status = f_generator_inplace.run(inplace_map) - - if status == True: - self.python_c_functions_reg_str += f_generator_inplace.python_c_function_reg_str + ",\n" - self.python_c_functions_str += f_generator_inplace.python_c_function_str + "\n" - - def InferNameSpace(self): - yaml_path = self.yaml_path - if "sparse" in yaml_path: - self.namespace = "sparse::" - def AttachNamespace(self): namespace = self.namespace python_c_functions_str = self.python_c_functions_str @@ -474,7 +440,7 @@ class PythonCYamlGenerator: self.InferNameSpace() # Read Yaml file - self.ParseYamlContents() + self.ParseForwardYamlContents() # Code Generation self.GeneratePythonCFunctions() -- GitLab