未验证 提交 933db9d4 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Forword only add dygraph func (#45153)

* [Eager draft] forward_only interface migrate to autograd_api

* strings api add dygraph forward function

* rm useless comments

* draft version for check CI

* fix ci

* forward-only no need compute_require_grad and pass stop_gradient, rm useless comments

* polish yaml and using CPUPlace = phi::CPUPlace

* rm useless comments

* polish yaml and update some test case

* rm useless funcs

* polish eager_gen code

* polish code
上级 f706d95d
...@@ -38,7 +38,7 @@ add_custom_target( ...@@ -38,7 +38,7 @@ add_custom_target(
COMMAND COMMAND
"${PYTHON_EXECUTABLE}" "${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py"
"--api_yaml_path=${api_yaml_path}" "--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
"--backward_yaml_path=${backward_yaml_path}" "--backward_yaml_path=${backward_yaml_path}"
"--forwards_cc_path=${tmp_forwards_cc_path}" "--forwards_cc_path=${tmp_forwards_cc_path}"
"--forwards_h_path=${tmp_forwards_h_path}" "--forwards_h_path=${tmp_forwards_h_path}"
......
...@@ -353,6 +353,9 @@ class FunctionGeneratorBase: ...@@ -353,6 +353,9 @@ class FunctionGeneratorBase:
self.forward_api_contents = forward_api_contents self.forward_api_contents = forward_api_contents
self.namespace = namespace self.namespace = namespace
self.is_forward_only = False if 'backward' in forward_api_contents.keys(
) else True
self.forward_api_name = "" self.forward_api_name = ""
self.orig_forward_inputs_list = [ self.orig_forward_inputs_list = [
......
...@@ -209,6 +209,26 @@ FORWARD_FUNCTION_TEMPLATE = \ ...@@ -209,6 +209,26 @@ FORWARD_FUNCTION_TEMPLATE = \
}} }}
""" """
FORWARD_ONLY_FUNCTION_TEMPLATE = \
"""
{} {}({}) {{
// Dygraph Record Event
{}
// AMP Logic
{}
// Forward API Call
VLOG(3) << \"Final State Running: \" << \"{}\";
{}
// Get Outputs
{}
// Returns
return {};
}}
"""
FORWARD_BODY_TEMPLATE = \ FORWARD_BODY_TEMPLATE = \
""" if(require_any_grad) {{ """ if(require_any_grad) {{
{} {}
...@@ -297,6 +317,7 @@ FORWARD_CC_FILE_TEMPLATE = \ ...@@ -297,6 +317,7 @@ FORWARD_CC_FILE_TEMPLATE = \
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/strings_api.h"
#include "paddle/phi/api/include/sparse_api.h" #include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
...@@ -321,6 +342,7 @@ FORWARD_H_FILE_TEMPLATE = \ ...@@ -321,6 +342,7 @@ FORWARD_H_FILE_TEMPLATE = \
#include "paddle/fluid/eager/to_static/run_program_op_func.h" #include "paddle/fluid/eager/to_static/run_program_op_func.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
using CPUPlace = phi::CPUPlace;
{} {}
{} {}
""" """
...@@ -406,6 +428,27 @@ CHECK_NAN_AND_INF_TEMPLATE = \ ...@@ -406,6 +428,27 @@ CHECK_NAN_AND_INF_TEMPLATE = \
""" if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }} """ if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
""" """
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'sparse_momentum_', 'sparse_momentum', 'full_'
]
inplace_optional_out_type_map = {
"Tensor":
"paddle::optional<paddle::experimental::Tensor>&",
"std::vector<Tensor>":
"paddle::optional<std::vector<paddle::experimental::Tensor>>&"
}
####################### #######################
## Generator Helpers ## ## Generator Helpers ##
...@@ -513,9 +556,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -513,9 +556,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
), "Unable to find \"args\" in api.yaml" ), "Unable to find \"args\" in api.yaml"
assert 'output' in forward_api_contents.keys( assert 'output' in forward_api_contents.keys(
), "Unable to find \"output\" in api.yaml" ), "Unable to find \"output\" in api.yaml"
if grad_api_contents is not None:
assert 'backward' in forward_api_contents.keys( assert 'backward' in forward_api_contents.keys(
), "Unable to find \"backward\" in api.yaml" ), "Unable to find \"backward\" in api.yaml"
assert 'args' in grad_api_contents.keys( assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml" ), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys( assert 'output' in grad_api_contents.keys(
...@@ -629,6 +673,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -629,6 +673,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward( self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward(
backward_forward_str) backward_forward_str)
def CollectForwardInfoFromYamlForward(self):
self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward(
self.forward_api_contents['args'] + " -> " +
self.forward_api_contents['output'])
def SlotNameMatching(self): def SlotNameMatching(self):
backward_inputs_list = self.backward_inputs_list backward_inputs_list = self.backward_inputs_list
backward_returns_list = self.backward_returns_list backward_returns_list = self.backward_returns_list
...@@ -694,6 +743,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -694,6 +743,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
backward_output_pos backward_output_pos
] ]
def GetPassStopGradientArgsList(self, forward_outputs_position_map):
pass_stop_gradient_args_list = ["false"]
for name, (_, _) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
return pass_stop_gradient_args_str
def GenerateNodeCreationCodes(self, for_backward=False): def GenerateNodeCreationCodes(self, for_backward=False):
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
...@@ -706,11 +763,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -706,11 +763,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
# Pass Stop Gradient Args # Pass Stop Gradient Args
pass_stop_gradient_args_list = ["false"] pass_stop_gradient_args_str = self.GetPassStopGradientArgsList(
for name, (_, _) in forward_outputs_position_map.items(): forward_outputs_position_map)
output_autograd_meta_name = GetAutoGradMetaName(name)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Node Construction # Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_inputs = len(forward_outputs_position_map.keys())
...@@ -851,8 +905,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -851,8 +905,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
########################## ##########################
# Parse forward and backward inplace_map # Parse forward and backward inplace_map
self.ParseForwardInplaceInfo() self.ParseForwardInplaceInfo()
if self.grad_api_contents is not None:
self.ParseBackwardInplaceInfo() self.ParseBackwardInplaceInfo()
# Parse no_need_buffer # Parse no_need_buffer
self.ParseNoNeedBuffer() self.ParseNoNeedBuffer()
...@@ -863,12 +917,16 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -863,12 +917,16 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.ParseIntermediate() self.ParseIntermediate()
self.IntermediateValidationCheck() self.IntermediateValidationCheck()
if self.grad_api_contents is not None:
# Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
self.CollectBackwardInfo() self.CollectBackwardInfo()
# Initialize forward_inputs_list, forward_attrs_list, forward_returns_list # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
self.CollectForwardInfoFromBackwardContents() self.CollectForwardInfoFromBackwardContents()
if self.is_forward_only:
self.CollectForwardInfoFromYamlForward()
# Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
self.CollectOriginalForwardInfo() self.CollectOriginalForwardInfo()
...@@ -882,9 +940,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -882,9 +940,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.DetermineForwardPositionMap(self.forward_inputs_list, self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list) self.forward_returns_list)
if self.grad_api_contents is not None:
# Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map # Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
self.SlotNameMatching() self.SlotNameMatching()
# Backward Validation Check # Backward Validation Check
self.BackwardValidationCheck() self.BackwardValidationCheck()
...@@ -909,6 +967,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -909,6 +967,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list forward_attrs_list = self.forward_attrs_list
if not self.is_forward_only:
backward_grad_outputs_map = self.backward_grad_outputs_map backward_grad_outputs_map = self.backward_grad_outputs_map
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
...@@ -934,6 +993,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -934,6 +993,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
is_optional = (name in optional_inputs) is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
if is_optional: if is_optional:
if self.is_forward_only and is_inplaced and forward_inplace_map and name in forward_inplace_map.keys(
):
arg_str = f"paddle::optional<paddle::experimental::Tensor>& {name}"
else:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}" arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
amp_tensors_vector_optional_list.append( amp_tensors_vector_optional_list.append(
f"if ({name}) amp_tensors_vector.push_back({{ *{name} }});\n" f"if ({name}) amp_tensors_vector.push_back({{ *{name} }});\n"
...@@ -1028,6 +1091,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1028,6 +1091,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
if is_inplaced and forward_inplace_map and name in forward_inplace_map.values( if is_inplaced and forward_inplace_map and name in forward_inplace_map.values(
): ):
ind = list(forward_inplace_map.values()).index(name)
if list(forward_inplace_map.keys()
)[ind] in self.optional_inputs:
returns_type_list[pos] = inplace_optional_out_type_map[
rtype]
else:
returns_type_list[pos] = "paddle::experimental::Tensor&" returns_type_list[pos] = "paddle::experimental::Tensor&"
else: else:
returns_type_list[pos] = "paddle::experimental::Tensor" returns_type_list[pos] = "paddle::experimental::Tensor"
...@@ -1035,6 +1104,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1035,6 +1104,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
if is_inplaced and forward_inplace_map and name in forward_inplace_map.values( if is_inplaced and forward_inplace_map and name in forward_inplace_map.values(
): ):
ind = list(forward_inplace_map.values()).index(name)
if list(forward_inplace_map.keys()
)[ind] in self.optional_inputs:
returns_type_list[pos] = inplace_optional_out_type_map[
rtype]
else:
returns_type_list[ returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>&" pos] = "std::vector<paddle::experimental::Tensor>&"
else: else:
...@@ -1052,18 +1127,21 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1052,18 +1127,21 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
# Node Creation Pre-Processing # Node Creation Pre-Processing
# 1. Get Input AutoGradMeta # 1. Get Input AutoGradMeta
if not self.is_forward_only:
inputs_autograd_meta_list = [] inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"] compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
# Has corresponding grad output # Has corresponding grad output
has_corresponding_grad_output = False has_corresponding_grad_output = False
if not self.is_forward_only:
for _, (_, corresponding_pos, for _, (_, corresponding_pos,
_) in backward_grad_outputs_map.items(): _) in backward_grad_outputs_map.items():
if pos == corresponding_pos: if pos == corresponding_pos:
has_corresponding_grad_output = True has_corresponding_grad_output = True
if has_corresponding_grad_output or ( if has_corresponding_grad_output or (
name in forward_inplace_map name in forward_inplace_map and forward_api_name
and forward_api_name not in inplace_check_blacklist): not in inplace_check_blacklist) or self.is_forward_only:
input_autograd_meta_name = GetAutoGradMetaName(name) input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
...@@ -1074,13 +1152,18 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1074,13 +1152,18 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta) inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name) compute_require_grad_args_list.append(
input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list) compute_require_grad_args_str = ",".join(
compute_require_grad_args_list)
# 2. Get Output AutoGradMeta # 2. Get Output AutoGradMeta
if not self.is_forward_only:
outputs_autograd_meta_list = [] outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys()) num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items(): for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
...@@ -1117,8 +1200,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1117,8 +1200,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inplace_name, inplace_name) inplace_name, inplace_name)
# Node Creation # Node Creation
if not self.is_forward_only:
self.GenerateNodeCreationCodes() self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
else:
node_creation_str = ""
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name) forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
...@@ -1144,13 +1230,30 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1144,13 +1230,30 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_autocast_list_str, amp_call_str) amp_autocast_list_str, amp_call_str)
# Generate forward_definition_str and forward_declaration_str # Generate forward_definition_str and forward_declaration_str
if not self.is_forward_only:
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str, returns_type_str, forward_function_name,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, inputs_args_definition_str, dygraph_event_str, amp_logic_str,
forward_function_name, forward_call_str, check_nan_inf_str, inputs_autograd_meta_str, forward_function_name,
get_outputs_str, outputs_autograd_meta_str, forward_call_str, check_nan_inf_str, get_outputs_str,
compute_require_grad_args_str, check_inplace_str, outputs_autograd_meta_str, compute_require_grad_args_str,
bump_inplace_version_str, node_creation_str, returns_str) check_inplace_str, bump_inplace_version_str, node_creation_str,
returns_str)
else:
if (len(amp_tensors_vector_list) > 0) and (self.forward_api_name
not in no_amp_list):
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str,
amp_logic_str, forward_function_name, forward_call_str,
get_outputs_str, returns_str)
else:
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str, " ",
forward_function_name, forward_call_str, get_outputs_str,
returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
def GenerateInplacedForwardDygraphFunctions(self): def GenerateInplacedForwardDygraphFunctions(self):
...@@ -1648,10 +1751,17 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): ...@@ -1648,10 +1751,17 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
self.node_declaration_str = "" self.node_declaration_str = ""
self.node_definition_str = "" self.node_definition_str = ""
def CollectIsForwardOnly(self, forward_api_contents):
self.is_forward_only = False if 'backward' in forward_api_contents.keys(
) else True
def ParseYamlContents(self): def ParseYamlContents(self):
self.ParseForwardYamlContents() self.ParseForwardYamlContents()
backward_yaml_path = self.backward_yaml_path backward_yaml_path = self.backward_yaml_path
# string api is forward_only, no backward_yaml respectively
if backward_yaml_path is not None:
self.grad_api_dict = ReadBwdFile(backward_yaml_path) self.grad_api_dict = ReadBwdFile(backward_yaml_path)
def GetBackwardAPIContents(self, forward_api_contents): def GetBackwardAPIContents(self, forward_api_contents):
...@@ -1674,9 +1784,13 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): ...@@ -1674,9 +1784,13 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
for forward_api_contents in forward_api_list: for forward_api_contents in forward_api_list:
if forward_api_contents['api'] in black_ops_list: continue if forward_api_contents['api'] in black_ops_list: continue
self.CollectIsForwardOnly(forward_api_contents)
if self.is_forward_only:
backward_api_contents = None
else:
backward_api_contents = self.GetBackwardAPIContents( backward_api_contents = self.GetBackwardAPIContents(
forward_api_contents) forward_api_contents)
if backward_api_contents is None: continue
# Generate Dygraph Forward Function # Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator( function_generator = DygraphForwardFunctionGenerator(
...@@ -1688,6 +1802,8 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): ...@@ -1688,6 +1802,8 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
# Generate Dygraph GradNode Function # Generate Dygraph GradNode Function
while True: while True:
if backward_api_contents is None:
break
next_grad_api_contents = self.GetBackwardAPIContents( next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents) backward_api_contents)
...@@ -1787,7 +1903,12 @@ if __name__ == "__main__": ...@@ -1787,7 +1903,12 @@ if __name__ == "__main__":
for i in range(len(api_yaml_paths)): for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i] api_yaml_path = api_yaml_paths[i]
# string api is forwrad only
if not api_yaml_path.endswith('strings_api.yaml'):
backward_yaml_path = backward_yaml_paths[i] backward_yaml_path = backward_yaml_paths[i]
else:
backward_yaml_path = None
generator = DygraphForwardAndNodesGenerator(api_yaml_path, generator = DygraphForwardAndNodesGenerator(api_yaml_path,
backward_yaml_path) backward_yaml_path)
......
...@@ -51,20 +51,6 @@ atype_to_parsing_function = { ...@@ -51,20 +51,6 @@ atype_to_parsing_function = {
"paddle::experimental::DataType": "CastPyArg2DataType", "paddle::experimental::DataType": "CastPyArg2DataType",
} }
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'sparse_momentum_', 'sparse_momentum', 'full_'
]
def FindParsingFunctionFromAttributeType(atype): def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys(): if atype not in atype_to_parsing_function.keys():
...@@ -131,41 +117,6 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -131,41 +117,6 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n" NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n"
AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
decltype({}({})) out;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out = {}({});
}} else {{
out = {}({});
}}
"""
INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
using result_type = decltype({}({}));
std::unique_ptr<result_type> out_ptr;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out_ptr = std::make_unique<result_type>({}({}));
}} else {{
out_ptr = std::make_unique<result_type>({}({}));
}}
result_type& out = *out_ptr;
"""
FUNCTION_SET_DEVICE_TEMPLATE = \ FUNCTION_SET_DEVICE_TEMPLATE = \
"""{} if (paddle::platform::is_gpu_place(place)) {{ """{} if (paddle::platform::is_gpu_place(place)) {{
...@@ -405,21 +356,13 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -405,21 +356,13 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
num_args = len( num_args = len(
forward_inputs_position_map.keys()) + len(orig_forward_attrs_list) forward_inputs_position_map.keys()) + len(orig_forward_attrs_list)
dygraph_function_call_list = ["" for i in range(num_args)] dygraph_function_call_list = ["" for i in range(num_args)]
amp_dygraph_function_call_list = ["" for i in range(num_args)]
for name, (_, pos) in forward_inputs_position_map.items(): for name, (_, pos) in forward_inputs_position_map.items():
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"NEW_{name}"
for name, _, _, pos in orig_forward_attrs_list: for name, _, _, pos in orig_forward_attrs_list:
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list) dygraph_function_call_str = ",".join(dygraph_function_call_list)
amp_dygraph_function_call_str = ",".join(amp_dygraph_function_call_list)
# Generate Python-C Function Definitions # Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, forward_api_name)
else:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format( fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name)) "::", namespace, GetForwardFunctionName(forward_api_name))
...@@ -429,78 +372,11 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -429,78 +372,11 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format( pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func") "pythonc_record_event", forward_api_name, "pybind_imperative_func")
# Forward amp logic
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
amp_autocast_optional_list = []
for name, (ttype, pos) in forward_inputs_position_map.items():
is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({name}.get());\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
amp_tensors_vector_list.append(f"{name}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({{{name}.get()}});\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
if forward_inplace_map and name in forward_inplace_map.keys(
):
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
amp_tensors_vector_list_str = "{ " + ",".join(
amp_tensors_vector_list) + " }"
amp_tensors_vector_optional_list_str = "".join(
amp_tensors_vector_optional_list)
amp_autocast_list_str = " ".join(
amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list)
kernel_trans2_op_name_str = f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format( noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str, fwd_function_name, fwd_function_name, dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str) dygraph_function_call_str)
amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, fwd_function_name,
amp_dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str)
# Generate Python-C Function Definetion # Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) >
0) and (forward_api_name not in no_amp_list):
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
amp_dygraph_function_str, return_str)
else:
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name, forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str, get_eager_tensor_str, parse_attributes_str, set_device_str,
...@@ -518,11 +394,6 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -518,11 +394,6 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
if forward_inplace_map: if forward_inplace_map:
inplaced_forward_api_name = GetInplacedFunctionName( inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name) self.forward_api_name)
if is_forward_only:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
else:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format( inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, "::", namespace,
GetForwardFunctionName(inplaced_forward_api_name)) GetForwardFunctionName(inplaced_forward_api_name))
...@@ -531,14 +402,6 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -531,14 +402,6 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_fwd_function_name, dygraph_function_call_str, inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str) inplaced_fwd_function_name, dygraph_function_call_str)
inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, inplaced_fwd_function_name,
amp_dygraph_function_call_str, inplaced_fwd_function_name,
dygraph_function_call_str)
return_str = " std::map<ssize_t, ssize_t> inplace_var_idx_map;" return_str = " std::map<ssize_t, ssize_t> inplace_var_idx_map;"
for inplace_input, inplace_output in forward_inplace_map.items(): for inplace_input, inplace_output in forward_inplace_map.items():
return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format( return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
...@@ -547,14 +410,6 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -547,14 +410,6 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return_str += " return ToPyObject(out, args, inplace_var_idx_map);" return_str += " return ToPyObject(out, args, inplace_var_idx_map);"
# Generate Python-C Function Definetion # Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) > 0) and (
inplaced_forward_api_name not in no_amp_list):
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_amp_dygraph_function_str, return_str)
else:
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format( python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str, inplaced_forward_api_name, get_eager_tensor_str,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
- api : bernoulli - api : bernoulli
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
......
...@@ -184,7 +184,7 @@ ...@@ -184,7 +184,7 @@
- api : arange - api : arange
args : (Tensor start, Tensor end, Tensor step, DataType dtype, Place place={}) args : (Tensor start, Tensor end, Tensor step, DataType dtype, Place place={})
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ArangeInferMeta func : ArangeInferMeta
param : [start, end, step] param : [start, end, step]
...@@ -199,7 +199,7 @@ ...@@ -199,7 +199,7 @@
# arg_max # arg_max
- api : argmax - api : argmax
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ArgMinMaxInferMeta func : ArgMinMaxInferMeta
kernel : kernel :
...@@ -208,7 +208,7 @@ ...@@ -208,7 +208,7 @@
# arg_min # arg_min
- api : argmin - api : argmin
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype) args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ArgMinMaxInferMeta func : ArgMinMaxInferMeta
kernel : kernel :
...@@ -366,7 +366,7 @@ ...@@ -366,7 +366,7 @@
# bitwise_and # bitwise_and
- api : bitwise_and - api : bitwise_and
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
...@@ -375,7 +375,7 @@ ...@@ -375,7 +375,7 @@
# bitwise_not # bitwise_not
- api : bitwise_not - api : bitwise_not
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
...@@ -384,7 +384,7 @@ ...@@ -384,7 +384,7 @@
# bitwise_or # bitwise_or
- api : bitwise_or - api : bitwise_or
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
...@@ -393,7 +393,7 @@ ...@@ -393,7 +393,7 @@
# bitwise_xor # bitwise_xor
- api : bitwise_xor - api : bitwise_xor
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
...@@ -557,7 +557,7 @@ ...@@ -557,7 +557,7 @@
- api : copy_to - api : copy_to
args : (Tensor x, Place place, bool blocking) args : (Tensor x, Place place, bool blocking)
output : Tensor output : Tensor(out)
invoke : copy_to_impl(x, place, blocking) invoke : copy_to_impl(x, place, blocking)
# cos # cos
...@@ -672,7 +672,7 @@ ...@@ -672,7 +672,7 @@
- api : diag_embed - api : diag_embed
args : (Tensor x, int offset, int dim1, int dim2) args : (Tensor x, int offset, int dim1, int dim2)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : DiagEmbedInferMeta func : DiagEmbedInferMeta
kernel : kernel :
...@@ -720,7 +720,7 @@ ...@@ -720,7 +720,7 @@
- api : eigvals - api : eigvals
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : EigvalsInferMeta func : EigvalsInferMeta
kernel : kernel :
...@@ -773,7 +773,7 @@ ...@@ -773,7 +773,7 @@
- api : empty - api : empty
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor output: Tensor(out)
infer_meta : infer_meta :
func : CreateInferMeta func : CreateInferMeta
param : [shape, dtype] param : [shape, dtype]
...@@ -785,7 +785,7 @@ ...@@ -785,7 +785,7 @@
- api : empty_like - api : empty_like
args : (Tensor x, DataType dtype = DataType::UNDEFINED, Place place = {}) args : (Tensor x, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor output: Tensor(out)
infer_meta : infer_meta :
func : CreateLikeInferMeta func : CreateLikeInferMeta
param : [x, dtype] param : [x, dtype]
...@@ -797,7 +797,7 @@ ...@@ -797,7 +797,7 @@
- api : equal - api : equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
...@@ -805,7 +805,7 @@ ...@@ -805,7 +805,7 @@
- api : equal_all - api : equal_all
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareAllInferMeta func : CompareAllInferMeta
kernel : kernel :
...@@ -986,7 +986,7 @@ ...@@ -986,7 +986,7 @@
- api : full - api : full
args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor output: Tensor(out)
infer_meta : infer_meta :
func : CreateInferMeta func : CreateInferMeta
param : [shape, dtype] param : [shape, dtype]
...@@ -1012,7 +1012,7 @@ ...@@ -1012,7 +1012,7 @@
- api : full_batch_size_like - api : full_batch_size_like
args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace()) args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace())
output: Tensor output: Tensor(out)
infer_meta : infer_meta :
func : FullBatchSizeLikeInferMeta func : FullBatchSizeLikeInferMeta
param : [input, shape, value, dtype, input_dim_idx, output_dim_idx] param : [input, shape, value, dtype, input_dim_idx, output_dim_idx]
...@@ -1024,7 +1024,7 @@ ...@@ -1024,7 +1024,7 @@
- api : full_like - api : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {}) args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor output: Tensor(out)
infer_meta : infer_meta :
func : CreateLikeInferMeta func : CreateLikeInferMeta
param : [x, dtype] param : [x, dtype]
...@@ -1058,7 +1058,7 @@ ...@@ -1058,7 +1058,7 @@
- api : gather_tree - api : gather_tree
args : (Tensor ids, Tensor parents) args : (Tensor ids, Tensor parents)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : GatherTreeMeta func : GatherTreeMeta
kernel : kernel :
...@@ -1066,7 +1066,7 @@ ...@@ -1066,7 +1066,7 @@
- api : gaussian_random - api : gaussian_random
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor output: Tensor(out)
infer_meta : infer_meta :
func : GaussianRandomInferMeta func : GaussianRandomInferMeta
param : [shape, mean, std, seed, dtype] param : [shape, mean, std, seed, dtype]
...@@ -1118,7 +1118,7 @@ ...@@ -1118,7 +1118,7 @@
- api : greater_equal - api : greater_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
...@@ -1126,7 +1126,7 @@ ...@@ -1126,7 +1126,7 @@
- api : greater_than - api : greater_than
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
...@@ -1211,7 +1211,7 @@ ...@@ -1211,7 +1211,7 @@
# histogram # histogram
- api : histogram - api : histogram
args : (Tensor x, int64_t bins, int min, int max) args : (Tensor x, int64_t bins, int min, int max)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : HistogramInferMeta func : HistogramInferMeta
kernel : kernel :
...@@ -1238,7 +1238,7 @@ ...@@ -1238,7 +1238,7 @@
# increment # increment
- api : increment - api : increment
args : (Tensor x, float value) args : (Tensor x, float value)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : IncrementInferMeta func : IncrementInferMeta
kernel : kernel :
...@@ -1288,7 +1288,7 @@ ...@@ -1288,7 +1288,7 @@
# is_empty # is_empty
- api : is_empty - api : is_empty
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : IsEmptyInferMeta func : IsEmptyInferMeta
kernel : kernel :
...@@ -1306,7 +1306,7 @@ ...@@ -1306,7 +1306,7 @@
# isfinite # isfinite
- api : isfinite - api : isfinite
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
...@@ -1316,7 +1316,7 @@ ...@@ -1316,7 +1316,7 @@
# isinf # isinf
- api : isinf - api : isinf
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
...@@ -1326,7 +1326,7 @@ ...@@ -1326,7 +1326,7 @@
# isnan # isnan
- api : isnan - api : isnan
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : IsfiniteInferMeta func : IsfiniteInferMeta
kernel : kernel :
...@@ -1419,7 +1419,7 @@ ...@@ -1419,7 +1419,7 @@
- api : less_equal - api : less_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
...@@ -1427,7 +1427,7 @@ ...@@ -1427,7 +1427,7 @@
- api : less_than - api : less_than
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
...@@ -1446,7 +1446,7 @@ ...@@ -1446,7 +1446,7 @@
- api : linspace - api : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype) args : (Tensor start, Tensor stop, Tensor number, DataType dtype)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : LinspaceInferMeta func : LinspaceInferMeta
kernel : kernel :
...@@ -1520,7 +1520,7 @@ ...@@ -1520,7 +1520,7 @@
# logical_and # logical_and
- api : logical_and - api : logical_and
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
...@@ -1529,7 +1529,7 @@ ...@@ -1529,7 +1529,7 @@
# logical_not # logical_not
- api : logical_not - api : logical_not
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
...@@ -1538,7 +1538,7 @@ ...@@ -1538,7 +1538,7 @@
# logical_or # logical_or
- api : logical_or - api : logical_or
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
...@@ -1547,7 +1547,7 @@ ...@@ -1547,7 +1547,7 @@
# logical_xor # logical_xor
- api : logical_xor - api : logical_xor
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
kernel : kernel :
...@@ -1827,7 +1827,7 @@ ...@@ -1827,7 +1827,7 @@
# multinomial # multinomial
- api : multinomial - api : multinomial
args : (Tensor x, int num_samples, bool replacement) args : (Tensor x, int num_samples, bool replacement)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : MultinomialInferMeta func : MultinomialInferMeta
kernel : kernel :
...@@ -1895,7 +1895,7 @@ ...@@ -1895,7 +1895,7 @@
- api : not_equal - api : not_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
...@@ -1903,7 +1903,7 @@ ...@@ -1903,7 +1903,7 @@
- api : one_hot - api : one_hot
args : (Tensor x, Scalar(int) num_classes) args : (Tensor x, Scalar(int) num_classes)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : OneHotInferMeta func : OneHotInferMeta
kernel : kernel :
...@@ -1911,12 +1911,12 @@ ...@@ -1911,12 +1911,12 @@
- api : ones - api : ones
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor output : Tensor(out)
invoke : full(shape, 1, dtype, place) invoke : full(shape, 1, dtype, place)
- api : ones_like - api : ones_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place={}) args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place={})
output : Tensor output : Tensor(out)
invoke : full_like(x, 1, dtype, place) invoke : full_like(x, 1, dtype, place)
- api : p_norm - api : p_norm
...@@ -2061,7 +2061,7 @@ ...@@ -2061,7 +2061,7 @@
- api : randperm - api : randperm
args : (int n, DataType dtype, Place place={}) args : (int n, DataType dtype, Place place={})
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : RandpermInferMeta func : RandpermInferMeta
param : [n, dtype] param : [n, dtype]
...@@ -2322,7 +2322,7 @@ ...@@ -2322,7 +2322,7 @@
- api : shape - api : shape
args : (Tensor input) args : (Tensor input)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ShapeInferMeta func : ShapeInferMeta
kernel : kernel :
...@@ -2334,7 +2334,7 @@ ...@@ -2334,7 +2334,7 @@
# shard_index # shard_index
- api : shard_index - api : shard_index
args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value) args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : ShardIndexInferMeta func : ShardIndexInferMeta
kernel : kernel :
...@@ -2362,7 +2362,7 @@ ...@@ -2362,7 +2362,7 @@
- api : sign - api : sign
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
...@@ -2401,7 +2401,7 @@ ...@@ -2401,7 +2401,7 @@
# size # size
- api : size - api : size
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(size)
infer_meta : infer_meta :
func : SizeInferMeta func : SizeInferMeta
kernel : kernel :
...@@ -2716,7 +2716,7 @@ ...@@ -2716,7 +2716,7 @@
# python API: paddle.nn.initializer.TruncatedNormal # python API: paddle.nn.initializer.TruncatedNormal
- api : truncated_gaussian_random - api : truncated_gaussian_random
args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={}) args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : TruncatedGaussianRandomInferMeta func : TruncatedGaussianRandomInferMeta
param : [shape, mean, std, seed, dtype] param : [shape, mean, std, seed, dtype]
...@@ -2831,7 +2831,7 @@ ...@@ -2831,7 +2831,7 @@
# where_index # where_index
- api : where_index - api : where_index
args : (Tensor condition) args : (Tensor condition)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : WhereIndexInferMeta func : WhereIndexInferMeta
kernel : kernel :
...@@ -2861,12 +2861,12 @@ ...@@ -2861,12 +2861,12 @@
- api : zeros - api : zeros
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor output : Tensor(out)
invoke : full(shape, 0, dtype, place) invoke : full(shape, 0, dtype, place)
- api : zeros_like - api : zeros_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {}) args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {})
output : Tensor output : Tensor(out)
invoke : full_like(x, 0, dtype, place) invoke : full_like(x, 0, dtype, place)
- api: broadcast_tensors - api: broadcast_tensors
...@@ -2881,7 +2881,7 @@ ...@@ -2881,7 +2881,7 @@
# dirichlet # dirichlet
- api: dirichlet - api: dirichlet
args: (Tensor alpha) args: (Tensor alpha)
output: Tensor output: Tensor(out)
infer_meta: infer_meta:
func: DirichletInferMeta func: DirichletInferMeta
kernel: kernel:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册