未验证 提交 e8d78a70 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] Add amp logic in python_C (#44309)

* add amp logic in python_C

* fix inplace bug
上级 d15b490a
......@@ -50,6 +50,45 @@ atype_to_parsing_function = {
"paddle::experimental::DataType": "CastPyArg2DataType",
}
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_',
'adam',
'adamw_',
'adamw',
'decayed_adagrad_',
'decayed_adagrad',
'dgc_momentum_',
'dgc_momentum',
'distributed_fused_lamb_',
'distributed_fused_lamb',
'dpsgd_',
'dpsgd',
'ftrl_',
'ftrl',
'lamb_',
'lamb',
'lars_momentum_',
'lars_momentum',
'merged_adam_',
'merged_adam',
'merged_momentum_',
'merged_momentum',
'momentum_',
'momentum',
'proximal_adagrad_',
'proximal_adagrad',
'proximal_gd_',
'proximal_gd',
'rmsprop_',
'rmsprop',
'sgd_',
'sgd',
'sparse_momentum_',
'sparse_momentum',
]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
......@@ -99,7 +138,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID
{}
// Call dygraph function
decltype({}({})) out = {}({});
{}
PyEval_RestoreThread(tstate);
tstate = nullptr;
......@@ -114,6 +153,25 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
}}
"""
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n"
AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
decltype({}({})) out;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out = {}({});
}} else {{
out = {}({});
}}
"""
FUNCTION_SET_DEVICE_TEMPLATE = \
"""{} if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -149,6 +207,8 @@ PYTHON_C_WRAPPER_TEMPLATE = \
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/eager_final_state_custom_python_api.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
namespace paddle {{
namespace pybind {{
......@@ -335,11 +395,15 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
num_args = len(
forward_inputs_position_map.keys()) + len(orig_forward_attrs_list)
dygraph_function_call_list = ["" for i in range(num_args)]
amp_dygraph_function_call_list = ["" for i in range(num_args)]
for name, (_, pos) in forward_inputs_position_map.items():
dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"NEW_{name}"
for name, _, _, pos in orig_forward_attrs_list:
dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
amp_dygraph_function_call_str = ",".join(amp_dygraph_function_call_list)
# Generate Python-C Function Definitions
if is_forward_only:
......@@ -355,12 +419,82 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func")
# Generate Python-C Function Definetion
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
# Forward amp logic
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
amp_autocast_optional_list = []
for name, (ttype, pos) in forward_inputs_position_map.items():
is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({name}.get());\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
amp_tensors_vector_list.append(f"{name}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({{{name}.get()}});\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
if forward_inplace_map and name in forward_inplace_map.keys(
):
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
amp_tensors_vector_list_str = "{ " + ",".join(
amp_tensors_vector_list) + " }"
amp_tensors_vector_optional_list_str = "".join(
amp_tensors_vector_optional_list)
amp_autocast_list_str = " ".join(
amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list)
kernel_trans2_op_name_str = f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str, return_str)
dygraph_function_call_str)
amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, fwd_function_name,
amp_dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str)
# Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) >
0) and (forward_api_name not in no_amp_list):
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
amp_dygraph_function_str, return_str)
else:
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
noamp_dygraph_function_str, return_str)
# Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::")
......@@ -383,6 +517,18 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
inplace_noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str)
inplace_amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, inplaced_fwd_function_name,
amp_dygraph_function_call_str, inplaced_fwd_function_name,
dygraph_function_call_str)
return_str = " std::map<ssize_t, ssize_t> inplace_var_idx_map;"
for inplace_input, inplace_output in forward_inplace_map.items():
return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
......@@ -391,13 +537,19 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return_str += " return ToPyObject(out, args, inplace_var_idx_map);"
# Generate Python-C Function Definetion
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str,
return_str)
if (is_forward_only) and (len(amp_tensors_vector_list) > 0) and (
inplaced_forward_api_name not in no_amp_list):
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_amp_dygraph_function_str, return_str)
else:
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_noamp_dygraph_function_str, return_str)
python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, inplaced_forward_api_name, namespace,
......
......@@ -43,15 +43,21 @@ inline std::vector<paddle::experimental::Tensor> EagerAmpAutoCasts(
const std::string& inputs_name,
const std::vector<paddle::experimental::Tensor>& inputs,
const paddle::experimental::DataType& dst_dtype,
std::string op_name) {
std::string op_name,
bool trace_backward = true) {
VLOG(6) << "AMP AmpAutoCasts:"
<< " inputs(" << inputs_name << ") dst_dtype("
<< paddle::framework::DataType2String(dst_dtype) << ").";
std::vector<paddle::experimental::Tensor> inputs_casted;
for (auto& input : inputs) {
if (NeedCast(input, dst_dtype)) {
inputs_casted.emplace_back(
std::move(cast_final_state_dygraph_function(input, dst_dtype)));
if (trace_backward) {
inputs_casted.emplace_back(
std::move(cast_final_state_dygraph_function(input, dst_dtype)));
} else {
inputs_casted.emplace_back(
std::move(paddle::experimental::cast(input, dst_dtype)));
}
} else {
inputs_casted.emplace_back(input);
}
......@@ -63,7 +69,8 @@ inline paddle::experimental::Tensor EagerAmpAutoCast(
const std::string& input_name,
const paddle::experimental::Tensor& input,
const paddle::experimental::DataType& dst_dtype,
const std::string& op_name) {
const std::string& op_name,
bool trace_backward = true) {
VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << input_name << ") dst_dtype("
<< paddle::framework::DataType2String(dst_dtype) << ").";
......@@ -85,7 +92,11 @@ inline paddle::experimental::Tensor EagerAmpAutoCast(
}
}
if (NeedCast(input, dst_dtype)) {
return cast_final_state_dygraph_function(input, dst_dtype);
if (trace_backward) {
return cast_final_state_dygraph_function(input, dst_dtype);
} else {
return paddle::experimental::cast(input, dst_dtype);
}
}
return input;
}
......@@ -94,9 +105,11 @@ inline paddle::optional<paddle::experimental::Tensor> EagerAmpAutoCast(
const std::string& input_name,
const paddle::optional<paddle::experimental::Tensor>& input,
const paddle::experimental::DataType& dst_dtype,
const std::string& op_name) {
const std::string& op_name,
bool trace_backward = true) {
if (input) {
return EagerAmpAutoCast(input_name, *input, dst_dtype, op_name);
return EagerAmpAutoCast(
input_name, *input, dst_dtype, op_name, trace_backward);
}
return paddle::none;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册