未验证 提交 933db9d4 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Forword only add dygraph func (#45153)

* [Eager draft] forward_only interface migrate to autograd_api

* strings api add dygraph forward function

* rm useless comments

* draft version for check CI

* fix ci

* forward-only no need compute_require_grad and pass stop_gradient, rm useless comments

* polish yaml and using CPUPlace = phi::CPUPlace

* rm useless comments

* polish yaml and update some test case

* rm useless funcs

* polish eager_gen code

* polish code
上级 f706d95d
......@@ -38,7 +38,7 @@ add_custom_target(
COMMAND
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py"
"--api_yaml_path=${api_yaml_path}"
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
"--backward_yaml_path=${backward_yaml_path}"
"--forwards_cc_path=${tmp_forwards_cc_path}"
"--forwards_h_path=${tmp_forwards_h_path}"
......
......@@ -353,6 +353,9 @@ class FunctionGeneratorBase:
self.forward_api_contents = forward_api_contents
self.namespace = namespace
self.is_forward_only = False if 'backward' in forward_api_contents.keys(
) else True
self.forward_api_name = ""
self.orig_forward_inputs_list = [
......
......@@ -51,20 +51,6 @@ atype_to_parsing_function = {
"paddle::experimental::DataType": "CastPyArg2DataType",
}
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'sparse_momentum_', 'sparse_momentum', 'full_'
]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
......@@ -131,41 +117,6 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n"
AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
decltype({}({})) out;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out = {}({});
}} else {{
out = {}({});
}}
"""
INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
using result_type = decltype({}({}));
std::unique_ptr<result_type> out_ptr;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out_ptr = std::make_unique<result_type>({}({}));
}} else {{
out_ptr = std::make_unique<result_type>({}({}));
}}
result_type& out = *out_ptr;
"""
FUNCTION_SET_DEVICE_TEMPLATE = \
"""{} if (paddle::platform::is_gpu_place(place)) {{
......@@ -405,23 +356,15 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
num_args = len(
forward_inputs_position_map.keys()) + len(orig_forward_attrs_list)
dygraph_function_call_list = ["" for i in range(num_args)]
amp_dygraph_function_call_list = ["" for i in range(num_args)]
for name, (_, pos) in forward_inputs_position_map.items():
dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"NEW_{name}"
for name, _, _, pos in orig_forward_attrs_list:
dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
amp_dygraph_function_call_str = ",".join(amp_dygraph_function_call_list)
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, forward_api_name)
else:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name))
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name))
return_str = " return ToPyObject(out);"
......@@ -429,82 +372,15 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func")
# Forward amp logic
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
amp_autocast_optional_list = []
for name, (ttype, pos) in forward_inputs_position_map.items():
is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({name}.get());\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
amp_tensors_vector_list.append(f"{name}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({{{name}.get()}});\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
if forward_inplace_map and name in forward_inplace_map.keys(
):
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
amp_tensors_vector_list_str = "{ " + ",".join(
amp_tensors_vector_list) + " }"
amp_tensors_vector_optional_list_str = "".join(
amp_tensors_vector_optional_list)
amp_autocast_list_str = " ".join(
amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list)
kernel_trans2_op_name_str = f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str)
amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, fwd_function_name,
amp_dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str)
# Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) >
0) and (forward_api_name not in no_amp_list):
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
amp_dygraph_function_str, return_str)
else:
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
noamp_dygraph_function_str, return_str)
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
noamp_dygraph_function_str, return_str)
# Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::")
......@@ -518,27 +394,14 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
if forward_inplace_map:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
if is_forward_only:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
else:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
inplace_noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str)
inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, inplaced_fwd_function_name,
amp_dygraph_function_call_str, inplaced_fwd_function_name,
dygraph_function_call_str)
return_str = " std::map<ssize_t, ssize_t> inplace_var_idx_map;"
for inplace_input, inplace_output in forward_inplace_map.items():
return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
......@@ -547,19 +410,11 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return_str += " return ToPyObject(out, args, inplace_var_idx_map);"
# Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) > 0) and (
inplaced_forward_api_name not in no_amp_list):
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_amp_dygraph_function_str, return_str)
else:
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_noamp_dygraph_function_str, return_str)
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_noamp_dygraph_function_str, return_str)
python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, inplaced_forward_api_name, namespace,
......
......@@ -9,7 +9,7 @@
- api : bernoulli
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......
......@@ -184,7 +184,7 @@
- api : arange
args : (Tensor start, Tensor end, Tensor step, DataType dtype, Place place={})
output : Tensor
output : Tensor(out)
infer_meta :
func : ArangeInferMeta
param : [start, end, step]
......@@ -199,7 +199,7 @@
# arg_max
- api : argmax
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
kernel :
......@@ -208,7 +208,7 @@
# arg_min
- api : argmin
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
kernel :
......@@ -366,7 +366,7 @@
# bitwise_and
- api : bitwise_and
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -375,7 +375,7 @@
# bitwise_not
- api : bitwise_not
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......@@ -384,7 +384,7 @@
# bitwise_or
- api : bitwise_or
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -393,7 +393,7 @@
# bitwise_xor
- api : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -557,7 +557,7 @@
- api : copy_to
args : (Tensor x, Place place, bool blocking)
output : Tensor
output : Tensor(out)
invoke : copy_to_impl(x, place, blocking)
# cos
......@@ -672,7 +672,7 @@
- api : diag_embed
args : (Tensor x, int offset, int dim1, int dim2)
output : Tensor
output : Tensor(out)
infer_meta :
func : DiagEmbedInferMeta
kernel :
......@@ -720,7 +720,7 @@
- api : eigvals
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : EigvalsInferMeta
kernel :
......@@ -773,7 +773,7 @@
- api : empty
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateInferMeta
param : [shape, dtype]
......@@ -785,7 +785,7 @@
- api : empty_like
args : (Tensor x, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateLikeInferMeta
param : [x, dtype]
......@@ -797,7 +797,7 @@
- api : equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -805,7 +805,7 @@
- api : equal_all
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareAllInferMeta
kernel :
......@@ -986,7 +986,7 @@
- api : full
args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateInferMeta
param : [shape, dtype]
......@@ -1012,7 +1012,7 @@
- api : full_batch_size_like
args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace())
output: Tensor
output: Tensor(out)
infer_meta :
func : FullBatchSizeLikeInferMeta
param : [input, shape, value, dtype, input_dim_idx, output_dim_idx]
......@@ -1024,7 +1024,7 @@
- api : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateLikeInferMeta
param : [x, dtype]
......@@ -1058,7 +1058,7 @@
- api : gather_tree
args : (Tensor ids, Tensor parents)
output : Tensor
output : Tensor(out)
infer_meta :
func : GatherTreeMeta
kernel :
......@@ -1066,7 +1066,7 @@
- api : gaussian_random
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor
output: Tensor(out)
infer_meta :
func : GaussianRandomInferMeta
param : [shape, mean, std, seed, dtype]
......@@ -1118,7 +1118,7 @@
- api : greater_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1126,7 +1126,7 @@
- api : greater_than
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1211,7 +1211,7 @@
# histogram
- api : histogram
args : (Tensor x, int64_t bins, int min, int max)
output : Tensor
output : Tensor(out)
infer_meta :
func : HistogramInferMeta
kernel :
......@@ -1238,7 +1238,7 @@
# increment
- api : increment
args : (Tensor x, float value)
output : Tensor
output : Tensor(out)
infer_meta :
func : IncrementInferMeta
kernel :
......@@ -1288,7 +1288,7 @@
# is_empty
- api : is_empty
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsEmptyInferMeta
kernel :
......@@ -1306,7 +1306,7 @@
# isfinite
- api : isfinite
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
kernel :
......@@ -1316,7 +1316,7 @@
# isinf
- api : isinf
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
kernel :
......@@ -1326,7 +1326,7 @@
# isnan
- api : isnan
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
kernel :
......@@ -1419,7 +1419,7 @@
- api : less_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1427,7 +1427,7 @@
- api : less_than
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1446,7 +1446,7 @@
- api : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : LinspaceInferMeta
kernel :
......@@ -1520,7 +1520,7 @@
# logical_and
- api : logical_and
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -1529,7 +1529,7 @@
# logical_not
- api : logical_not
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......@@ -1538,7 +1538,7 @@
# logical_or
- api : logical_or
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -1547,7 +1547,7 @@
# logical_xor
- api : logical_xor
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -1827,7 +1827,7 @@
# multinomial
- api : multinomial
args : (Tensor x, int num_samples, bool replacement)
output : Tensor
output : Tensor(out)
infer_meta :
func : MultinomialInferMeta
kernel :
......@@ -1895,7 +1895,7 @@
- api : not_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1903,7 +1903,7 @@
- api : one_hot
args : (Tensor x, Scalar(int) num_classes)
output : Tensor
output : Tensor(out)
infer_meta :
func : OneHotInferMeta
kernel :
......@@ -1911,12 +1911,12 @@
- api : ones
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor
output : Tensor(out)
invoke : full(shape, 1, dtype, place)
- api : ones_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place={})
output : Tensor
output : Tensor(out)
invoke : full_like(x, 1, dtype, place)
- api : p_norm
......@@ -2061,7 +2061,7 @@
- api : randperm
args : (int n, DataType dtype, Place place={})
output : Tensor
output : Tensor(out)
infer_meta :
func : RandpermInferMeta
param : [n, dtype]
......@@ -2322,7 +2322,7 @@
- api : shape
args : (Tensor input)
output : Tensor
output : Tensor(out)
infer_meta :
func : ShapeInferMeta
kernel :
......@@ -2334,7 +2334,7 @@
# shard_index
- api : shard_index
args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value)
output : Tensor
output : Tensor(out)
infer_meta :
func : ShardIndexInferMeta
kernel :
......@@ -2362,7 +2362,7 @@
- api : sign
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......@@ -2401,7 +2401,7 @@
# size
- api : size
args : (Tensor x)
output : Tensor
output : Tensor(size)
infer_meta :
func : SizeInferMeta
kernel :
......@@ -2716,7 +2716,7 @@
# python API: paddle.nn.initializer.TruncatedNormal
- api : truncated_gaussian_random
args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor
output : Tensor(out)
infer_meta :
func : TruncatedGaussianRandomInferMeta
param : [shape, mean, std, seed, dtype]
......@@ -2831,7 +2831,7 @@
# where_index
- api : where_index
args : (Tensor condition)
output : Tensor
output : Tensor(out)
infer_meta :
func : WhereIndexInferMeta
kernel :
......@@ -2861,12 +2861,12 @@
- api : zeros
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor
output : Tensor(out)
invoke : full(shape, 0, dtype, place)
- api : zeros_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {})
output : Tensor
output : Tensor(out)
invoke : full_like(x, 0, dtype, place)
- api: broadcast_tensors
......@@ -2881,7 +2881,7 @@
# dirichlet
- api: dirichlet
args: (Tensor alpha)
output: Tensor
output: Tensor(out)
infer_meta:
func: DirichletInferMeta
kernel:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册