diff --git a/paddle/fluid/eager/amp_utils.h b/paddle/fluid/eager/amp_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..229af41a38ad0aadba663c9fbe40634a7fd25466 --- /dev/null +++ b/paddle/fluid/eager/amp_utils.h @@ -0,0 +1,245 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/imperative/amp_auto_cast.h" + +namespace egr { + +static inline paddle::experimental::DataType GetPromoteType( + const std::string& api_name, + const std::vector>& + amp_tensors_vector, + const paddle::experimental::DataType& amp_dtype) { + auto dst_type = amp_dtype; + if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() == + "float16") { + if (api_name == "batch_norm" || api_name == "layer_norm" || + api_name == "sync_batch_norm") { + if (amp_tensors_vector[0][0].dtype() == + paddle::experimental::DataType::FLOAT32) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + } else if (api_name == "fused_attention") { + for (size_t i = 0; i < amp_tensors_vector.size(); i++) { + if (i != 3 || i != 4 || i != 9 || i != 10) { + if (amp_tensors_vector[i][0].dtype() == + paddle::experimental::DataType::FLOAT32) { + dst_type = paddle::experimental::DataType::FLOAT32; + break; + } + } + } + } else if (api_name == "fused_feedforward") { + for (size_t i = 0; i < amp_tensors_vector.size(); i++) { + if (i != 7 || i != 8 || i != 9 || i != 10) { + if (amp_tensors_vector[i][0].dtype() == + paddle::experimental::DataType::FLOAT32) { + dst_type = paddle::experimental::DataType::FLOAT32; + break; + } + } + } + } else { + for (const auto& tensors : amp_tensors_vector) { + for (const auto& tensor : tensors) { + if (tensor.dtype() == paddle::experimental::DataType::FLOAT32) { + dst_type = tensor.dtype(); + break; + } + } + } + } + } else { + for (const auto& tensors : amp_tensors_vector) { + for (const auto& tensor : tensors) { + if (tensor.dtype() == paddle::experimental::DataType::FLOAT32) { + dst_type = tensor.dtype(); + break; + } + } + } + } + // NOTE(juncai): moving_average_abs_max_scale only consider the dtype of + // input(X) + if (api_name == "moving_average_abs_max_scale") { + if (amp_tensors_vector[0][0].dtype() == + paddle::experimental::DataType::FLOAT16) { + dst_type = paddle::experimental::DataType::FLOAT16; + } + } + return dst_type; +} + +paddle::experimental::DataType GetAmpDestDtype( + const std::string& api_name, + const std::vector>& + amp_tensors_vector) { + auto amp_dtype = + egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype(); + auto amp_level = egr::Controller::Instance().GetAMPLevel(); + VLOG(6) << "AMP GetAmpDestDtype:" + << " op(" << api_name << ") amp_dtype(" << amp_dtype << ") amp_level(" + << static_cast(amp_level) << ")."; + if (amp_dtype == "float16") { + if (amp_level == paddle::imperative::AmpLevel::O1) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableAllowOps() + ->count(api_name)) { + return paddle::experimental::DataType::FLOAT16; + } else if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(api_name)) { + return paddle::experimental::DataType::FLOAT32; + } else { + auto dst_type = GetPromoteType(api_name, amp_tensors_vector, + paddle::experimental::DataType::FLOAT16); + if (dst_type == paddle::experimental::DataType::FLOAT16 && + paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedFp16Ops() + ->count(api_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + return dst_type; + } + } else if (amp_level == paddle::imperative::AmpLevel::O2) { + auto dst_type = paddle::experimental::DataType::FLOAT16; + if (paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedFp16Ops() + ->count(api_name) || + paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(api_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + return dst_type; + } + } else if (amp_dtype == "bfloat16") { + if (amp_level == paddle::imperative::AmpLevel::O1) { + if (paddle::imperative::AmpOperators::Instance() + .GetMutableAllowOps() + ->count(api_name)) { + return paddle::experimental::DataType::BFLOAT16; + } else if (paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(api_name)) { + return paddle::experimental::DataType::FLOAT32; + } else { + auto dst_type = + GetPromoteType(api_name, amp_tensors_vector, + paddle::experimental::DataType::BFLOAT16); + if (dst_type == paddle::experimental::DataType::BFLOAT16 && + paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedBf16Ops() + ->count(api_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + return dst_type; + } + } else if (amp_level == paddle::imperative::AmpLevel::O2) { + auto dst_type = paddle::experimental::DataType::BFLOAT16; + if (paddle::imperative::AmpOperators::Instance() + .GetMutableUnsupportedBf16Ops() + ->count(api_name) || + paddle::imperative::AmpOperators::Instance() + .GetMutableBlockOps() + ->count(api_name)) { + dst_type = paddle::experimental::DataType::FLOAT32; + } + return dst_type; + } + } + return paddle::experimental::DataType::FLOAT32; +} + +static inline bool NeedCast(const paddle::experimental::Tensor& tensor, + const paddle::experimental::DataType& dst_dtype) { + auto place = tensor.inner_place(); + auto data_type = tensor.dtype(); + if (paddle::platform::is_gpu_place(place) || + paddle::platform::is_cuda_pinned_place(place) || + paddle::platform::is_xpu_place(place) || + paddle::platform::is_mlu_place(place) || + paddle::platform::is_npu_place(place) || + paddle::platform::is_npu_pinned_place(place)) { + // CudaPinndePlace is added for varbase created by dataloader + if ((data_type == paddle::experimental::DataType::FLOAT32 || + data_type == paddle::experimental::DataType::FLOAT16 || + data_type == paddle::experimental::DataType::BFLOAT16) && + (data_type != dst_dtype)) { + return true; + } + } + return false; +} + +std::vector AmpAutoCasts( + const std::string& inputs_name, + const std::vector& inputs, + const paddle::experimental::DataType& dst_dtype, std::string api_name) { + VLOG(6) << "AMP AmpAutoCasts:" + << " inputs(" << inputs_name << ") dst_dtype(" + << paddle::framework::DataType2String(dst_dtype) << ")."; + std::vector inputs_casted; + for (auto& input : inputs) { + if (NeedCast(input, dst_dtype)) { + paddle::framework::AttributeMap cast_attrs = { + {"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())}, + {"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}}; + inputs_casted.emplace_back( + std::move(cast_dygraph_function(input, cast_attrs))); + } else { + inputs_casted.emplace_back(input); + } + } + return inputs_casted; +} + +paddle::experimental::Tensor AmpAutoCast( + const std::string& input_name, const paddle::experimental::Tensor& input, + const paddle::experimental::DataType& dst_dtype, std::string api_name) { + VLOG(6) << "AMP AmpAutoCasts:" + << " input(" << input_name << ") dst_dtype(" + << paddle::framework::DataType2String(dst_dtype) << ")."; + if (dst_dtype == paddle::experimental::DataType::FLOAT16) { + if (api_name == "run_program") { + return input; + } + if ((api_name == "batch_norm" || api_name == "layer_norm" || + api_name == "sync_batch_norm") && + input_name != "X") { + return input; + } + if ((api_name == "fused_attention" || api_name == "fused_feedforward")) { + if (input_name == "LnScale" || input_name == "LnBias" || + input_name == "Ln2Scale" || input_name == "Ln2Bias" || + input_name == "Ln1Scale" || input_name == "Ln1Bias") { + return input; + } + } + } + if (NeedCast(input, dst_dtype)) { + paddle::framework::AttributeMap cast_attrs = { + {"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())}, + {"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}}; + return cast_dygraph_function(input, cast_attrs); + } + return input; +} +} // namespace egr diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 229817596423cd46c29db3f0dae589d0655b8485..fc3ea234929bc58ffc38fd694dacbe3bf60ffee4 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -1379,6 +1379,7 @@ static std::pair GenerateForwardFunctionContents( paddle::string::Sprintf(FORWARD_FUNCTION_TEMPLATE, op_type); std::string dygraph_function_args_str = ""; + std::string amp_function_call_args_str = ""; core_ops_args_info[op_type] = {}; core_ops_args_type_info[op_type] = {}; core_ops_args_info[op_type].resize(in_vars.size()); @@ -1391,6 +1392,9 @@ static std::pair GenerateForwardFunctionContents( // [Generation] Get Ins Map std::string ins_contents_str = ""; std::vector input_args_str_list(in_vars.size()); + std::vector amp_function_call_args_str_list(in_vars.size()); + std::string amp_tensors_vector_str = ""; + std::string amp_auto_cast_str = ""; for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); size_t input_position = fwd_inputs_name_pos_map.at(input_name); @@ -1400,6 +1404,7 @@ static std::pair GenerateForwardFunctionContents( "const std::vector& %s"; input_args_str_list[input_position] = paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); + amp_function_call_args_str_list[input_position] = " NEW_" + input_name; core_ops_args_type_info[op_type][input_position] = "list"; } else { @@ -1420,6 +1425,7 @@ static std::pair GenerateForwardFunctionContents( } input_args_str_list[input_position] = paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); + amp_function_call_args_str_list[input_position] = " NEW_" + input_name; core_ops_args_type_info[op_type][input_position] = "tensor"; } @@ -1431,10 +1437,31 @@ static std::pair GenerateForwardFunctionContents( "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, input_name, input_name); + if (input.duplicable()) { + const char* AMP_TENSORS_VECTOR_TEMPLATE = "%s,"; + amp_tensors_vector_str += + paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name); + const char* AMP_AUTO_CAST_TEMPLATE = + " auto NEW_%s = egr::AmpAutoCasts(\"%s\", %s, amp_dst_dtype, " + "\"%s\");\n"; + amp_auto_cast_str += paddle::string::Sprintf( + AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type); + } else { + const char* AMP_TENSORS_VECTOR_TEMPLATE = "{%s},"; + amp_tensors_vector_str += + paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name); + const char* AMP_AUTO_CAST_TEMPLATE = + " auto NEW_%s = egr::AmpAutoCast(\"%s\", %s, amp_dst_dtype, " + "\"%s\");\n"; + amp_auto_cast_str += paddle::string::Sprintf( + AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type); + } } if (ins_contents_str.size() > 0) ins_contents_str.pop_back(); // // Remove trailing "," + if (amp_tensors_vector_str.size() > 0) amp_tensors_vector_str.pop_back(); + for (const std::string& arg : input_args_str_list) { dygraph_function_args_str += arg; dygraph_function_args_str += ","; @@ -1442,16 +1469,17 @@ static std::pair GenerateForwardFunctionContents( if (dygraph_function_args_str.size() > 0) dygraph_function_args_str.pop_back(); - const char* FWD_INS_MAP_TEMPLATE = - " std::map>> ins = { " - "%s };\n"; - std::string ins_map_str = - paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str); - generated_function_body += ins_map_str; - generated_function_body += "\n"; + for (const std::string& arg : amp_function_call_args_str_list) { + amp_function_call_args_str += arg; + amp_function_call_args_str += ","; + } + if (amp_function_call_args_str.size() > 0) + amp_function_call_args_str.pop_back(); // Handle Dispensable Inputs + std::string dispensable_ins_contents_str = ""; + std::string dispensable_amp_tensors_vector_str = ""; + std::string dispensable_amp_auto_cast_str = ""; std::set input_names; for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); @@ -1461,14 +1489,36 @@ static std::pair GenerateForwardFunctionContents( const char* FWD_INS_CONTENT_TEMPLATE = " if(%s.size() > 0) " "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n"; - generated_function_body += paddle::string::Sprintf( + dispensable_ins_contents_str += paddle::string::Sprintf( FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); + const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE = + " if(%s.size() > 0) " + "amp_tensors_vector.push_back(%s);\n"; + dispensable_amp_tensors_vector_str += paddle::string::Sprintf( + FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name); + const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE = + " auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(\"%s\", " + "%s, amp_dst_dtype, \"%s\") : %s);\n"; + dispensable_amp_auto_cast_str += paddle::string::Sprintf( + DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name, + input_name, input_name, op_type, input_name); } else { const char* FWD_INS_CONTENT_TEMPLATE = " if(%s.initialized()) " "ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n"; - generated_function_body += paddle::string::Sprintf( + dispensable_ins_contents_str += paddle::string::Sprintf( FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name); + const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE = + " if(%s.initialized()) " + "amp_tensors_vector.push_back({ %s });\n"; + dispensable_amp_tensors_vector_str += paddle::string::Sprintf( + FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name); + const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE = + " auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(\"%s\", " + "%s, amp_dst_dtype, \"%s\") : %s);\n"; + dispensable_amp_auto_cast_str += paddle::string::Sprintf( + DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name, + input_name, input_name, op_type, input_name); } } } @@ -1493,6 +1543,7 @@ static std::pair GenerateForwardFunctionContents( std::string arg_str = paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); dygraph_function_args_str += arg_str; + amp_function_call_args_str += (", " + output_var_name); core_ops_args_type_info[op_type].push_back("list"); } else { @@ -1500,6 +1551,7 @@ static std::pair GenerateForwardFunctionContents( std::string arg_str = paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); dygraph_function_args_str += arg_str; + amp_function_call_args_str += (", " + output_var_name); core_ops_args_type_info[op_type].push_back("tensor"); } @@ -1544,6 +1596,7 @@ static std::pair GenerateForwardFunctionContents( std::string arg_str = paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); dygraph_function_args_str += arg_str; + amp_function_call_args_str += (", " + outnum); const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", egr::EagerUtils::CreateVars(%s) },"; outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, @@ -1565,6 +1618,69 @@ static std::pair GenerateForwardFunctionContents( if (inplace_mapping_str.size() > 0) inplace_mapping_str.pop_back(); // Remove trailing "," + if ((op_type != "cast") && (inplace_map.empty())) { + VLOG(6) << "Generating Dygraph Forward AMP"; + const char* AMP_LOGIC_CONTEXT = + " if (egr::Controller::Instance().GetAMPLevel() != " + "paddle::imperative::AmpLevel::O0) {\n" + " VLOG(5) << \"Check and Prepare For AMP\";\n" + " \n" + "%s\n" + " }\n"; + std::string amp_logic_str = ""; + if (in_vars.size() != 0) { + const char* AMP_TENSORS_VECTOR_TEMPLATE = + " std::vector> " + "amp_tensors_vector = { " + "%s };\n"; + std::string amp_tensors_vector = paddle::string::Sprintf( + AMP_TENSORS_VECTOR_TEMPLATE, amp_tensors_vector_str); + amp_tensors_vector += dispensable_amp_tensors_vector_str; + amp_logic_str += amp_tensors_vector; + amp_logic_str += "\n"; + const char* GET_AMP_GET_DST_DTYPE_CONTEXT = + " auto amp_dst_dtype = " + "egr::GetAmpDestDtype(\"%s\", " + "amp_tensors_vector);\n"; + amp_logic_str += + paddle::string::Sprintf(GET_AMP_GET_DST_DTYPE_CONTEXT, op_type); + amp_logic_str += "\n"; + amp_logic_str += amp_auto_cast_str; + amp_logic_str += dispensable_amp_auto_cast_str; + amp_logic_str += "\n"; + } + const char* CALL_BACK_TEMPLATE = + " {\n" + " paddle::imperative::AutoCastGuard " + "guard(egr::Controller::Instance().GetCurrentTracer(), " + "paddle::imperative::AmpLevel::O0);\n" + " return %s_dygraph_function(%s);\n" + " }"; + amp_function_call_args_str += ", attr_map "; + if (amp_function_call_args_str.size() > 0) { + auto iter = amp_function_call_args_str.begin(); + if ((*iter) == ',') amp_function_call_args_str.erase(iter); + } + std::string call_back_str = paddle::string::Sprintf( + CALL_BACK_TEMPLATE, op_type, amp_function_call_args_str); + amp_logic_str += call_back_str; + amp_logic_str += "\n"; + std::string amp_context = + paddle::string::Sprintf(AMP_LOGIC_CONTEXT, amp_logic_str); + generated_function_body += amp_context; + generated_function_body += "\n"; + } + // forward ins insert + const char* FWD_INS_MAP_TEMPLATE = + " std::map>> ins = { " + "%s };\n"; + std::string ins_map_str = + paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str); + ins_map_str += dispensable_ins_contents_str; + generated_function_body += ins_map_str; + generated_function_body += "\n"; + // forward outs insert const char* FWD_OUTS_MAP_TEMPLATE = " std::map>> outs = { " @@ -2044,6 +2160,7 @@ static std::string GenerateSingleOpBase( grad_attrs_str += paddle::string::Sprintf(CAST_GRAD, attrs_name, attrs_name, attrs_name, attrs_name); } + // Handle dynamic grad attributes grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name); generated_grad_function_body += grad_attrs_str; @@ -2469,6 +2586,7 @@ static void GenerateForwardDygraphFile(const std::string& forward_cc_path, "#include " "\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n" "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" + "#include \"paddle/fluid/eager/amp_utils.h\"\n" "#include \"paddle/fluid/platform/profiler/event_tracing.h\"\n\n"; std::string forward_cc_include_str = paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE); diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 25610a3f95fe5d969ffafa8379842b1ef2333b54..5f3dfe8e513ed4c828445413e27054ad0982005e 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -364,6 +364,7 @@ GradNodeBase::ApplyGradientHooks( if (!outs[i][j].defined() || !outs[i][j].initialized()) { outs[i][j] = tensors[i][j]; } + CheckTensor(tensors[i][j], outs[i][j]); } } diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 4dec1c1f9f4e5c0088fc05a8e581bb637117b4a4..81470f38cc37b779174903f810d3a59e7775557e 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -279,4 +279,29 @@ class Edge { std::shared_ptr grad_node_{nullptr}; }; +inline void CheckTensor(const paddle::experimental::Tensor& pre, + const paddle::experimental::Tensor& post) { + if (!pre.initialized() && post.initialized()) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "The tensor in before and after hook are not consistent")); + } + if (pre.initialized() && post.initialized()) { + VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " " + << paddle::framework::DataType2String(post.dtype()); + PADDLE_ENFORCE_EQ( + pre.dtype(), post.dtype(), + paddle::platform::errors::PermissionDenied( + "The dtype of tensor before(%s) and after(%s) hook are not " + "consistent", + paddle::framework::DataType2String(pre.dtype()), + paddle::framework::DataType2String(post.dtype()))); + PADDLE_ENFORCE_EQ( + pre.inner_place(), post.inner_place(), + paddle::platform::errors::PermissionDenied( + "The place of tensor before(%s) and after(%s) " + "hook are not consistent", + pre.inner_place().DebugString(), post.inner_place().DebugString())); + } +} + } // namespace egr diff --git a/paddle/fluid/framework/convert_utils.cc b/paddle/fluid/framework/convert_utils.cc index df5cc6d82042c262467b35f6a7cbe097a4ad7776..1144bc5150906a60026bcd7e02ad70fb4db0125d 100644 --- a/paddle/fluid/framework/convert_utils.cc +++ b/paddle/fluid/framework/convert_utils.cc @@ -145,6 +145,8 @@ DataType String2DataType(const std::string& str) { return DataType::COMPLEX64; } else if (str == "complex128") { return DataType::COMPLEX128; + } else if (str == "bfloat16") { + return DataType::BFLOAT16; } else { return DataType::UNDEFINED; } @@ -174,6 +176,8 @@ std::string DataType2String(DataType dtype) { return "complex64"; case DataType::COMPLEX128: return "complex128"; + case DataType::BFLOAT16: + return "bfloat16"; default: PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Unknow phi::DataType, the int value = %d.", diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index a610c31ee8946b9e3e9f3bfdf50c7448d4755c8d..d8c297b1a94c7b50dfdad13b27c6190a278965c3 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -112,6 +112,9 @@ int tensor_properties_set_stop_gradient(TensorObject* self, PyObject* value, EAGER_TRY auto meta = egr::EagerUtils::autograd_meta(&self->tensor); meta->SetStopGradient(CastPyArg2AttrBoolean(value, 0)); + if (!meta->GradNode()) { + meta->SetGradNode(std::make_shared(meta)); + } return 0; EAGER_CATCH_AND_THROW_RETURN_ZERO } diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py index 2dcf7a6f168e20393a3e1a3432ac75b652e2a063..76a6e11d98dffe821501b02683b7565808f9c7c1 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py @@ -172,7 +172,7 @@ class TestImperativeQatAmp(unittest.TestCase): acc_top1 = sum(acc_top1_list) / len(acc_top1_list) return acc_top1 - def test_ptq(self): + def ptq(self): start_time = time.time() self.set_vars() @@ -217,6 +217,11 @@ class TestImperativeQatAmp(unittest.TestCase): end_time = time.time() print("total time: %ss" % (end_time - start_time)) + def test_ptq(self): + self.ptq() + with _test_eager_guard(): + self.ptq() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 2011a35db682e53b851cbcf3eb71cbc3706fd7c4..aa1b10db441f81fa5c4c50a62eb68248d8484a92 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -21,6 +21,7 @@ from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_p import paddle.nn as nn from paddle.static import InputSpec from paddle.autograd import PyLayer +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode, in_dygraph_mode if fluid.core.is_compiled_with_cuda(): fluid.set_flags({"FLAGS_cudnn_deterministic": True}) @@ -51,7 +52,7 @@ class SimpleConv(fluid.dygraph.Layer): class TestAutoCast(unittest.TestCase): - def test_amp_guard_white_op(self): + def amp_guard_white_op(self): data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') with fluid.dygraph.guard(): conv2d = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) @@ -66,7 +67,12 @@ class TestAutoCast(unittest.TestCase): self.assertTrue(out_fp16.dtype == fluid.core.VarDesc.VarType.FP16) self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32) - def test_amp_guard_black_op(self): + def test_amp_guard_white_op(self): + with _test_eager_guard(): + self.amp_guard_white_op() + self.amp_guard_white_op() + + def amp_guard_black_op(self): data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) @@ -76,7 +82,12 @@ class TestAutoCast(unittest.TestCase): self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32) self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32) - def test_custom_op_list(self): + def test_amp_guard_black_op(self): + with _test_eager_guard(): + self.amp_guard_black_op() + self.amp_guard_black_op() + + def custom_op_list(self): with fluid.dygraph.guard(): tracer = fluid.framework._dygraph_tracer() base_white_list = fluid.dygraph.amp.auto_cast.WHITE_LIST @@ -107,7 +118,12 @@ class TestAutoCast(unittest.TestCase): set(black_list) == (set(base_black_list) - {"log"}) | {"conv2d"}) - def test_custom_op_list_exception(self): + def test_custom_op_list(self): + with _test_eager_guard(): + self.custom_op_list() + self.custom_op_list() + + def custom_op_list_exception(self): inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) def func(): @@ -118,7 +134,6 @@ class TestAutoCast(unittest.TestCase): filter_size=7, stride=2, act='relu') - with fluid.dygraph.amp_guard( custom_white_list=["conv2d"], custom_black_list=["conv2d"]): @@ -127,7 +142,12 @@ class TestAutoCast(unittest.TestCase): self.assertRaises(ValueError, func) - def test_amp_guard_upsupported_fp16_op(self): + def test_custom_op_list_exception(self): + with _test_eager_guard(): + self.custom_op_list_exception() + self.custom_op_list_exception() + + def amp_guard_upsupported_fp16_op(self): data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') with fluid.dygraph.guard(): conv2d = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) @@ -143,7 +163,6 @@ class TestAutoCast(unittest.TestCase): out_purefp16_fp32 = paddle.expand_as( out_purefp16_fp16, out_purefp16_fp16) # expand_as_v2 has no fp16 kernel - self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32) self.assertTrue(out_amp_fp16.dtype == fluid.core.VarDesc.VarType.FP16) self.assertTrue(out_amp_fp32.dtype == fluid.core.VarDesc.VarType.FP32) @@ -152,7 +171,12 @@ class TestAutoCast(unittest.TestCase): self.assertTrue( out_purefp16_fp32.dtype == fluid.core.VarDesc.VarType.FP32) - def test_mode_exception(self): + def test_amp_guard_upsupported_fp16_op(self): + with _test_eager_guard(): + self.amp_guard_upsupported_fp16_op() + self.amp_guard_upsupported_fp16_op() + + def mode_exception(self): def func(): data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') with fluid.dygraph.guard(): @@ -164,9 +188,14 @@ class TestAutoCast(unittest.TestCase): self.assertRaises(ValueError, func) + def test_mode_exception(self): + with _test_eager_guard(): + self.mode_exception() + self.mode_exception() + class TestAmpScaler(unittest.TestCase): - def test_scale(self): + def scale(self): with fluid.dygraph.guard(): data = paddle.rand([10, 1024]) scaler = paddle.fluid.dygraph.AmpScaler(init_loss_scaling=1024) @@ -174,7 +203,12 @@ class TestAmpScaler(unittest.TestCase): self.assertEqual( np.array_equal(scaled_data.numpy(), data.numpy() * 1024), True) - def test_minimize(self): + def test_scale(self): + with _test_eager_guard(): + self.scale() + self.scale() + + def minimize(self): inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) def run_simple_conv(inp_np, use_scaler=True): @@ -223,7 +257,12 @@ class TestAmpScaler(unittest.TestCase): np.allclose(outs_with_scaler[1][i][0].numpy(), outs_no_scaler[1][i][0].numpy()), True) - def test_step(self): + def test_minimize(self): + with _test_eager_guard(): + self.minimize() + self.minimize() + + def step(self): inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) def run_simple_conv(inp_np, use_scaler=True): @@ -264,7 +303,12 @@ class TestAmpScaler(unittest.TestCase): np.allclose(outs_with_scaler[i].numpy(), outs_no_scaler[i].numpy()), True) - def test_nan_inf(self): + def test_step(self): + with _test_eager_guard(): + self.step() + self.step() + + def nan_inf(self): inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) inp_np[0][1][2][3] = np.nan with fluid.dygraph.guard(): @@ -294,7 +338,12 @@ class TestAmpScaler(unittest.TestCase): self.assertTrue( np.array_equal(param.numpy(), params_init[param.name])) - def test_step_update_exception(self): + def test_nan_inf(self): + with _test_eager_guard(): + self.nan_inf() + self.nan_inf() + + def step_update_exception(self): def func1(): model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -340,6 +389,11 @@ class TestAmpScaler(unittest.TestCase): self.assertRaises(RuntimeError, func3) + def test_step_update_exception(self): + with _test_eager_guard(): + self.step_update_exception() + self.step_update_exception() + def test_get_and_set(self): with fluid.dygraph.guard(): scaler = paddle.amp.GradScaler( @@ -504,14 +558,19 @@ class TestGradScalerStateDict(unittest.TestCase): return dy_out, dy_param_value, dy_grad_value def test_with_state_dict(self): - with fluid.dygraph.guard(): - out_use_state_dict = self.train_resnet( - enable_amp=True, use_data_loader=True, use_save_load=True) - out_no_state_dict = self.train_resnet( - enable_amp=True, use_data_loader=True, use_save_load=False) - print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) - self.assertTrue( - np.allclose(out_use_state_dict[0], out_no_state_dict[0])) + def func_isinstance(): + with fluid.dygraph.guard(): + out_use_state_dict = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=True) + out_no_state_dict = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=False) + print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) + self.assertTrue( + np.allclose(out_use_state_dict[0], out_no_state_dict[0])) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() class TestAmpDecorator(unittest.TestCase): @@ -765,17 +824,23 @@ class TestPureFp16SaveLoad(unittest.TestCase): return dy_out, dy_param_value, dy_grad_value def test_with_save_load(self): - with fluid.dygraph.guard(): - out_use_save_load = self.train_resnet( - enable_amp=True, use_data_loader=True, use_save_load=True) - out_no_save_load = self.train_resnet( - enable_amp=True, use_data_loader=True, use_save_load=False) - print('save_load:', out_use_save_load[0], out_no_save_load[0]) - self.assertTrue(np.allclose(out_use_save_load[0], out_no_save_load[0])) + def func_isinstance(): + with fluid.dygraph.guard(): + out_use_save_load = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=True) + out_no_save_load = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=False) + print('save_load:', out_use_save_load[0], out_no_save_load[0]) + self.assertTrue( + np.allclose(out_use_save_load[0], out_no_save_load[0])) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() class TestPureFp16InferenceSaveLoad(unittest.TestCase): - def test_inference_save_load(self): + def inference_save_load(self): BATCH_SIZE = 16 BATCH_NUM = 4 EPOCH_NUM = 4 @@ -861,7 +926,15 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): results = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) + print("pred.numpy()", pred.numpy()) + print("results", results) self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) + paddle.disable_static() + + def test_inference_save_load(self): + self.inference_save_load() + with _test_eager_guard(): + self.inference_save_load() class TestResnet2(unittest.TestCase): @@ -987,38 +1060,63 @@ class TestResnet2(unittest.TestCase): return dy_out, dy_param_value, dy_grad_value def test_resnet(self): - with fluid.dygraph.guard(): - out_fp32 = self.train_resnet(enable_amp=False) - out_amp = self.train_resnet(enable_amp=True) - out_pure_fp16 = self.train_resnet(enable_amp=True, level='O2') - print(out_fp32[0], out_amp[0], out_pure_fp16[0]) - self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) - self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) + def func_isinstance(): + with fluid.dygraph.guard(): + out_fp32 = self.train_resnet(enable_amp=False) + out_amp = self.train_resnet(enable_amp=True) + out_pure_fp16 = self.train_resnet(enable_amp=True, level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) + self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) + self.assertTrue( + np.allclose( + out_fp32[0], out_pure_fp16[0], atol=1.e-2)) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() def test_with_data_loader(self): - with fluid.dygraph.guard(): - out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True) - out_amp = self.train_resnet(enable_amp=True, use_data_loader=True) - out_pure_fp16 = self.train_resnet( - enable_amp=True, use_data_loader=True, level='O2') - print(out_fp32[0], out_amp[0], out_pure_fp16[0]) - self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) - self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) + def func_isinstance(): + with fluid.dygraph.guard(): + out_fp32 = self.train_resnet( + enable_amp=False, use_data_loader=True) + out_amp = self.train_resnet( + enable_amp=True, use_data_loader=True) + out_pure_fp16 = self.train_resnet( + enable_amp=True, use_data_loader=True, level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) + self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) + self.assertTrue( + np.allclose( + out_fp32[0], out_pure_fp16[0], atol=1.e-2)) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() def test_param_group(self): - with fluid.dygraph.guard(): - out_fp32 = self.train_resnet( - enable_amp=False, use_data_loader=True, use_param_group=True) - out_amp = self.train_resnet( - enable_amp=True, use_data_loader=True, use_param_group=True) - out_pure_fp16 = self.train_resnet( - enable_amp=True, - use_data_loader=True, - use_param_group=True, - level='O2') - print(out_fp32[0], out_amp[0], out_pure_fp16[0]) - self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) - self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) + def func_isinstance(): + with fluid.dygraph.guard(): + out_fp32 = self.train_resnet( + enable_amp=False, + use_data_loader=True, + use_param_group=True) + out_amp = self.train_resnet( + enable_amp=True, use_data_loader=True, use_param_group=True) + out_pure_fp16 = self.train_resnet( + enable_amp=True, + use_data_loader=True, + use_param_group=True, + level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) + self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) + self.assertTrue( + np.allclose( + out_fp32[0], out_pure_fp16[0], atol=1.e-2)) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() class TestResnet(unittest.TestCase): @@ -1102,12 +1200,19 @@ class TestResnet(unittest.TestCase): return dy_out, dy_param_value, dy_grad_value def test_resnet(self): - out_fp32 = self.train_resnet(enable_amp=False) - out_amp = self.train_resnet(enable_amp=True) - out_pure_fp16 = self.train_resnet(enable_amp=True, level='O2') - print(out_fp32[0], out_amp[0], out_pure_fp16[0]) - self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) - self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-1)) + def func_isinstance(): + out_fp32 = self.train_resnet(enable_amp=False) + out_amp = self.train_resnet(enable_amp=True) + out_pure_fp16 = self.train_resnet(enable_amp=True, level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) + self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) + self.assertTrue( + np.allclose( + out_fp32[0], out_pure_fp16[0], atol=1.e-1)) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() class TestLayerNormFp16(unittest.TestCase): @@ -1116,14 +1221,20 @@ class TestLayerNormFp16(unittest.TestCase): ''' def test_layer_norm_fp16(self): - if fluid.is_compiled_with_cuda(): - with fluid.dygraph.guard(fluid.CUDAPlace(0)): - x = paddle.rand([2, 2, 2, 3]) - layer_norm = paddle.nn.LayerNorm(x.shape[1:]) - with paddle.amp.auto_cast(custom_white_list=['layer_norm']): - out = layer_norm(x) + def func_isinstance(): + if fluid.is_compiled_with_cuda(): + with fluid.dygraph.guard(fluid.CUDAPlace(0)): + x = paddle.rand([2, 2, 2, 3]) + layer_norm = paddle.nn.LayerNorm(x.shape[1:]) + with paddle.amp.auto_cast(custom_white_list=['layer_norm']): + out = layer_norm(x) - self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16) + self.assertTrue( + out.dtype == fluid.core.VarDesc.VarType.FP16) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() class TestBf16(unittest.TestCase): @@ -1142,18 +1253,23 @@ class TestBf16(unittest.TestCase): return output.numpy() def test_bf16(self): - if fluid.core.is_compiled_with_cuda(): - cudnn_version = paddle.device.get_cudnn_version() - if cudnn_version is not None and cudnn_version >= 8100: - out_fp32 = self.train(enable_amp=False) - out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') - out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) + def func_isinstance(): + if fluid.core.is_compiled_with_cuda(): + cudnn_version = paddle.device.get_cudnn_version() + if cudnn_version is not None and cudnn_version >= 8100: + out_fp32 = self.train(enable_amp=False) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() class TestAmpWithPyLyer(unittest.TestCase): @@ -1176,44 +1292,54 @@ class TestAmpWithPyLyer(unittest.TestCase): x.stop_gradient = False y.stop_gradient = False - with paddle.amp.auto_cast(): - res = MyMM.apply(x, y) - loss = paddle.mean(res) + # with paddle.amp.auto_cast(): + res = MyMM.apply(x, y) + loss = paddle.mean(res) loss.backward() class TestAmpWithHook(unittest.TestCase): def test_hook_change_dtype(self): - with paddle.fluid.dygraph.guard(): - v = paddle.rand([3, 3]) - v.stop_gradient = False - - def foo(grad): - print('grad', grad, grad.dtype) # grad's dtype is float32 - res = paddle.mm(grad, grad) # mm runs in fp16 - print('res', res, res.dtype) # res's dtype is float16 - return res - - v.register_hook(foo) - with paddle.amp.auto_cast(): - a = paddle.mm(v, v) - loss = a.sum() - self.assertRaises(RuntimeError, loss.backward) + def func_isinstance(): + with paddle.fluid.dygraph.guard(): + v = paddle.rand([3, 3]) + v.stop_gradient = False + + def foo(grad): + print('grad', grad, grad.dtype) # grad's dtype is float32 + res = paddle.mm(grad, grad) # mm runs in fp16 + print('res', res, res.dtype) # res's dtype is float16 + return res + + v.register_hook(foo) + with paddle.amp.auto_cast(): + a = paddle.mm(v, v) + loss = a.sum() + self.assertRaises(RuntimeError, loss.backward) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() def test_hook_change_place(self): - with paddle.fluid.dygraph.guard(): - v = paddle.rand([3, 3]) - v.stop_gradient = False - - def foo(grad): - res = grad.cpu() # change place - return res - - v.register_hook(foo) - with paddle.amp.auto_cast(): - a = paddle.mm(v, v) - loss = a.sum() - self.assertRaises(RuntimeError, loss.backward) + def func_isinstance(): + with paddle.fluid.dygraph.guard(): + v = paddle.rand([3, 3]) + v.stop_gradient = False + + def foo(grad): + res = grad.cpu() # change place + return res + + v.register_hook(foo) + with paddle.amp.auto_cast(): + a = paddle.mm(v, v) + loss = a.sum() + self.assertRaises(RuntimeError, loss.backward) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() if __name__ == '__main__': diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 7f8e516496f32352fa18f950a4687d5b52f4d10d..f075439e54fe7b26458300948f7f85d8f4e63007 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -997,17 +997,17 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ 'test_parallel_executor_transformer', 'test_tensor_scalar_type_promotion_dynamic', 'test_eager_deletion_delete_vars', 'test_asp_pruning_1d', - 'test_imperative_auto_mixed_precision', 'test_imperative_using_non_zero_gpu', - 'test_machine_translation', 'test_flatten_op', 'test_onnx_export', - 'test_optimizer_for_varbase', 'test_fusion_transpose_flatten_concat_op', - 'best_fit_allocator_test', 'test_ir_fusion_group_pass', - 'test_trt_quant_conv2d_dequant_fuse_pass', 'test_allclose_op', - 'test_ftrl_op', 'test_elementwise_add_op', 'test_instance_norm_op', - 'test_lambv2_op', 'test_yolo_box_op', 'test_parallel_executor_drop_scope', - 'test_generator_dataloader', 'test_conv2d_transpose_op_depthwise_conv', - 'test_imperative_save_load_v2', 'test_lookahead', - 'test_moving_average_abs_max_scale_op', 'test_roi_perspective_transform_op', - 'test_tensorrt_engine', 'test_affine_grid_function', 'test_nonzero_api', + 'test_imperative_using_non_zero_gpu', 'test_machine_translation', + 'test_flatten_op', 'test_onnx_export', 'test_optimizer_for_varbase', + 'test_fusion_transpose_flatten_concat_op', 'best_fit_allocator_test', + 'test_ir_fusion_group_pass', 'test_trt_quant_conv2d_dequant_fuse_pass', + 'test_allclose_op', 'test_ftrl_op', 'test_elementwise_add_op', + 'test_instance_norm_op', 'test_lambv2_op', 'test_yolo_box_op', + 'test_parallel_executor_drop_scope', 'test_generator_dataloader', + 'test_conv2d_transpose_op_depthwise_conv', 'test_imperative_save_load_v2', + 'test_lookahead', 'test_moving_average_abs_max_scale_op', + 'test_roi_perspective_transform_op', 'test_tensorrt_engine', + 'test_affine_grid_function', 'test_nonzero_api', 'test_ir_memory_optimize_pass', 'test_reduce_mkldnn_op', 'test_bilinear_interp_op', 'test_cvm_op', 'test_scale_op', 'test_matmul_op', 'test_sequence_pool', 'test_complex_simplenet', 'test_complex_reshape',