From e8d78a70007f43eb361a8a23a0961bdf4674a634 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 14 Jul 2022 10:43:55 +0800 Subject: [PATCH] [AMP] Add amp logic in python_C (#44309) * add amp logic in python_C * fix inplace bug --- .../final_state_generator/python_c_gen.py | 178 ++++++++++++++++-- paddle/fluid/eager/eager_amp_auto_cast.h | 27 ++- 2 files changed, 185 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index d1e7885bae4..c6ac5a12f56 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -50,6 +50,45 @@ atype_to_parsing_function = { "paddle::experimental::DataType": "CastPyArg2DataType", } +# This list contains ops that do not need to generate amp logic +# All optimizer ops in this list +no_amp_list = [ + 'adam_', + 'adam', + 'adamw_', + 'adamw', + 'decayed_adagrad_', + 'decayed_adagrad', + 'dgc_momentum_', + 'dgc_momentum', + 'distributed_fused_lamb_', + 'distributed_fused_lamb', + 'dpsgd_', + 'dpsgd', + 'ftrl_', + 'ftrl', + 'lamb_', + 'lamb', + 'lars_momentum_', + 'lars_momentum', + 'merged_adam_', + 'merged_adam', + 'merged_momentum_', + 'merged_momentum', + 'momentum_', + 'momentum', + 'proximal_adagrad_', + 'proximal_adagrad', + 'proximal_gd_', + 'proximal_gd', + 'rmsprop_', + 'rmsprop', + 'sgd_', + 'sgd', + 'sparse_momentum_', + 'sparse_momentum', +] + def FindParsingFunctionFromAttributeType(atype): if atype not in atype_to_parsing_function.keys(): @@ -99,7 +138,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj // Set Device ID {} // Call dygraph function - decltype({}({})) out = {}({}); + {} PyEval_RestoreThread(tstate); tstate = nullptr; @@ -114,6 +153,25 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj }} """ +NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n" + +AMP_DYGRAPH_FUNCTION_TEMPLATE = \ +""" + decltype({}({})) out; + // AMP Logic + if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{ + VLOG(5) << "Check and Prepare For AMP"; + {} + paddle::small_vector, egr::kSlotSmallVectorSize> amp_tensors_vector = {}; + {} + {} + {} + out = {}({}); + }} else {{ + out = {}({}); + }} +""" + FUNCTION_SET_DEVICE_TEMPLATE = \ """{} if (paddle::platform::is_gpu_place(place)) {{ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -149,6 +207,8 @@ PYTHON_C_WRAPPER_TEMPLATE = \ #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/pybind/eager_final_state_custom_python_api.h" #include "paddle/fluid/pybind/eager.h" +#include "paddle/fluid/eager/amp_utils.h" +#include "paddle/fluid/eager/eager_amp_auto_cast.h" namespace paddle {{ namespace pybind {{ @@ -335,11 +395,15 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): num_args = len( forward_inputs_position_map.keys()) + len(orig_forward_attrs_list) dygraph_function_call_list = ["" for i in range(num_args)] + amp_dygraph_function_call_list = ["" for i in range(num_args)] for name, (_, pos) in forward_inputs_position_map.items(): dygraph_function_call_list[pos] = f"{name}" + amp_dygraph_function_call_list[pos] = f"NEW_{name}" for name, _, _, pos in orig_forward_attrs_list: dygraph_function_call_list[pos] = f"{name}" + amp_dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_str = ",".join(dygraph_function_call_list) + amp_dygraph_function_call_str = ",".join(amp_dygraph_function_call_list) # Generate Python-C Function Definitions if is_forward_only: @@ -355,12 +419,82 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format( "pythonc_record_event", forward_api_name, "pybind_imperative_func") - # Generate Python-C Function Definetion - self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( - forward_api_name, pythonc_record_event_str, forward_api_name, - get_eager_tensor_str, parse_attributes_str, set_device_str, + # Forward amp logic + amp_tensors_vector_list = [] + amp_tensors_vector_optional_list = [] + amp_autocast_list = [] + amp_autocast_optional_list = [] + + for name, (ttype, pos) in forward_inputs_position_map.items(): + is_optional = (name in optional_inputs) + if IsVectorTensorType(ttype): + if is_optional: + amp_tensors_vector_optional_list.append( + f"if ({name}.is_initialized()) amp_tensors_vector.push_back({name}.get());\n" + ) + amp_autocast_optional_list.append( + f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n" + ) + else: + amp_tensors_vector_list.append(f"{name}") + amp_autocast_list.append( + f"auto NEW_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n" + ) + else: + if is_optional: + amp_tensors_vector_optional_list.append( + f"if ({name}.is_initialized()) amp_tensors_vector.push_back({{{name}.get()}});\n" + ) + amp_autocast_optional_list.append( + f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n" + ) + else: + if forward_inplace_map and name in forward_inplace_map.keys( + ): + amp_tensors_vector_list.append(f"{{{name}}}") + amp_autocast_list.append( + f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n" + ) + else: + amp_tensors_vector_list.append(f"{{{name}}}") + amp_autocast_list.append( + f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n" + ) + amp_tensors_vector_list_str = "{ " + ",".join( + amp_tensors_vector_list) + " }" + amp_tensors_vector_optional_list_str = "".join( + amp_tensors_vector_optional_list) + amp_autocast_list_str = " ".join( + amp_autocast_list) + " " + " ".join( + amp_autocast_optional_list) + + kernel_trans2_op_name_str = f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");" + amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n" + + noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format( fwd_function_name, dygraph_function_call_str, fwd_function_name, - dygraph_function_call_str, return_str) + dygraph_function_call_str) + + amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format( + fwd_function_name, dygraph_function_call_str, + kernel_trans2_op_name_str, amp_tensors_vector_list_str, + amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str, + amp_autocast_list_str, fwd_function_name, + amp_dygraph_function_call_str, fwd_function_name, + dygraph_function_call_str) + + # Generate Python-C Function Definetion + if (is_forward_only) and (len(amp_tensors_vector_list) > + 0) and (forward_api_name not in no_amp_list): + self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( + forward_api_name, pythonc_record_event_str, forward_api_name, + get_eager_tensor_str, parse_attributes_str, set_device_str, + amp_dygraph_function_str, return_str) + else: + self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( + forward_api_name, pythonc_record_event_str, forward_api_name, + get_eager_tensor_str, parse_attributes_str, set_device_str, + noamp_dygraph_function_str, return_str) # Set prefix of forward_api_name to avoid conflicts prefix = self.namespace.strip("::") @@ -383,6 +517,18 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): "::", namespace, GetForwardFunctionName(inplaced_forward_api_name)) + inplace_noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format( + inplaced_fwd_function_name, dygraph_function_call_str, + inplaced_fwd_function_name, dygraph_function_call_str) + + inplace_amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format( + inplaced_fwd_function_name, dygraph_function_call_str, + kernel_trans2_op_name_str, amp_tensors_vector_list_str, + amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str, + amp_autocast_list_str, inplaced_fwd_function_name, + amp_dygraph_function_call_str, inplaced_fwd_function_name, + dygraph_function_call_str) + return_str = " std::map inplace_var_idx_map;" for inplace_input, inplace_output in forward_inplace_map.items(): return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format( @@ -391,13 +537,19 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): return_str += " return ToPyObject(out, args, inplace_var_idx_map);" # Generate Python-C Function Definetion - python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format( - inplaced_forward_api_name, pythonc_record_event_str, - inplaced_forward_api_name, get_eager_tensor_str, - parse_attributes_str, set_device_str, - inplaced_fwd_function_name, dygraph_function_call_str, - inplaced_fwd_function_name, dygraph_function_call_str, - return_str) + if (is_forward_only) and (len(amp_tensors_vector_list) > 0) and ( + inplaced_forward_api_name not in no_amp_list): + python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format( + inplaced_forward_api_name, pythonc_record_event_str, + inplaced_forward_api_name, get_eager_tensor_str, + parse_attributes_str, set_device_str, + inplace_amp_dygraph_function_str, return_str) + else: + python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format( + inplaced_forward_api_name, pythonc_record_event_str, + inplaced_forward_api_name, get_eager_tensor_str, + parse_attributes_str, set_device_str, + inplace_noamp_dygraph_function_str, return_str) python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( forward_api_name_prefix, inplaced_forward_api_name, namespace, diff --git a/paddle/fluid/eager/eager_amp_auto_cast.h b/paddle/fluid/eager/eager_amp_auto_cast.h index 438ccbaca8a..26af2b98ca0 100644 --- a/paddle/fluid/eager/eager_amp_auto_cast.h +++ b/paddle/fluid/eager/eager_amp_auto_cast.h @@ -43,15 +43,21 @@ inline std::vector EagerAmpAutoCasts( const std::string& inputs_name, const std::vector& inputs, const paddle::experimental::DataType& dst_dtype, - std::string op_name) { + std::string op_name, + bool trace_backward = true) { VLOG(6) << "AMP AmpAutoCasts:" << " inputs(" << inputs_name << ") dst_dtype(" << paddle::framework::DataType2String(dst_dtype) << ")."; std::vector inputs_casted; for (auto& input : inputs) { if (NeedCast(input, dst_dtype)) { - inputs_casted.emplace_back( - std::move(cast_final_state_dygraph_function(input, dst_dtype))); + if (trace_backward) { + inputs_casted.emplace_back( + std::move(cast_final_state_dygraph_function(input, dst_dtype))); + } else { + inputs_casted.emplace_back( + std::move(paddle::experimental::cast(input, dst_dtype))); + } } else { inputs_casted.emplace_back(input); } @@ -63,7 +69,8 @@ inline paddle::experimental::Tensor EagerAmpAutoCast( const std::string& input_name, const paddle::experimental::Tensor& input, const paddle::experimental::DataType& dst_dtype, - const std::string& op_name) { + const std::string& op_name, + bool trace_backward = true) { VLOG(6) << "AMP AmpAutoCasts:" << " input(" << input_name << ") dst_dtype(" << paddle::framework::DataType2String(dst_dtype) << ")."; @@ -85,7 +92,11 @@ inline paddle::experimental::Tensor EagerAmpAutoCast( } } if (NeedCast(input, dst_dtype)) { - return cast_final_state_dygraph_function(input, dst_dtype); + if (trace_backward) { + return cast_final_state_dygraph_function(input, dst_dtype); + } else { + return paddle::experimental::cast(input, dst_dtype); + } } return input; } @@ -94,9 +105,11 @@ inline paddle::optional EagerAmpAutoCast( const std::string& input_name, const paddle::optional& input, const paddle::experimental::DataType& dst_dtype, - const std::string& op_name) { + const std::string& op_name, + bool trace_backward = true) { if (input) { - return EagerAmpAutoCast(input_name, *input, dst_dtype, op_name); + return EagerAmpAutoCast( + input_name, *input, dst_dtype, op_name, trace_backward); } return paddle::none; } -- GitLab