未验证 提交 68c9e3e4 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[Refactor] refactored eager_gen.py PR #1 (#40815)

* [Refactor] refactored eager_gen.py PR #1

* [Refactor] refactored eager_gen.py PR #1

* Refactored version 2

* Added automatic code generation utils

* Fixed merge issues
上级 7fa3a724
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import yaml
import re
import argparse
import os
########################
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set(list("split"))
# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
core_ops_returns_info = {}
core_ops_args_info = {}
core_ops_args_type_info = {}
yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \
'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>',
'Scalar' : 'paddle::experimental::Scalar',
'ScalarArray' : 'paddle::experimental::ScalarArray'
}
#############################
### File Reader Helpers ###
#############################
def ReadFwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
f.close()
return contents
def ReadBwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {}
for content in contents:
if 'backward_api' in content.keys():
api_name = content['backward_api']
else:
assert False
ret[api_name] = content
f.close()
return ret
##################################
### Generic Helper Functions ###
##################################
def FindGradName(string):
return string + "_grad"
def FindForwardName(string):
if not string.endswith("_grad"):
return None
return string[:-5]
def IsPlainTensorType(string):
plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
if string in plain_tensor_types:
return True
return False
def IsVectorTensorType(string):
vector_tensor_types = [
'std::vector<std::vector<Tensor>>', 'std::vector<Tensor>'
]
if string in vector_tensor_types:
return True
return False
def GetSavedName(string):
return string + "_"
def GetConstReference(string):
ret = string
if not string.startswith("const "):
ret = "const " + string
if not string.endswith("&"):
ret += "&"
return ret
def RemoveConstAndReference(string):
ret = string
if string.startswith("const "):
ret = ret[6:]
if string.endswith("&"):
ret = ret[:-1]
return ret
def GetGradNodeName(string):
return f"FinalGradNode{string}"
def GetDygraphForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
def GetIntermediateAPIFunctionName(string):
return string + "_intermediate"
def GetAutoGradMetaName(string):
return f"{string}_autograd_meta"
def GetAutoGradMetaVectorName(string):
return f"{string}_autograd_meta_vec"
def RemoveSpecialSymbolsInName(string):
# Remove any name after '@'
ret = string.split("@")[0]
return ret
def RecoverBaseNameOfInplaceFunction(function_name):
return function_name[:-1]
def GetInplacedFunctionName(function_name):
return function_name + "_"
def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
######################
### Yaml Parsers ###
######################
def ParseYamlArgs(string):
# Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y
# inputs_list = [ [arg_name, arg_type, orig_position], ...]
inputs_list = []
# attrs_list = [ [arg_name, arg_type, default_value, orig_position], ...]
attrs_list = []
args = [x.strip() for x in string.strip().split(",")]
atype = r'((const )?\S+) '
aname = r'(.*)'
pattern = f'{atype}{aname}'
for i in range(len(args)):
arg = args[i]
m = re.search(pattern, arg)
arg_type = m.group(1).strip()
arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys(
), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name)
if "Tensor" in arg_type:
assert default_value is None
inputs_list.append([arg_name, arg_type, i])
else:
attrs_list.append([arg_name, arg_type, default_value, i])
return inputs_list, attrs_list
def ParseYamlReturns(string):
# Example0: Tensor(out), Tensor(out1)
# Example1: Tensor, Tensor
# Example2: Tensor[](out), Tensor
# list = [ [ret_name, ret_type, orig_position], ...]
returns_list = []
returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret = returns[i]
ret_name = ""
if "(" in ret and ")" in ret:
# Remove trailing ')'
ret = ret[:-1]
ret_type = ret.split("(")[0].strip()
ret_name = ret.split("(")[1].strip()
else:
ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys(
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i])
return returns_list
def ParseYamlForwardFromBackward(string):
# Example: matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out)
fname = r'(.*?)'
wspace = r'\s*'
fargs = r'(.*?)'
frets = r'(.*)'
pattern = f'{fname}{wspace}\({wspace}{fargs}{wspace}\){wspace}->{wspace}{frets}'
m = re.search(pattern, string)
function_name = m.group(1)
function_args = m.group(2)
function_returns = m.group(3)
forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args)
forward_returns_list = ParseYamlReturns(function_returns)
return forward_inputs_list, forward_attrs_list, forward_returns_list
def ParseYamlForward(args_str, returns_str):
# args Example: (const Tensor& x, const Tensor& y, bool transpose_x = false, bool transpose_y = false)
# returns Example: Tensor, Tensor
fargs = r'(.*?)'
wspace = r'\s*'
args_pattern = f'\({fargs}\)'
args_str = re.search(args_pattern, args_str).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturns(returns_str)
return inputs_list, attrs_list, returns_list
def ParseYamlBackward(args_str, returns_str):
# args Example: (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
# returns Example: Tensor(x_grad), Tensor(y_grad)
fargs = r'(.*?)'
wspace = r'\s*'
args_pattern = f'\({fargs}\)'
args_str = re.search(args_pattern, args_str).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturns(returns_str)
return inputs_list, attrs_list, returns_list
########################
### Generator Base ###
########################
class FunctionGeneratorBase:
def __init__(self, forward_api_contents, namespace):
self.forward_api_contents = forward_api_contents
self.namespace = namespace
self.forward_api_name = ""
self.orig_forward_inputs_list = [
] #[ [arg_name, arg_type, orig_position], ...]
self.orig_forward_attrs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.orig_forward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...]
# Processed Forward Data
self.forward_inputs_position_map = {
} #{ "name" : [type, fwd_position] }
self.forward_outputs_position_map = {
} #{ "name" : [type, fwd_position] }
# Special Op Attributes
self.optional_inputs = [] #[name, ...]
self.no_need_buffers = [] #[name, ...]
self.intermediate_outputs = [] #[name, ...]
self.inplace_map = {} #{name : name, ...}
def ParseInplaceInfo(self):
forward_api_contents = self.forward_api_contents
if 'inplace' not in forward_api_contents.keys(): return
# inplace_map_str: "(x -> out0), (y -> out2)"
inplace_map_str = forward_api_contents['inplace']
for pair in inplace_map_str.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]
if pair.endswith(")"):
pair = pair[:-1]
key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
self.inplace_map[key] = val
def ParseNoNeedBuffer(self):
forward_api_contents = self.forward_api_contents
if 'no_need_buffer' in forward_api_contents.keys():
no_need_buffer_str = forward_api_contents['no_need_buffer']
for name in no_need_buffer_str.split(","):
name = name.strip()
name = RemoveSpecialSymbolsInName(name)
self.no_need_buffers.append(name.strip())
def ParseDispensable(self):
forward_api_contents = self.forward_api_contents
if 'optional' in forward_api_contents.keys():
optional_inputs_str = forward_api_contents['optional']
for name in optional_inputs_str.split(","):
name = name.strip()
name = RemoveSpecialSymbolsInName(name)
self.optional_inputs.append(name)
def ParseIntermediate(self):
forward_api_contents = self.forward_api_contents
if 'intermediate' in forward_api_contents.keys():
intermediate_str = forward_api_contents['intermediate']
for name in intermediate_str.split(","):
name = name.strip()
name = RemoveSpecialSymbolsInName(name)
self.intermediate_outputs.append(name)
def CollectOriginalForwardInfo(self):
forward_api_contents = self.forward_api_contents
self.forward_api_name = forward_api_contents['api']
forward_args_str = forward_api_contents['args']
forward_returns_str = forward_api_contents['output']
assert 'api' in forward_api_contents.keys(
), "Unable to find \"api\" in forward_api_contents keys"
assert 'args' in forward_api_contents.keys(
), "Unable to find \"args\" in forward_api_contents keys"
assert 'output' in forward_api_contents.keys(
), "Unable to find \"output\" in forward_api_contents keys"
# Collect Original Forward Inputs/Outputs and then perform validation checks
self.orig_forward_inputs_list, self.orig_forward_attrs_list, self.orig_forward_returns_list = ParseYamlForward(
forward_args_str, forward_returns_str)
def DetermineForwardPositionMap(self, forward_inputs_list,
forward_returns_list):
for i in range(len(forward_inputs_list)):
forward_input = forward_inputs_list[i]
input_name = forward_input[0]
input_type = forward_input[1]
input_pos = forward_input[2]
self.forward_inputs_position_map[
input_name] = [input_type, input_pos]
for i in range(len(forward_returns_list)):
forward_return = forward_returns_list[i]
return_name = forward_return[0]
return_type = forward_return[1]
return_pos = forward_return[2]
self.forward_outputs_position_map[
return_name] = [return_type, return_pos]
print("Generated Forward Input Position Map: ",
self.forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
self.forward_outputs_position_map)
class YamlGeneratorBase:
def __init__(self, api_yaml_path):
self.namespace = ""
self.api_yaml_path = api_yaml_path
self.forward_api_list = []
def ParseForwardYamlContents(self):
api_yaml_path = self.api_yaml_path
self.forward_api_list = ReadFwdFile(api_yaml_path)
def InferNameSpace(self):
api_yaml_path = self.api_yaml_path
if "sparse" in api_yaml_path:
self.namespace = "sparse::"
...@@ -16,31 +16,25 @@ import yaml ...@@ -16,31 +16,25 @@ import yaml
import re import re
import argparse import argparse
import os import os
from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info
ops_to_fill_zero_for_empty_grads = set(list("split")) from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, ReadBwdFile
# For API dispatch used at python-level from codegen_utils import FindGradName, FindForwardName, GetSavedName, GetGradNodeName
# { op_name : [arg_name, ...] } from codegen_utils import IsPlainTensorType, IsVectorTensorType
core_ops_returns_info = {} from codegen_utils import GetConstReference, RemoveConstAndReference
core_ops_args_info = {} from codegen_utils import GetDygraphForwardFunctionName, GetIntermediateAPIFunctionName
core_ops_args_type_info = {} from codegen_utils import GetAutoGradMetaName, GetAutoGradMetaVectorName
from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFunction
namespace = "" from codegen_utils import GetInplacedFunctionName
from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward
yaml_types_mapping = { from codegen_utils import ParseYamlForward, ParseYamlBackward
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ from codegen_utils import ops_to_fill_zero_for_empty_grads
'str' : 'std::string', \
'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', ###########
'Tensor' : 'Tensor', ## Utils ##
'Tensor[]' : 'std::vector<Tensor>', ###########
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>',
'Scalar' : 'paddle::experimental::Scalar',
'ScalarArray' : 'paddle::experimental::ScalarArray'
}
def ParseArguments(): def ParseArguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Eager Code Generator Args Parser') description='Eager Code Generator Args Parser')
...@@ -55,845 +49,129 @@ def ParseArguments(): ...@@ -55,845 +49,129 @@ def ParseArguments():
return args return args
################# ########################
### Helpers ### ## Code Gen Templates ##
################# ########################
def RecoverBaseNameOfInplaceFunction(function_name): SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
return function_name[:-1] """
def GetInplacedFunctionName(function_name):
return function_name + "_"
def FindGradName(string):
return string + "_grad"
def FindForwardName(string):
if not string.endswith("_grad"):
return None
return string[:-5]
def IsPlainTensorType(string):
plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
if string in plain_tensor_types:
return True
return False
def IsVectorTensorType(string):
vector_tensor_types = [
'std::vector<std::vector<Tensor>>', 'std::vector<Tensor>'
]
if string in vector_tensor_types:
return True
return False
def GetSavedName(string):
return string + "_"
def GetConstReference(string):
ret = string
if not string.startswith("const "):
ret = "const " + string
if not string.endswith("&"):
ret += "&"
return ret
def RemoveConstAndReference(string):
ret = string
if string.startswith("const "):
ret = ret[6:]
if string.endswith("&"):
ret = ret[:-1]
return ret
def GetGradNodeName(string):
return f"FinalGradNode{string}"
def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
def GetAutoGradMetaName(string):
return f"{string}_autograd_meta"
def GetAutoGradMetaVectorName(string):
return f"{string}_autograd_meta_vec"
######################
### File Readers ###
######################
def ReadFwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
f.close()
return contents
def ReadBwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {}
for content in contents:
if 'backward_api' in content.keys():
api_name = content['backward_api']
else:
assert False
ret[api_name] = content
f.close()
return ret
######################
### Yaml Parsers ###
######################
def ParseInplaceInfo(string):
# string: "(x -> out0), (y -> out2)"
inplace_map = {}
for pair in string.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]
if pair.endswith(")"):
pair = pair[:-1]
key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
inplace_map[key] = val
return inplace_map
def RemoveSpecialSymbolsInName(string):
# Remove any name after '@'
ret = string.split("@")[0]
return ret
def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
# intermediate_outputs : [name0, name1, ...]
# forward_returns_list : [[ret_name, type, orig_pos], ...]
"""
Check whether intermediate_outputs are positioned
at the very end of forward_returns_list
"""
intermediate_positions = range(
len(forward_returns_list) - len(intermediate_outputs),
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
def ParseDispensable(string):
# string: "X, Y"
string = RemoveSpecialSymbolsInName(string)
return [v.strip() for v in string.split(",")]
def ParseIntermediate(string):
string = RemoveSpecialSymbolsInName(string)
return [v.strip() for v in string.split(",")]
def ParseNoNeedBuffer(string):
# string: "x, y"
string = RemoveSpecialSymbolsInName(string)
no_need_buffer_set = set()
for name in string.split(","):
no_need_buffer_set.add(name.strip())
return no_need_buffer_set
def ParseYamlArgs(string):
# Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y
# inputs_list = [ [arg_name, arg_type, orig_position], ...]
inputs_list = []
# attrs_list = [ [arg_name, arg_type, default_value, orig_position], ...]
attrs_list = []
args = [x.strip() for x in string.strip().split(",")]
atype = r'((const )?\S+) '
aname = r'(.*)'
pattern = f'{atype}{aname}'
for i in range(len(args)):
arg = args[i]
m = re.search(pattern, arg)
arg_type = m.group(1).strip()
arg_name = m.group(3).split("=")[0].strip()
default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys(
), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name)
if "Tensor" in arg_type:
assert default_value is None
inputs_list.append([arg_name, arg_type, i])
else:
attrs_list.append([arg_name, arg_type, default_value, i])
return inputs_list, attrs_list
def ParseYamlReturns(string):
# Example0: Tensor(out), Tensor(out1)
# Example1: Tensor, Tensor
# Example2: Tensor[](out), Tensor
# list = [ [ret_name, ret_type, orig_position], ...]
returns_list = []
returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret = returns[i]
ret_name = ""
if "(" in ret and ")" in ret:
# Remove trailing ')'
ret = ret[:-1]
ret_type = ret.split("(")[0].strip()
ret_name = ret.split("(")[1].strip()
else:
ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys(
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i])
return returns_list
def ParseYamlForwardFromBackward(string):
# Example: matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out)
fname = r'(.*?)'
wspace = r'\s*'
fargs = r'(.*?)'
frets = r'(.*)'
pattern = f'{fname}{wspace}\({wspace}{fargs}{wspace}\){wspace}->{wspace}{frets}'
m = re.search(pattern, string)
function_name = m.group(1)
function_args = m.group(2)
function_returns = m.group(3)
forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args)
forward_returns_list = ParseYamlReturns(function_returns)
return forward_inputs_list, forward_attrs_list, forward_returns_list
def ParseYamlForward(args_str, returns_str):
# args Example: (const Tensor& x, const Tensor& y, bool transpose_x = false, bool transpose_y = false)
# returns Example: Tensor, Tensor
fargs = r'(.*?)'
wspace = r'\s*'
args_pattern = f'\({fargs}\)'
args_str = re.search(args_pattern, args_str).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturns(returns_str)
return inputs_list, attrs_list, returns_list
def ParseYamlBackward(args_str, returns_str):
# args Example: (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
# returns Example: Tensor(x_grad), Tensor(y_grad)
fargs = r'(.*?)'
wspace = r'\s*'
args_pattern = f'\({fargs}\)'
args_str = re.search(args_pattern, args_str).group(1)
inputs_list, attrs_list = ParseYamlArgs(args_str)
returns_list = ParseYamlReturns(returns_str)
return inputs_list, attrs_list, returns_list
#######################
### Preprocessing ###
#######################
def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
forward_returns_list, orig_forward_inputs_list,
orig_forward_attrs_list, orig_forward_returns_list):
for i in range(len(forward_inputs_list)):
forward_input_name = forward_inputs_list[i][0]
forward_input_type = forward_inputs_list[i][1]
forward_input_pos = forward_inputs_list[i][2]
orig_input_name = orig_forward_inputs_list[i][0]
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]
assert forward_input_type == orig_input_type
assert forward_input_pos == orig_input_pos
for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
orig_attr_type = orig_forward_attrs_list[i][1]
orig_attr_default = orig_forward_attrs_list[i][2]
orig_attr_pos = orig_forward_attrs_list[i][3]
forward_attr_name = forward_attrs_list[i][0]
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][1]
orig_return_pos = orig_forward_returns_list[i][2]
forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2]
assert orig_return_type == forward_return_type
assert orig_return_pos == forward_return_pos
# Check Order: Inputs, Attributes
max_input_position = -1
for _, _, pos in forward_inputs_list:
max_input_position = max(max_input_position, pos)
max_attr_position = -1
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position
max_attr_position = max(max_attr_position, pos)
def BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map,
backward_attrs_list):
# Check Order: TensorWrappers, GradTensors, Attributes
max_fwd_input_position = -1
for _, (_, _, pos) in backward_fwd_input_map.items():
max_fwd_input_position = max(max_fwd_input_position, pos)
max_grad_tensor_position = -1
for _, (_, _, pos) in backward_grad_input_map.items():
assert pos > max_fwd_input_position
max_grad_tensor_position = max(max_grad_tensor_position, pos)
max_attr_position = -1
for _, _, _, pos in backward_attrs_list:
assert pos > max_grad_tensor_position
max_attr_position = max(max_attr_position, pos)
def DetermineForwardPositionMap(forward_inputs_list, forward_returns_list):
forward_inputs_position_map = {}
forward_outputs_position_map = {}
for i in range(len(forward_inputs_list)):
forward_input = forward_inputs_list[i]
input_name = forward_input[0]
input_type = forward_input[1]
input_pos = forward_input[2]
forward_inputs_position_map[input_name] = [input_type, input_pos]
for i in range(len(forward_returns_list)):
forward_return = forward_returns_list[i]
return_name = forward_return[0]
return_type = forward_return[1]
return_pos = forward_return[2]
forward_outputs_position_map[return_name] = [return_type, return_pos]
return forward_inputs_position_map, forward_outputs_position_map
def SlotNameMatching(backward_inputs_list, backward_returns_list,
forward_inputs_position_map, forward_outputs_position_map):
backward_fwd_input_map = {}
backward_grad_input_map = {}
backward_grad_output_map = {}
for backward_input in backward_inputs_list:
backward_input_name = backward_input[0]
backward_input_type = backward_input[1]
backward_input_pos = backward_input[2]
backward_fwd_name = FindForwardName(backward_input_name)
if backward_fwd_name:
# Grad Input
assert backward_fwd_name in forward_outputs_position_map.keys()
matched_forward_output_type = forward_outputs_position_map[
backward_fwd_name][0]
matched_forward_output_pos = forward_outputs_position_map[
backward_fwd_name][1]
backward_grad_input_map[backward_input_name] = [
backward_input_type, matched_forward_output_pos,
backward_input_pos
]
else:
# TensorWrapper Input
if backward_input_name in forward_inputs_position_map.keys():
tensor_wrapper_type = forward_inputs_position_map[
backward_input_name][0]
backward_fwd_input_map[backward_input_name] = [
backward_input_type, True, backward_input_pos
]
elif backward_input_name in forward_outputs_position_map.keys():
tensor_wrapper_type = forward_outputs_position_map[
backward_input_name][0]
backward_fwd_input_map[backward_input_name] = [
backward_input_type, False, backward_input_pos
]
else:
assert False, backward_input_name
for backward_output in backward_returns_list:
backward_output_name = backward_output[0]
backward_output_type = backward_output[1]
backward_output_pos = backward_output[2]
backward_fwd_name = FindForwardName(backward_output_name)
assert backward_fwd_name is not None
assert backward_fwd_name in forward_inputs_position_map.keys(
), backward_fwd_name
matched_forward_input_type = forward_inputs_position_map[
backward_fwd_name][0]
matched_forward_input_pos = forward_inputs_position_map[
backward_fwd_name][1]
backward_grad_output_map[backward_output_name] = [
backward_output_type, matched_forward_input_pos, backward_output_pos
]
return backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map
def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
backward_attrs_list, no_need_buffer_set):
# Inputs:
# fwd_api_name = ""
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Determine Node Name
forward_op_name = fwd_api_name
# SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set:
no_need_buffer = "true"
else:
no_need_buffer = "false"
tensor_wrapper_name = GetSavedName(tname)
if IsPlainTensorType(ttype):
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved, {}); {} = egr::TensorWrapper({}, full_reserved, {});
}} }}
""" """
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname, no_need_buffer)
PLAIN_TENSOR_MEMBER_TEMPLATE = """ PLAIN_TENSOR_MEMBER_TEMPLATE = \
egr::TensorWrapper {}; """
egr::TensorWrapper {};
""" """
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
CLEAR_TENSOR_WRAPPERS_TEMPLATE = """ CLEAR_TENSOR_WRAPPER_TEMPLATE = \
{}.clear(); """
{}.clear();
""" """
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
else: SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
assert IsVectorTensorType(ttype) """
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{ for(const auto& eager_tensor : {}) {{
for(const auto& eager_tensor : {}) {{ {}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) );
{}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) ); }};
}}; }}
}}
""" """
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name, no_need_buffer)
VECTOR_TENSOR_MEMBER_TEMPLATE = """ VECTOR_TENSOR_MEMBER_TEMPLATE = \
std::vector<egr::TensorWrapper> {}; """
std::vector<egr::TensorWrapper> {};
""" """
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
CLEAR_TENSOR_WRAPPERS_TEMPLATE = """ CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \
for (auto tw: {}) {
tw.clear();
};
""" """
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format( for (auto tw: {}) {
tensor_wrapper_name) tw.clear();
};
# End: SetTensorWrapper Methods & TensorWrapper Members
# SetAttributes & Attribute Members
set_attribute_methods_str = ""
attribute_members_str = ""
for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname)
SET_ATTR_METHOD_TEMPLATE = """
void SetAttribute{}({} {}) {{
{} = {};
}}
""" """
set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
aname, GetConstReference(atype), aname, saved_attr_name, aname)
if default_val: SET_ATTR_METHOD_TEMPLATE = \
ATTRIBUTE_MEMBER_TEMPLATE = """ """
void SetAttribute{}({} {}) {{
{} = {};
}}
"""
ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = \
"""
{} {} = {}; {} {} = {};
""" """
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name, default_val) ATTRIBUTE_MEMBER_TEMPLATE = \
else: """
ATTRIBUTE_MEMBER_TEMPLATE = """
{} {}; {} {};
"""
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name)
# End: SetAttributes & Attribute Members
grad_node_name = GetGradNodeName(fwd_api_name)
NODE_DECLARATION_TEMPLATE = """
class {} : public egr::GradNodeBase {{
public:
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
""" """
node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
return node_declaration_str
def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list):
# fwd_api_name = ""
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
grad_api_args_len = len(backward_fwd_input_map.keys()) + len(
backward_grad_input_map.keys()) + len(backward_attrs_list)
grad_api_args = ["" for i in range(grad_api_args_len)]
for name, (_, is_fwd_input,
grad_api_position), in backward_fwd_input_map.items():
tensor_wrapper_name = GetSavedName(name)
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_input_map.items():
if IsPlainTensorType(ttype):
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else:
assert IsVectorTensorType(ttype)
grad_api_args[grad_api_position] = f"hooked_grads[{fwd_position}]"
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
grad_api_args[grad_api_position] = f"this->{saved_attribute_name}"
grad_api_args_str = ", ".join(grad_api_args)
# Construct grad_api returns
num_bwd_outputs = len(backward_grad_output_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({num_bwd_outputs});\n"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_output_map.items():
# Infer Grad API Return Type
if num_bwd_outputs == 1:
# Single tensor output, return as is
if IsPlainTensorType(ttype):
returns_str += "returns[0] = { grad_api_returns };\n"
else:
assert IsVectorTensorType(ttype)
returns_str += "returns[0] = grad_api_returns;\n"
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name) NODE_DECLARATION_TEMPLATE = \
"""
class {} : public egr::GradNodeBase {{
public:
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
fill_zero_str = "" bool IsTensorWrappersCleared() override {{
if fwd_api_name in ops_to_fill_zero_for_empty_grads: return is_tensor_wrappers_cleared;
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" }}
private:
// TensorWrappers
{}
if len(namespace) > 0: bool is_tensor_wrappers_cleared = false;
grad_api_namespace = f"paddle::experimental::{namespace}"
else:
grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """ // Attributes
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{ {}
{} }};
auto hooked_grads = ApplyGradientHooks(grads);
// Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\";
auto grad_api_returns = {}::{}({});
{}
}}
"""
node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
bwd_api_name, grad_api_args_str, returns_str)
return node_definition_str
def GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
# Get Output AutoGradMeta
outputs_autograd_meta_list = []
pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += f"""
// Check Inplace
egr::EagerUtils::CheckInplace({inplace_name}, {inplace_autograd_meta_name}, require_any_grad);\n
""" """
bump_inplace_version_str += f""" FUNCTION_TEMPLATE = \
// Bump Inplace Version """
{inplace_name}.bump_inplace_version(); std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
VLOG(3) << \"Tensor(\" << {inplace_name}.name() << \") uses Inplace Strategy.\";\n {}
auto hooked_grads = ApplyGradientHooks(grads);
// Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\";
auto grad_api_returns = {}{}({});
{}
}}
""" """
# Node Construction FORWARD_FUNCTION_TEMPLATE = \
num_bwd_inputs = len(backward_grad_input_map.keys()) """
num_bwd_outputs = len(backward_grad_output_map.keys()) {} {}({}) {{
grad_node_name = GetGradNodeName( {}
RecoverBaseNameOfInplaceFunction(
fwd_api_name)) if inplace_map else GetGradNodeName(fwd_api_name) {}
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
# SetAttributes
set_attributes_list = []
forward_attrs_name_set = set()
for name, _, _, _ in forward_attrs_list:
forward_attrs_name_set.add(name)
for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});"
else:
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
# SetTensorWrappers
set_tensor_wrappers_list = []
for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
if is_optional: // Returns
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" return {};
else: }}
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
# SetGradOutMeta & SetEdges
set_grad_out_meta_list = []
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
set_edges_str = "\n".join(set_edges_list)
# SetOutRank & SetHistory & SetGradInMeta
set_out_rank_list = []
set_history_list = []
set_grad_in_meta_list = []
set_retain_grad_list = []
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
"""
node_creation_event_str = NODE_CREATION_TEMPLATE.format(node_event_name)
NODE_CREATION_TEMPLATE = """ """
NODE_CREATION_TEMPLATE = \
"""
// Get AutoGradMeta // Get AutoGradMeta
{} {}
bool trace_backward = egr::Controller::Instance().HasGrad(); bool trace_backward = egr::Controller::Instance().HasGrad();
...@@ -924,185 +202,72 @@ def GenerateNodeCreationCodes( ...@@ -924,185 +202,72 @@ def GenerateNodeCreationCodes(
{} {}
}} }}
}} }}
"""
NAMESPACE_WRAPPER_TEMPLATE = \
"""
namespace {} {{
{}
}}
""" """
node_creation_str = NODE_CREATION_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str, set_attributes_str,
set_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str,
set_out_rank_str, set_history_str, set_grad_in_meta_str,
set_retain_grad_str)
return node_creation_str
def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs, inplace_map):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# backward_fwd_input_map = { "name" : [type, is_fwd_input, orig_position] ...}
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = ["name0", ...]
# Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
))
inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype):
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else:
if inplace_map and name in inplace_map.keys():
arg_str = f"paddle::experimental::Tensor& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_definition_list[pos] = arg_str NODE_CC_FILE_TEMPLATE = \
inputs_args_declaration_list[pos] = arg_str """
#include "glog/logging.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/backward/sparse_bw_api.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
for name, atype, default_val, pos in forward_attrs_list: #include "paddle/phi/api/include/sparse_api.h"
inputs_call_list[pos] = name
if default_val is not None:
inputs_args_declaration_list[
pos] = f"{atype} {name} = {default_val}"
else:
inputs_args_declaration_list[pos] = f"{atype} {name}"
inputs_args_definition_list[pos] = f"{atype} {name}"
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
inputs_call_args_str = ", ".join(inputs_call_list)
# Forward Full Logic
if len(intermediate_outputs) == 0:
function_name = fwd_api_name
else:
function_name = fwd_api_name + "_intermediate"
if len(namespace) > 0:
forward_call_str = f"auto api_result = paddle::experimental::{namespace}::{function_name}({inputs_call_args_str});"
else:
forward_call_str = f"auto api_result = paddle::experimental::{function_name}({inputs_call_args_str});"
# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1:
returns_list[0] = f"api_result"
else:
# Tuple api_result
returns_list[pos] = f"std::get<{pos}>(api_result)"
if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor"
else:
assert IsVectorTensorType(rtype)
returns_type_list[pos] = "std::vector<paddle::experimental::Tensor>"
if num_outputs == 1:
returns_str = returns_list[0]
returns_type_str = returns_type_list[0]
else:
returns_type_str = ", ".join(returns_type_list)
returns_type_str = f"std::tuple<{returns_type_str}>"
returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})"
node_creation_str = GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map)
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{fwd_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{
{}
{} {}
// Returns
return {};
}}
""" """
forward_function_name = GetForwardFunctionName(fwd_api_name) NODE_H_FILE_TEMPLATE = \
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format( """
returns_type_str, forward_function_name, inputs_args_definition_str, #pragma once
dygraph_event_str, node_creation_str, returns_str) #include "paddle/fluid/eager/tensor_wrapper.h"
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});" #include "paddle/fluid/eager/grad_node_info.h"
return forward_function_str, forward_function_declaration_str
def CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list):
# fwd_api_name : ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())
final_state_fwd_api_name = "final_state_" + fwd_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype):
core_ops_args_type_info[final_state_fwd_api_name][pos] = "tensor"
else:
assert IsVectorTensorType(ttype)
core_ops_args_type_info[final_state_fwd_api_name][pos] = "list"
for name, _, _, pos in forward_attrs_list:
core_ops_args_info[final_state_fwd_api_name][pos] = name
for name, (ttype, pos) in forward_outputs_position_map.items(): {}
core_ops_returns_info[final_state_fwd_api_name][pos] = name """
FORWARD_CC_FILE_TEMPLATE = \
"""
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
def GenerateCoreOpInfoDeclaration(): #include "paddle/phi/api/include/sparse_api.h"
core_ops_declaration_str = """ #include "paddle/fluid/eager/api/utils/global_utils.h"
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info; #include "paddle/fluid/platform/profiler/event_tracing.h"
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
{}
{}
""" """
return core_ops_declaration_str
FORWARD_H_FILE_TEMPLATE = \
"""
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/phi/api/all.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
def GenerateCoreOpInfoDefinition(): {}
{}
"""
CORE_OPS_INFO_TEMPLATE = """ CORE_OPS_INFO_TEMPLATE = \
"""
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info = {{ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info = {{
{} {}
}}; }};
...@@ -1114,6 +279,38 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r ...@@ -1114,6 +279,38 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r
}}; }};
""" """
CORE_OPS_DECLARATION_TEMPLATE = \
"""
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
"""
CHECK_INPLACE_TEMPLATE = \
"""
// Check Inplace
egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
"""
BUMP_INPLACE_VERSION_TEMPLATE = \
"""
// Bump Inplace Version
{}.bump_inplace_version();
VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";\n
"""
#######################
## Generator Helpers ##
#######################
def GenerateCoreOpInfoDeclaration():
return CORE_OPS_DECLARATION_TEMPLATE
def GenerateCoreOpInfoDefinition():
op_args_info_list = [] op_args_info_list = []
for op_name, arg_list in core_ops_args_info.items(): for op_name, arg_list in core_ops_args_info.items():
arg_str = ",".join(["\"" + v + "\"" for v in arg_list]) arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
...@@ -1142,68 +339,864 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r ...@@ -1142,68 +339,864 @@ std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_r
return core_ops_info_definition_str return core_ops_info_definition_str
#####################
## Generator Class ##
#####################
class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
self.forward_api_contents = forward_api_contents
# Members from Parent:
#self.namespace
#self.forward_api_contents
#self.forward_api_name
#self.orig_forward_inputs_list
#self.orig_forward_attrs_list
#self.orig_forward_returns_list
#self.forward_inputs_position_map
#self.forward_outputs_position_map
#self.optional_inputs
#self.no_need_buffers
#self.intermediate_outputs
#self.inplace_map
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
self.grad_api_contents = grad_api_contents
# Raw Contents
self.backward_forward_str = ""
self.backward_api_name = ""
self.forward_attrs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.forward_inputs_list = [
] #[ [arg_name, arg_type, orig_position], ...]
self.forward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...]
self.backward_inputs_list = [
] #[ [attr_name, attr_type, default_value, orig_position], ...]
self.backward_attrs_list = [
] #[ [arg_name, arg_type, orig_position], ...]
self.backward_returns_list = [
] #[ [ret_name, ret_type, orig_position], ...]
# SlotNameMatched Backward Data
self.backward_forward_inputs_map = {
} #{ "name" : [type, is_fwd_input, orig_position] ...}
self.backward_grad_inputs_map = {
} #{ "name" : [type, fwd_position, orig_position] ...}
self.backward_grad_outputs_map = {
} #{ "name" : [type, fwd_position, orig_position] ...}
# Generated Results
self.forward_definition_str = ""
self.forward_declaration_str = ""
self.node_declaration_str = ""
self.node_definition_str = ""
def DygraphYamlValidationCheck(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents
assert 'api' in forward_api_contents.keys()
assert 'args' in forward_api_contents.keys()
assert 'output' in forward_api_contents.keys()
assert 'backward' in forward_api_contents.keys()
assert 'args' in grad_api_contents.keys()
assert 'output' in grad_api_contents.keys()
assert 'forward' in grad_api_contents.keys()
def ForwardsValidationCheck(self):
forward_inputs_list = self.forward_inputs_list
forward_attrs_list = self.forward_attrs_list
forward_returns_list = self.forward_returns_list
orig_forward_inputs_list = self.orig_forward_inputs_list
orig_forward_attrs_list = self.orig_forward_attrs_list
orig_forward_returns_list = self.orig_forward_returns_list
for i in range(len(forward_inputs_list)):
forward_input_name = forward_inputs_list[i][0]
forward_input_type = forward_inputs_list[i][1]
forward_input_pos = forward_inputs_list[i][2]
orig_input_name = orig_forward_inputs_list[i][0]
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]
assert forward_input_type == orig_input_type
assert forward_input_pos == orig_input_pos
for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
orig_attr_type = orig_forward_attrs_list[i][1]
orig_attr_default = orig_forward_attrs_list[i][2]
orig_attr_pos = orig_forward_attrs_list[i][3]
forward_attr_name = forward_attrs_list[i][0]
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][1]
orig_return_pos = orig_forward_returns_list[i][2]
forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2]
assert orig_return_type == forward_return_type
assert orig_return_pos == forward_return_pos
# Check Order: Inputs, Attributes
max_input_position = -1
for _, _, pos in forward_inputs_list:
max_input_position = max(max_input_position, pos)
max_attr_position = -1
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position
max_attr_position = max(max_attr_position, pos)
def BackwardValidationCheck(self):
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_attrs_list = self.backward_attrs_list
# Check Order: TensorWrappers, GradTensors, Attributes
max_fwd_input_position = -1
for _, (_, _, pos) in backward_forward_inputs_map.items():
max_fwd_input_position = max(max_fwd_input_position, pos)
max_grad_tensor_position = -1
for _, (_, _, pos) in backward_grad_inputs_map.items():
assert pos > max_fwd_input_position
max_grad_tensor_position = max(max_grad_tensor_position, pos)
max_attr_position = -1
for _, _, _, pos in backward_attrs_list:
assert pos > max_grad_tensor_position
max_attr_position = max(max_attr_position, pos)
def IntermediateValidationCheck(self):
intermediate_outputs = self.intermediate_outputs
forward_returns_list = self.forward_returns_list
"""
Check whether intermediate_outputs are positioned
at the very end of forward_returns_list
"""
intermediate_positions = range(
len(forward_returns_list) - len(intermediate_outputs),
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
def CollectBackwardInfo(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents
self.backward_api_name = forward_api_contents['backward']
self.backward_forward_str = grad_api_contents['forward']
backward_args_str = grad_api_contents['args']
backward_returns_str = grad_api_contents['output']
self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward(
backward_args_str, backward_returns_str)
print("Parsed Backward Inputs List: ", self.backward_inputs_list)
print("Prased Backward Attrs List: ", self.backward_attrs_list)
print("Parsed Backward Returns List: ", self.backward_returns_list)
def CollectForwardInfoFromBackwardContents(self):
backward_forward_str = self.backward_forward_str
self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward(
backward_forward_str)
def SlotNameMatching(self):
backward_inputs_list = self.backward_inputs_list
backward_returns_list = self.backward_returns_list
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
for backward_input in backward_inputs_list:
backward_input_name = backward_input[0]
backward_input_type = backward_input[1]
backward_input_pos = backward_input[2]
backward_fwd_name = FindForwardName(backward_input_name)
if backward_fwd_name:
# Grad Input
assert backward_fwd_name in forward_outputs_position_map.keys()
matched_forward_output_type = forward_outputs_position_map[
backward_fwd_name][0]
matched_forward_output_pos = forward_outputs_position_map[
backward_fwd_name][1]
self.backward_grad_inputs_map[backward_input_name] = [
backward_input_type, matched_forward_output_pos,
backward_input_pos
]
else:
# TensorWrapper Input
if backward_input_name in forward_inputs_position_map.keys():
tensor_wrapper_type = forward_inputs_position_map[
backward_input_name][0]
self.backward_forward_inputs_map[backward_input_name] = [
backward_input_type, True, backward_input_pos
]
elif backward_input_name in forward_outputs_position_map.keys():
tensor_wrapper_type = forward_outputs_position_map[
backward_input_name][0]
self.backward_forward_inputs_map[backward_input_name] = [
backward_input_type, False, backward_input_pos
]
else:
assert False, backward_input_name
for backward_output in backward_returns_list:
backward_output_name = backward_output[0]
backward_output_type = backward_output[1]
backward_output_pos = backward_output[2]
backward_fwd_name = FindForwardName(backward_output_name)
assert backward_fwd_name is not None
assert backward_fwd_name in forward_inputs_position_map.keys(
), f"Unable to find {backward_fwd_name} in forward inputs"
matched_forward_input_type = forward_inputs_position_map[
backward_fwd_name][0]
matched_forward_input_pos = forward_inputs_position_map[
backward_fwd_name][1]
self.backward_grad_outputs_map[backward_output_name] = [
backward_output_type, matched_forward_input_pos,
backward_output_pos
]
print("Generated Backward Fwd Input Map: ",
self.backward_forward_inputs_map)
print("Generated Backward Grad Input Map: ",
self.backward_grad_inputs_map)
print("Generated Backward Grad Output Map: ",
self.backward_grad_outputs_map)
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_attrs_list = self.backward_attrs_list
no_need_buffers = self.no_need_buffers
# SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input,
_) in backward_forward_inputs_map.items():
no_need_buffer = "true" if tname in no_need_buffers else "false"
tensor_wrapper_name = GetSavedName(tname)
if IsPlainTensorType(ttype):
set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tensor_wrapper_name, tname, no_need_buffer)
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format(
tensor_wrapper_name)
else:
assert IsVectorTensorType(ttype)
set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
tname, tname, tname, tensor_wrapper_name, no_need_buffer)
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name)
clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
# SetAttributes & Attribute Members
set_attribute_methods_str = ""
attribute_members_str = ""
for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname)
set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
aname, GetConstReference(atype), aname, saved_attr_name, aname)
if default_val:
attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name,
default_val)
else:
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name)
grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
print("Generated Node Declaration: ", self.node_declaration_str)
def GenerateNodeDefinition(self):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
backward_grad_inputs_map.keys()) + len(backward_attrs_list)
grad_api_args = ["" for i in range(grad_api_args_len)]
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
if IsPlainTensorType(ttype):
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else:
assert IsVectorTensorType(ttype)
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}]"
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
grad_api_args[grad_api_position] = f"this->{saved_attribute_name}"
grad_api_args_str = ", ".join(grad_api_args)
# Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({num_bwd_outputs});\n"
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
# Infer Grad API Return Type
if num_bwd_outputs == 1:
# Single tensor output, return as is
if IsPlainTensorType(ttype):
returns_str += "returns[0] = { grad_api_returns };\n"
else:
assert IsVectorTensorType(ttype)
returns_str += "returns[0] = grad_api_returns;\n"
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name)
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
grad_api_namespace = f"paddle::experimental::{namespace}"
self.node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)
print("Generated Node Definition: ", self.node_definition_str)
def GenerateForwardDefinition(self, is_inplaced):
namespace = self.namespace
forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if is_inplaced else self.forward_api_name
backward_api_name = self.backward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
inplace_map = self.inplace_map
# Get Function Args
num_inputs = len(forward_attrs_list) + len(
forward_inputs_position_map.keys())
inputs_args_definition_list = ["" for i in range(num_inputs)]
inputs_args_declaration_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype):
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else:
if inplace_map and name in inplace_map.keys():
arg_str = f"paddle::experimental::Tensor& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_definition_list[pos] = arg_str
inputs_args_declaration_list[pos] = arg_str
for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
if default_val is not None:
inputs_args_declaration_list[
pos] = f"{atype} {name} = {default_val}"
else:
inputs_args_declaration_list[pos] = f"{atype} {name}"
inputs_args_definition_list[pos] = f"{atype} {name}"
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
inputs_call_args_str = ", ".join(inputs_call_list)
# Forward Full Logic
function_name = forward_api_name
if len(intermediate_outputs) > 0:
function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1:
returns_list[0] = f"api_result"
else:
# Tuple api_result
returns_list[pos] = f"std::get<{pos}>(api_result)"
if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor"
else:
assert IsVectorTensorType(rtype)
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>"
if num_outputs == 1:
returns_str = returns_list[0]
returns_type_str = returns_type_list[0]
else:
returns_type_str = ", ".join(returns_type_list)
returns_type_str = f"std::tuple<{returns_type_str}>"
returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})"
self.GenerateNodeCreationCodes(forward_call_str)
node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
print("Generated Forward Definition: ", self.forward_definition_str)
print("Generated Forward Declaration: ", self.forward_declaration_str)
def GenerateNodeCreationCodes(self, forward_call_str):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_forward_inputs_map = self.backward_forward_inputs_map
backward_grad_inputs_map = self.backward_grad_inputs_map
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
inplace_map = self.inplace_map
# Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
# Get Output AutoGradMeta
outputs_autograd_meta_list = []
pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
# Node Construction
num_backward_inputs = len(backward_grad_inputs_map.keys())
num_backward_outputs = len(backward_grad_outputs_map.keys())
grad_node_name = GetGradNodeName(forward_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_backward_inputs}, {num_backward_outputs});"
# SetAttributes
set_attributes_list = []
forward_attrs_name_set = set()
for name, _, _, _ in forward_attrs_list:
forward_attrs_name_set.add(name)
for name, _, default_val_attr, _ in backward_attrs_list:
if name in forward_attrs_name_set:
set_attributes = f" grad_node->SetAttribute{name}({name});"
else:
set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
# SetTensorWrappers
set_tensor_wrappers_list = []
for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
if is_optional:
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
# SetGradOutMeta & SetEdges
set_grad_out_meta_list = []
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
set_edges_str = "\n".join(set_edges_list)
# SetOutRank & SetHistory & SetGradInMeta
set_out_rank_list = []
set_history_list = []
set_grad_in_meta_list = []
set_retain_grad_list = []
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
self.node_creation_str = NODE_CREATION_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str,
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
def GenerateInplacedForwardDygraphFunctions(self):
# Inplaced Version Dygraph Function Generation
forward_api_name = self.forward_api_name
forward_api_contents = self.forward_api_contents
if forward_api_name != "sum" and "inplace" in forward_api_contents.keys(
):
# Node Definition Generation
self.GenerateForwardDefinition(is_inplaced=True)
self.UpdateCoreOpsInformation(is_inplaced=True)
def UpdateCoreOpsInformation(self, is_inplaced):
forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if is_inplaced else self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
num_args = len(forward_inputs_position_map.keys()) + len(
forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())
final_state_fwd_api_name = "final_state_" + forward_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
if IsPlainTensorType(ttype):
core_ops_args_type_info[final_state_fwd_api_name][
pos] = "tensor"
else:
assert IsVectorTensorType(ttype)
core_ops_args_type_info[final_state_fwd_api_name][pos] = "list"
for name, _, _, pos in forward_attrs_list:
core_ops_args_info[final_state_fwd_api_name][pos] = name
for name, (ttype, pos) in forward_outputs_position_map.items():
core_ops_returns_info[final_state_fwd_api_name][pos] = name
def run(self):
# Basic Validation Check
self.DygraphYamlValidationCheck()
##########################
## Parsing Raw Contents ##
##########################
# Parse inplace_map
self.ParseInplaceInfo()
# Parse no_need_buffer
self.ParseNoNeedBuffer()
# Parse optional_inputs
self.ParseDispensable()
# Parse intermediate_outputs
self.ParseIntermediate()
self.IntermediateValidationCheck()
# Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
self.CollectBackwardInfo()
# Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
self.CollectForwardInfoFromBackwardContents()
# Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
self.CollectOriginalForwardInfo()
# Forwards Validation Check
self.ForwardsValidationCheck()
#############################
## Process Parsed Contents ##
#############################
# Initialize forward_inputs_position_map, forward_outputs_position_map
self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list)
# Initialize forward_inputs_position_map, forward_outputs_position_map
self.SlotNameMatching()
# Backward Validation Check
self.BackwardValidationCheck()
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition()
self.GenerateForwardDefinition(is_inplaced=False)
self.UpdateCoreOpsInformation(is_inplaced=False)
self.GenerateInplacedForwardDygraphFunctions()
class DygraphYamlGenerator(YamlGeneratorBase):
def __init__(self, api_yaml_path, backward_yaml_path):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
YamlGeneratorBase.__init__(self, api_yaml_path)
self.backward_yaml_path = backward_yaml_path
self.grad_api_dict = {}
self.forward_definition_str = ""
self.forward_declaration_str = ""
self.node_declaration_str = ""
self.node_definition_str = ""
def ParseYamlContents(self):
self.ParseForwardYamlContents()
backward_yaml_path = self.backward_yaml_path
self.grad_api_dict = ReadBwdFile(backward_yaml_path)
def GetBackwardAPIContents(self, forward_api_contents):
grad_api_dict = self.grad_api_dict
if 'backward' not in forward_api_contents.keys(): return None
backward_api_name = forward_api_contents['backward']
assert backward_api_name in grad_api_dict.keys()
backward_api_contents = grad_api_dict[backward_api_name]
return backward_api_contents
def GenerateCode(self):
forward_api_list = self.forward_api_list
grad_api_dict = self.grad_api_dict
namespace = self.namespace
for forward_api_contents in forward_api_list:
backward_api_contents = self.GetBackwardAPIContents(
forward_api_contents)
if backward_api_contents is None: continue
d_generator = DygraphSingleFunctionGenerator(
forward_api_contents, backward_api_contents, namespace)
d_generator.run()
self.forward_definition_str += d_generator.forward_definition_str + "\n"
self.forward_declaration_str += d_generator.forward_declaration_str + "\n"
self.node_declaration_str += d_generator.node_declaration_str + "\n"
self.node_definition_str += d_generator.node_definition_str + "\n"
if len(namespace) > 0:
if namespace.endswith("::"):
namespace = namespace[:-2]
self.forward_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
namespace, self.forward_definition_str)
self.forward_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
namespace, self.forward_declaration_str)
self.node_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
namespace, self.node_declaration_str)
self.node_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
namespace, self.node_definition_str)
def run(self):
self.ParseYamlContents()
self.InferNameSpace()
self.GenerateCode()
##################
## File Writers ##
##################
def GenerateNodeCCFile(filepath, node_definition_str): def GenerateNodeCCFile(filepath, node_definition_str):
file_contents = """ if os.path.exists(filepath):
#include "glog/logging.h" os.remove(filepath)
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/phi/api/backward/sparse_bw_api.h" file_contents = NODE_CC_FILE_TEMPLATE.format(node_definition_str)
"""
file_contents += node_definition_str
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
f.write(file_contents) f.write(file_contents)
def GenerateNodeHFile(filepath, node_declaration_str): def GenerateNodeHFile(filepath, node_declaration_str):
file_contents = """ if os.path.exists(filepath):
#pragma once os.remove(filepath)
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/eager/grad_node_info.h"
""" file_contents = NODE_H_FILE_TEMPLATE.format(node_declaration_str)
file_contents += node_declaration_str
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
f.write(file_contents) f.write(file_contents)
def GenerateForwardCCFile(filepath, forward_definition_str): def GenerateForwardCCFile(filepath, forward_definition_str):
file_contents = """ if os.path.exists(filepath):
#include "paddle/phi/api/lib/dygraph_api.h" os.remove(filepath)
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
"""
file_contents += GenerateCoreOpInfoDefinition() core_ops_info_str = GenerateCoreOpInfoDefinition()
file_contents += forward_definition_str file_contents = FORWARD_CC_FILE_TEMPLATE.format(core_ops_info_str,
forward_definition_str)
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
f.write(file_contents) f.write(file_contents)
def GenerateForwardHFile(filepath, forward_function_declaration_str): def GenerateForwardHFile(filepath, forward_function_declaration_str):
file_contents = """ if os.path.exists(filepath):
#pragma once os.remove(filepath)
#include "glog/logging.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/phi/api/all.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
""" core_ops_info_str = GenerateCoreOpInfoDeclaration()
file_contents += GenerateCoreOpInfoDeclaration() file_contents = FORWARD_H_FILE_TEMPLATE.format(
file_contents += forward_function_declaration_str core_ops_info_str, forward_function_declaration_str)
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
f.write(file_contents) f.write(file_contents)
...@@ -1224,199 +1217,13 @@ if __name__ == "__main__": ...@@ -1224,199 +1217,13 @@ if __name__ == "__main__":
api_yaml_path = api_yaml_paths[i] api_yaml_path = api_yaml_paths[i]
backward_yaml_path = backward_yaml_paths[i] backward_yaml_path = backward_yaml_paths[i]
if "sparse" in api_yaml_path: generator = DygraphYamlGenerator(api_yaml_path, backward_yaml_path)
assert "sparse" in backward_yaml_path generator.run()
namespace = "sparse"
else:
namespace = ""
fwd_api_list = ReadFwdFile(api_yaml_path)
grad_api_dict = ReadBwdFile(backward_yaml_path)
yaml_forward_definition_str = ""
yaml_forward_declaration_str = ""
yaml_node_declaration_str = ""
yaml_node_definition_str = ""
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys() node_declaration_str += generator.node_declaration_str + "\n"
assert 'args' in fwd_api.keys() node_definition_str += generator.node_definition_str + "\n"
assert 'output' in fwd_api.keys() forward_definition_str += generator.forward_definition_str + "\n"
assert 'backward' in fwd_api.keys() forward_declaration_str += generator.forward_declaration_str + "\n"
no_need_buffer_set = set()
if 'no_need_buffer' in fwd_api.keys():
no_need_buffer_set = ParseNoNeedBuffer(fwd_api[
'no_need_buffer'])
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
inplace_map = {}
if 'inplace' in fwd_api.keys():
inplace_map = ParseInplaceInfo(fwd_api['inplace'])
bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys(), bwd_api_name
bwd_api = grad_api_dict[bwd_api_name]
assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys()
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output']
# Collect Forward Inputs/Outputs
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward(
bwd_forward_str)
print("Parsed Forward Inputs List: ", forward_inputs_list)
print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list)
intermediate_outputs = []
if 'intermediate' in fwd_api.keys():
intermediate_outputs = ParseIntermediate(fwd_api[
'intermediate'])
IntermediateValidationCheck(intermediate_outputs,
forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ",
orig_forward_inputs_list)
print("Prased Original Forward Attrs List: ",
orig_forward_attrs_list)
print("Parsed Original Forward Returns List: ",
orig_forward_returns_list)
# Forward Validation Checks
ForwardsValidationCheck(
forward_inputs_list, forward_attrs_list, forward_returns_list,
orig_forward_inputs_list, orig_forward_attrs_list,
orig_forward_returns_list)
# Parse Backward Inputs/Outputs
backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward(
bwd_args_str, bwd_returns_str)
print("Parsed Backward Inputs List: ", backward_inputs_list)
print("Prased Backward Attrs List: ", backward_attrs_list)
print("Parsed Backward Returns List: ", backward_returns_list)
# Determine Forward Inputs/Outputs Position
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
# SlotName Matching
backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching(
backward_inputs_list, backward_returns_list,
forward_inputs_position_map, forward_outputs_position_map)
print("Generated Backward Fwd Input Map: ", backward_fwd_input_map)
print("Generated Backward Grad Input Map: ",
backward_grad_input_map)
print("Generated Backward Grad Output Map: ",
backward_grad_output_map)
# Backward Validation Check
BackwardValidationCheck(backward_fwd_input_map,
backward_grad_input_map,
backward_attrs_list)
# Node Declaration Generation
yaml_node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list,
no_need_buffer_set)
print("Generated Node Declaration: ", node_declaration_str)
yaml_node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list)
print("Generated Node Definition: ", node_definition_str)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, orig_forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs, {})
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
yaml_forward_definition_str += definition_declaration_pair[0]
yaml_forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map,
orig_forward_attrs_list)
# Inplaced Version Dygraph Function Generation
if fwd_api_name != "sum" and "inplace" in fwd_api.keys():
fwd_api_name_inplaced = GetInplacedFunctionName(fwd_api_name)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name_inplaced, bwd_api_name,
forward_inputs_position_map, forward_outputs_position_map,
forward_attrs_list, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list, optional_inputs, intermediate_outputs,
inplace_map)
print("Generated Inplaced Forward Definition: ",
forward_definition_str)
print("Generated Inplaced Forward Declaration: ",
forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(
fwd_api_name_inplaced, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list)
if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str}
}}
"""
forward_declaration_str += f"""namespace {namespace} {{
{yaml_forward_declaration_str}
}}
"""
node_declaration_str += f"""namespace {namespace} {{
{yaml_node_declaration_str}
}}
"""
node_definition_str += f"""namespace {namespace} {{
{yaml_node_definition_str}
}}
"""
else:
forward_definition_str += yaml_forward_definition_str
forward_declaration_str += yaml_forward_declaration_str
node_declaration_str += yaml_node_declaration_str
node_definition_str += yaml_node_definition_str
# Generate Files # Generate Files
nodes_h_path = args.nodes_h_path nodes_h_path = args.nodes_h_path
...@@ -1424,12 +1231,6 @@ if __name__ == "__main__": ...@@ -1424,12 +1231,6 @@ if __name__ == "__main__":
forwards_h_path = args.forwards_h_path forwards_h_path = args.forwards_h_path
forwards_cc_path = args.forwards_cc_path forwards_cc_path = args.forwards_cc_path
for path in [
nodes_cc_path, nodes_h_path, forwards_h_path, forwards_cc_path
]:
if os.path.exists(path):
os.remove(path)
GenerateNodeCCFile(nodes_cc_path, node_definition_str) GenerateNodeCCFile(nodes_cc_path, node_definition_str)
GenerateNodeHFile(nodes_h_path, node_declaration_str) GenerateNodeHFile(nodes_h_path, node_declaration_str)
GenerateForwardCCFile(forwards_cc_path, forward_definition_str) GenerateForwardCCFile(forwards_cc_path, forward_definition_str)
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
import os import os
import argparse import argparse
import logging import logging
from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap, GetInplacedFunctionName, ParseInplaceInfo from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, IsVectorTensorType, GetForwardFunctionName
from codegen_utils import ParseYamlForward, GetInplacedFunctionName
########################### ###########################
## Global Configurations ## ## Global Configurations ##
...@@ -121,7 +124,10 @@ FUNCTION_NAME_TEMPLATE = \ ...@@ -121,7 +124,10 @@ FUNCTION_NAME_TEMPLATE = \
PYTHON_C_FUNCTION_REG_TEMPLATE = \ PYTHON_C_FUNCTION_REG_TEMPLATE = \
"{{\"final_state_{}\", (PyCFunction)(void(*)(void)) {}eager_final_state_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}}" """
{{\"final_state_{}\", (PyCFunction)(void(*)(void)) {}eager_final_state_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}}
"""
PYTHON_C_WRAPPER_TEMPLATE = \ PYTHON_C_WRAPPER_TEMPLATE = \
...@@ -229,77 +235,39 @@ NAMESPACE_WRAPPER_TEMPLATE = \ ...@@ -229,77 +235,39 @@ NAMESPACE_WRAPPER_TEMPLATE = \
####################### #######################
## Generator Classes ## ## Generator Classes ##
####################### #######################
class PythonCSingleFunctionGenerator: class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
def __init__(self, fwd_api_contents, namespace): def __init__(self, forward_api_contents, namespace):
self.fwd_api_contents = fwd_api_contents # Members from Parent:
self.namespace = namespace #self.namespace
#self.forward_api_contents
# Raw Contents #self.forward_api_name
self.forward_api_name = "" #self.orig_forward_inputs_list
self.forward_args_str = "" #self.orig_forward_attrs_list
self.forward_returns_str = "" #self.orig_forward_returns_list
#self.forward_inputs_position_map
# Raw Data #self.forward_outputs_position_map
self.forward_attrs_list = None #[ [attr_name, attr_type, default_value, orig_position], ...] #self.optional_inputs
self.forward_inputs_list = None #[ [arg_name, arg_type, orig_position], ...] #self.no_need_buffers
self.forward_returns_list = None #[ [ret_name, ret_type, orig_position], ...] #self.intermediate_outputs
#self.inplace_map
# Processed Data FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)
self.forward_inputs_position_map = None #{ "name" : [type, fwd_position] }
self.forward_outputs_position_map = None #{ "name" : [type, fwd_position] }
# Special Op Attributes
self.optional_inputs = [] #[name, ...]
self.is_forward_only = True self.is_forward_only = True
# Generated Results # Generated Results
self.python_c_function_str = "" self.python_c_function_str = ""
self.python_c_function_reg_str = "" self.python_c_function_reg_str = ""
def CollectRawContents(self):
fwd_api_contents = self.fwd_api_contents
assert 'api' in fwd_api_contents.keys(
), "Unable to find \"api\" in fwd_api_contents keys"
assert 'args' in fwd_api_contents.keys(
), "Unable to find \"args\" in fwd_api_contents keys"
assert 'output' in fwd_api_contents.keys(
), "Unable to find \"output\" in fwd_api_contents keys"
self.forward_api_name = fwd_api_contents['api']
self.forward_args_str = fwd_api_contents['args']
self.forward_returns_str = fwd_api_contents['output']
def CollectIsForwardOnly(self): def CollectIsForwardOnly(self):
fwd_api_contents = self.fwd_api_contents forward_api_contents = self.forward_api_contents
self.is_forward_only = False if 'backward' in fwd_api_contents.keys( self.is_forward_only = False if 'backward' in forward_api_contents.keys(
) else True ) else True
def CollectOptionalInputs(self): def GeneratePythonCFunction(self):
fwd_api_contents = self.fwd_api_contents
if 'optional' in fwd_api_contents.keys():
self.optional_inputs = ParseDispensable(fwd_api_contents[
'optional'])
def CollectForwardInOutAttr(self):
forward_args_str = self.forward_args_str
forward_returns_str = self.forward_returns_str
self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForward(
forward_args_str, forward_returns_str)
def CollectForwardPositionMap(self):
forward_inputs_list = self.forward_inputs_list
forward_returns_list = self.forward_returns_list
self.forward_inputs_position_map, self.forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
def GeneratePythonCFunction(self, inplace_map):
namespace = self.namespace namespace = self.namespace
forward_api_name = GetInplacedFunctionName( inplace_map = self.inplace_map
self.forward_api_name) if inplace_map else self.forward_api_name forward_api_name = self.forward_api_name
forward_attrs_list = self.forward_attrs_list orig_forward_attrs_list = self.orig_forward_attrs_list
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
...@@ -326,7 +294,7 @@ class PythonCSingleFunctionGenerator: ...@@ -326,7 +294,7 @@ class PythonCSingleFunctionGenerator:
parse_attributes_str = "" parse_attributes_str = ""
# Generate Python-C Attributes Parsing Logic # Generate Python-C Attributes Parsing Logic
for name, atype, _, pos in forward_attrs_list: for name, atype, _, pos in orig_forward_attrs_list:
parsing_function_name = FindParsingFunctionFromAttributeType(atype) parsing_function_name = FindParsingFunctionFromAttributeType(atype)
parse_attributes_str += PARSE_PYTHON_C_ARGS_TEMPLATE.format( parse_attributes_str += PARSE_PYTHON_C_ARGS_TEMPLATE.format(
name, pos, atype, name, parsing_function_name, name, name, pos, atype, name, parsing_function_name, name,
...@@ -334,11 +302,11 @@ class PythonCSingleFunctionGenerator: ...@@ -334,11 +302,11 @@ class PythonCSingleFunctionGenerator:
# Generate Dygraph Function Call Logic # Generate Dygraph Function Call Logic
num_args = len(forward_inputs_position_map.keys()) + len( num_args = len(forward_inputs_position_map.keys()) + len(
forward_attrs_list) orig_forward_attrs_list)
dygraph_function_call_list = ["" for i in range(num_args)] dygraph_function_call_list = ["" for i in range(num_args)]
for name, (_, pos) in forward_inputs_position_map.items(): for name, (_, pos) in forward_inputs_position_map.items():
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
for name, _, _, pos in forward_attrs_list: for name, _, _, pos in orig_forward_attrs_list:
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list) dygraph_function_call_str = ",".join(dygraph_function_call_list)
...@@ -350,17 +318,7 @@ class PythonCSingleFunctionGenerator: ...@@ -350,17 +318,7 @@ class PythonCSingleFunctionGenerator:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format( fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name)) "::", namespace, GetForwardFunctionName(forward_api_name))
if inplace_map: return_str = " return ToPyObject(out);"
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
forward_api_name, inplace_input, forward_api_name,
inplace_output)
break
else:
return_str = " return ToPyObject(out);"
# Generate Record Event for performance profiling # Generate Record Event for performance profiling
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format( pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
...@@ -374,29 +332,56 @@ class PythonCSingleFunctionGenerator: ...@@ -374,29 +332,56 @@ class PythonCSingleFunctionGenerator:
self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name, namespace, forward_api_name, forward_api_name) forward_api_name, namespace, forward_api_name, forward_api_name)
def run(self, inplace_map): if len(inplace_map) > 0:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input,
inplaced_forward_api_name, inplace_output)
break
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, fwd_function_name,
dygraph_function_call_str, return_str)
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format(
inplaced_forward_api_name, namespace, inplaced_forward_api_name,
inplaced_forward_api_name)
def run(self):
# Initialized is_forward_only # Initialized is_forward_only
self.CollectIsForwardOnly() self.CollectIsForwardOnly()
# Initialized forward_api_name, forward_args_str, forward_returns_str
self.CollectRawContents()
if SkipAPIGeneration(self.forward_api_name): return False
# Initialized optional_inputs # Initialized optional_inputs
self.CollectOptionalInputs() self.ParseDispensable()
# Initialized inplace_map
self.ParseInplaceInfo()
# Initialized forward_inputs_list, forward_returns_list, forward_attrs_list # Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list
self.CollectForwardInOutAttr() self.CollectOriginalForwardInfo()
logging.info( logging.info(
f"Parsed Original Forward Inputs List: \n{self.forward_inputs_list}") f"Parsed Original Forward Inputs List: \n{self.orig_forward_inputs_list}"
)
logging.info( logging.info(
f"Prased Original Forward Attrs List: \n{self.forward_attrs_list}") f"Prased Original Forward Attrs List: \n{self.orig_forward_attrs_list}"
)
logging.info( logging.info(
f"Parsed Original Forward Returns List: \n{self.forward_returns_list}" f"Parsed Original Forward Returns List: \n{self.orig_forward_returns_list}"
) )
if SkipAPIGeneration(self.forward_api_name): return False
# Initialized forward_inputs_position_map, forward_outputs_position_map # Initialized forward_inputs_position_map, forward_outputs_position_map
self.CollectForwardPositionMap() self.DetermineForwardPositionMap(self.orig_forward_inputs_list,
self.orig_forward_returns_list)
logging.info( logging.info(
f"Generated Forward Input Position Map: {self.forward_inputs_position_map}" f"Generated Forward Input Position Map: {self.forward_inputs_position_map}"
) )
...@@ -405,7 +390,7 @@ class PythonCSingleFunctionGenerator: ...@@ -405,7 +390,7 @@ class PythonCSingleFunctionGenerator:
) )
# Code Generation # Code Generation
self.GeneratePythonCFunction(inplace_map) self.GeneratePythonCFunction()
logging.info( logging.info(
f"Generated Python-C Function: {self.python_c_function_str}") f"Generated Python-C Function: {self.python_c_function_str}")
logging.info( logging.info(
...@@ -415,21 +400,18 @@ class PythonCSingleFunctionGenerator: ...@@ -415,21 +400,18 @@ class PythonCSingleFunctionGenerator:
return True return True
class PythonCYamlGenerator: class PythonCYamlGenerator(YamlGeneratorBase):
def __init__(self, path): def __init__(self, path):
self.yaml_path = path # Parent members:
# self.namespace
self.namespace = "" # self.api_yaml_path
self.forward_api_list = [] # self.forward_api_list
YamlGeneratorBase.__init__(self, api_yaml_path)
# Generated Result # Generated Result
self.python_c_functions_reg_str = "" self.python_c_functions_reg_str = ""
self.python_c_functions_str = "" self.python_c_functions_str = ""
def ParseYamlContents(self):
yaml_path = self.yaml_path
self.forward_api_list = ReadFwdFile(yaml_path)
def GeneratePythonCFunctions(self): def GeneratePythonCFunctions(self):
namespace = self.namespace namespace = self.namespace
forward_api_list = self.forward_api_list forward_api_list = self.forward_api_list
...@@ -437,28 +419,12 @@ class PythonCYamlGenerator: ...@@ -437,28 +419,12 @@ class PythonCYamlGenerator:
for forward_api_content in forward_api_list: for forward_api_content in forward_api_list:
f_generator = PythonCSingleFunctionGenerator(forward_api_content, f_generator = PythonCSingleFunctionGenerator(forward_api_content,
namespace) namespace)
status = f_generator.run({}) status = f_generator.run()
if status == True: if status == True:
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n" self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator.python_c_function_str + "\n" self.python_c_functions_str += f_generator.python_c_function_str + "\n"
if 'inplace' in forward_api_content.keys():
inplace_map = ParseInplaceInfo(forward_api_content['inplace'])
f_generator_inplace = PythonCSingleFunctionGenerator(
forward_api_content, namespace)
status = f_generator_inplace.run(inplace_map)
if status == True:
self.python_c_functions_reg_str += f_generator_inplace.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator_inplace.python_c_function_str + "\n"
def InferNameSpace(self):
yaml_path = self.yaml_path
if "sparse" in yaml_path:
self.namespace = "sparse::"
def AttachNamespace(self): def AttachNamespace(self):
namespace = self.namespace namespace = self.namespace
python_c_functions_str = self.python_c_functions_str python_c_functions_str = self.python_c_functions_str
...@@ -474,7 +440,7 @@ class PythonCYamlGenerator: ...@@ -474,7 +440,7 @@ class PythonCYamlGenerator:
self.InferNameSpace() self.InferNameSpace()
# Read Yaml file # Read Yaml file
self.ParseYamlContents() self.ParseForwardYamlContents()
# Code Generation # Code Generation
self.GeneratePythonCFunctions() self.GeneratePythonCFunctions()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册