未验证 提交 c12f7d48 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] Support amp for Intermediate_dygraph (#40623)

* approve amp for intermediate_dygraph

* add amp_utils for intermediate_dygraph

* add amp needcast check for mlu & npu

* test unittest

* add SetGradNode for set_stop_gradient && add checktensor for GradientHooks

* refine code

* refien unittest of imperative_amp for new dygraph

* inplace api skip amp

* add test_imperative_qat_amp for intermediate amp

* refine code

* refine test_amp ci strategy

* refine unittest code

* refine amp_utils code

* refine amp getpromotetype for some special op

* refine unittest code
上级 38d1fe34
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
namespace egr {
static inline paddle::experimental::DataType GetPromoteType(
const std::string& api_name,
const std::vector<std::vector<paddle::experimental::Tensor>>&
amp_tensors_vector,
const paddle::experimental::DataType& amp_dtype) {
auto dst_type = amp_dtype;
if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
"float16") {
if (api_name == "batch_norm" || api_name == "layer_norm" ||
api_name == "sync_batch_norm") {
if (amp_tensors_vector[0][0].dtype() ==
paddle::experimental::DataType::FLOAT32) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
} else if (api_name == "fused_attention") {
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
if (i != 3 || i != 4 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() ==
paddle::experimental::DataType::FLOAT32) {
dst_type = paddle::experimental::DataType::FLOAT32;
break;
}
}
}
} else if (api_name == "fused_feedforward") {
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
if (i != 7 || i != 8 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() ==
paddle::experimental::DataType::FLOAT32) {
dst_type = paddle::experimental::DataType::FLOAT32;
break;
}
}
}
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
}
}
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
}
}
// NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
// input(X)
if (api_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() ==
paddle::experimental::DataType::FLOAT16) {
dst_type = paddle::experimental::DataType::FLOAT16;
}
}
return dst_type;
}
paddle::experimental::DataType GetAmpDestDtype(
const std::string& api_name,
const std::vector<std::vector<paddle::experimental::Tensor>>&
amp_tensors_vector) {
auto amp_dtype =
egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype();
auto amp_level = egr::Controller::Instance().GetAMPLevel();
VLOG(6) << "AMP GetAmpDestDtype:"
<< " op(" << api_name << ") amp_dtype(" << amp_dtype << ") amp_level("
<< static_cast<int>(amp_level) << ").";
if (amp_dtype == "float16") {
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(api_name)) {
return paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(api_name)) {
return paddle::experimental::DataType::FLOAT32;
} else {
auto dst_type = GetPromoteType(api_name, amp_tensors_vector,
paddle::experimental::DataType::FLOAT16);
if (dst_type == paddle::experimental::DataType::FLOAT16 &&
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(api_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
}
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::FLOAT16;
if (paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedFp16Ops()
->count(api_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(api_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
}
} else if (amp_dtype == "bfloat16") {
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(api_name)) {
return paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(api_name)) {
return paddle::experimental::DataType::FLOAT32;
} else {
auto dst_type =
GetPromoteType(api_name, amp_tensors_vector,
paddle::experimental::DataType::BFLOAT16);
if (dst_type == paddle::experimental::DataType::BFLOAT16 &&
paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedBf16Ops()
->count(api_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
}
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
auto dst_type = paddle::experimental::DataType::BFLOAT16;
if (paddle::imperative::AmpOperators::Instance()
.GetMutableUnsupportedBf16Ops()
->count(api_name) ||
paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(api_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
}
return dst_type;
}
}
return paddle::experimental::DataType::FLOAT32;
}
static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
const paddle::experimental::DataType& dst_dtype) {
auto place = tensor.inner_place();
auto data_type = tensor.dtype();
if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 ||
data_type == paddle::experimental::DataType::BFLOAT16) &&
(data_type != dst_dtype)) {
return true;
}
}
return false;
}
std::vector<paddle::experimental::Tensor> AmpAutoCasts(
const std::string& inputs_name,
const std::vector<paddle::experimental::Tensor>& inputs,
const paddle::experimental::DataType& dst_dtype, std::string api_name) {
VLOG(6) << "AMP AmpAutoCasts:"
<< " inputs(" << inputs_name << ") dst_dtype("
<< paddle::framework::DataType2String(dst_dtype) << ").";
std::vector<paddle::experimental::Tensor> inputs_casted;
for (auto& input : inputs) {
if (NeedCast(input, dst_dtype)) {
paddle::framework::AttributeMap cast_attrs = {
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
{"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}};
inputs_casted.emplace_back(
std::move(cast_dygraph_function(input, cast_attrs)));
} else {
inputs_casted.emplace_back(input);
}
}
return inputs_casted;
}
paddle::experimental::Tensor AmpAutoCast(
const std::string& input_name, const paddle::experimental::Tensor& input,
const paddle::experimental::DataType& dst_dtype, std::string api_name) {
VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << input_name << ") dst_dtype("
<< paddle::framework::DataType2String(dst_dtype) << ").";
if (dst_dtype == paddle::experimental::DataType::FLOAT16) {
if (api_name == "run_program") {
return input;
}
if ((api_name == "batch_norm" || api_name == "layer_norm" ||
api_name == "sync_batch_norm") &&
input_name != "X") {
return input;
}
if ((api_name == "fused_attention" || api_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
input_name == "Ln1Scale" || input_name == "Ln1Bias") {
return input;
}
}
}
if (NeedCast(input, dst_dtype)) {
paddle::framework::AttributeMap cast_attrs = {
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
{"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}};
return cast_dygraph_function(input, cast_attrs);
}
return input;
}
} // namespace egr
......@@ -1379,6 +1379,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
paddle::string::Sprintf(FORWARD_FUNCTION_TEMPLATE, op_type);
std::string dygraph_function_args_str = "";
std::string amp_function_call_args_str = "";
core_ops_args_info[op_type] = {};
core_ops_args_type_info[op_type] = {};
core_ops_args_info[op_type].resize(in_vars.size());
......@@ -1391,6 +1392,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Ins Map
std::string ins_contents_str = "";
std::vector<std::string> input_args_str_list(in_vars.size());
std::vector<std::string> amp_function_call_args_str_list(in_vars.size());
std::string amp_tensors_vector_str = "";
std::string amp_auto_cast_str = "";
for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name();
size_t input_position = fwd_inputs_name_pos_map.at(input_name);
......@@ -1400,6 +1404,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
"const std::vector<paddle::experimental::Tensor>& %s";
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
amp_function_call_args_str_list[input_position] = " NEW_" + input_name;
core_ops_args_type_info[op_type][input_position] = "list";
} else {
......@@ -1420,6 +1425,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
amp_function_call_args_str_list[input_position] = " NEW_" + input_name;
core_ops_args_type_info[op_type][input_position] = "tensor";
}
......@@ -1431,10 +1437,31 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE,
input_name, input_name);
if (input.duplicable()) {
const char* AMP_TENSORS_VECTOR_TEMPLATE = "%s,";
amp_tensors_vector_str +=
paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name);
const char* AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = egr::AmpAutoCasts(\"%s\", %s, amp_dst_dtype, "
"\"%s\");\n";
amp_auto_cast_str += paddle::string::Sprintf(
AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type);
} else {
const char* AMP_TENSORS_VECTOR_TEMPLATE = "{%s},";
amp_tensors_vector_str +=
paddle::string::Sprintf(AMP_TENSORS_VECTOR_TEMPLATE, input_name);
const char* AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = egr::AmpAutoCast(\"%s\", %s, amp_dst_dtype, "
"\"%s\");\n";
amp_auto_cast_str += paddle::string::Sprintf(
AMP_AUTO_CAST_TEMPLATE, input_name, input_name, input_name, op_type);
}
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
if (amp_tensors_vector_str.size() > 0) amp_tensors_vector_str.pop_back();
for (const std::string& arg : input_args_str_list) {
dygraph_function_args_str += arg;
dygraph_function_args_str += ",";
......@@ -1442,16 +1469,17 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (dygraph_function_args_str.size() > 0)
dygraph_function_args_str.pop_back();
const char* FWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerVariable>>> ins = { "
"%s };\n";
std::string ins_map_str =
paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str);
generated_function_body += ins_map_str;
generated_function_body += "\n";
for (const std::string& arg : amp_function_call_args_str_list) {
amp_function_call_args_str += arg;
amp_function_call_args_str += ",";
}
if (amp_function_call_args_str.size() > 0)
amp_function_call_args_str.pop_back();
// Handle Dispensable Inputs
std::string dispensable_ins_contents_str = "";
std::string dispensable_amp_tensors_vector_str = "";
std::string dispensable_amp_auto_cast_str = "";
std::set<std::string> input_names;
for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name();
......@@ -1461,14 +1489,36 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.size() > 0) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
generated_function_body += paddle::string::Sprintf(
dispensable_ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE =
" if(%s.size() > 0) "
"amp_tensors_vector.push_back(%s);\n";
dispensable_amp_tensors_vector_str += paddle::string::Sprintf(
FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name);
const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(\"%s\", "
"%s, amp_dst_dtype, \"%s\") : %s);\n";
dispensable_amp_auto_cast_str += paddle::string::Sprintf(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name,
input_name, input_name, op_type, input_name);
} else {
const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.initialized()) "
"ins[\"%s\"] = egr::EagerUtils::TrySyncToVars(%s);\n";
generated_function_body += paddle::string::Sprintf(
dispensable_ins_contents_str += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
const char* FWD_AMP_TENSORS_VECTOR_TEMPLATE =
" if(%s.initialized()) "
"amp_tensors_vector.push_back({ %s });\n";
dispensable_amp_tensors_vector_str += paddle::string::Sprintf(
FWD_AMP_TENSORS_VECTOR_TEMPLATE, input_name, input_name);
const char* DISPENSABLE_AMP_AUTO_CAST_TEMPLATE =
" auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(\"%s\", "
"%s, amp_dst_dtype, \"%s\") : %s);\n";
dispensable_amp_auto_cast_str += paddle::string::Sprintf(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE, input_name, input_name,
input_name, input_name, op_type, input_name);
}
}
}
......@@ -1493,6 +1543,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + output_var_name);
core_ops_args_type_info[op_type].push_back("list");
} else {
......@@ -1500,6 +1551,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + output_var_name);
core_ops_args_type_info[op_type].push_back("tensor");
}
......@@ -1544,6 +1596,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str;
amp_function_call_args_str += (", " + outnum);
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::CreateVars(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
......@@ -1565,6 +1618,69 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (inplace_mapping_str.size() > 0)
inplace_mapping_str.pop_back(); // Remove trailing ","
if ((op_type != "cast") && (inplace_map.empty())) {
VLOG(6) << "Generating Dygraph Forward AMP";
const char* AMP_LOGIC_CONTEXT =
" if (egr::Controller::Instance().GetAMPLevel() != "
"paddle::imperative::AmpLevel::O0) {\n"
" VLOG(5) << \"Check and Prepare For AMP\";\n"
" \n"
"%s\n"
" }\n";
std::string amp_logic_str = "";
if (in_vars.size() != 0) {
const char* AMP_TENSORS_VECTOR_TEMPLATE =
" std::vector<std::vector<paddle::experimental::Tensor>> "
"amp_tensors_vector = { "
"%s };\n";
std::string amp_tensors_vector = paddle::string::Sprintf(
AMP_TENSORS_VECTOR_TEMPLATE, amp_tensors_vector_str);
amp_tensors_vector += dispensable_amp_tensors_vector_str;
amp_logic_str += amp_tensors_vector;
amp_logic_str += "\n";
const char* GET_AMP_GET_DST_DTYPE_CONTEXT =
" auto amp_dst_dtype = "
"egr::GetAmpDestDtype(\"%s\", "
"amp_tensors_vector);\n";
amp_logic_str +=
paddle::string::Sprintf(GET_AMP_GET_DST_DTYPE_CONTEXT, op_type);
amp_logic_str += "\n";
amp_logic_str += amp_auto_cast_str;
amp_logic_str += dispensable_amp_auto_cast_str;
amp_logic_str += "\n";
}
const char* CALL_BACK_TEMPLATE =
" {\n"
" paddle::imperative::AutoCastGuard "
"guard(egr::Controller::Instance().GetCurrentTracer(), "
"paddle::imperative::AmpLevel::O0);\n"
" return %s_dygraph_function(%s);\n"
" }";
amp_function_call_args_str += ", attr_map ";
if (amp_function_call_args_str.size() > 0) {
auto iter = amp_function_call_args_str.begin();
if ((*iter) == ',') amp_function_call_args_str.erase(iter);
}
std::string call_back_str = paddle::string::Sprintf(
CALL_BACK_TEMPLATE, op_type, amp_function_call_args_str);
amp_logic_str += call_back_str;
amp_logic_str += "\n";
std::string amp_context =
paddle::string::Sprintf(AMP_LOGIC_CONTEXT, amp_logic_str);
generated_function_body += amp_context;
generated_function_body += "\n";
}
// forward ins insert
const char* FWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerVariable>>> ins = { "
"%s };\n";
std::string ins_map_str =
paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str);
ins_map_str += dispensable_ins_contents_str;
generated_function_body += ins_map_str;
generated_function_body += "\n";
// forward outs insert
const char* FWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerVariable>>> outs = { "
......@@ -2044,6 +2160,7 @@ static std::string GenerateSingleOpBase(
grad_attrs_str += paddle::string::Sprintf(CAST_GRAD, attrs_name, attrs_name,
attrs_name, attrs_name);
}
// Handle dynamic grad attributes
grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name);
generated_grad_function_body += grad_attrs_str;
......@@ -2469,6 +2586,7 @@ static void GenerateForwardDygraphFile(const std::string& forward_cc_path,
"#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n"
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n"
"#include \"paddle/fluid/eager/amp_utils.h\"\n"
"#include \"paddle/fluid/platform/profiler/event_tracing.h\"\n\n";
std::string forward_cc_include_str =
paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE);
......
......@@ -364,6 +364,7 @@ GradNodeBase::ApplyGradientHooks(
if (!outs[i][j].defined() || !outs[i][j].initialized()) {
outs[i][j] = tensors[i][j];
}
CheckTensor(tensors[i][j], outs[i][j]);
}
}
......
......@@ -279,4 +279,29 @@ class Edge {
std::shared_ptr<GradNodeBase> grad_node_{nullptr};
};
inline void CheckTensor(const paddle::experimental::Tensor& pre,
const paddle::experimental::Tensor& post) {
if (!pre.initialized() && post.initialized()) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"The tensor in before and after hook are not consistent"));
}
if (pre.initialized() && post.initialized()) {
VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " "
<< paddle::framework::DataType2String(post.dtype());
PADDLE_ENFORCE_EQ(
pre.dtype(), post.dtype(),
paddle::platform::errors::PermissionDenied(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent",
paddle::framework::DataType2String(pre.dtype()),
paddle::framework::DataType2String(post.dtype())));
PADDLE_ENFORCE_EQ(
pre.inner_place(), post.inner_place(),
paddle::platform::errors::PermissionDenied(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent",
pre.inner_place().DebugString(), post.inner_place().DebugString()));
}
}
} // namespace egr
......@@ -145,6 +145,8 @@ DataType String2DataType(const std::string& str) {
return DataType::COMPLEX64;
} else if (str == "complex128") {
return DataType::COMPLEX128;
} else if (str == "bfloat16") {
return DataType::BFLOAT16;
} else {
return DataType::UNDEFINED;
}
......@@ -174,6 +176,8 @@ std::string DataType2String(DataType dtype) {
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unknow phi::DataType, the int value = %d.",
......
......@@ -112,6 +112,9 @@ int tensor_properties_set_stop_gradient(TensorObject* self, PyObject* value,
EAGER_TRY
auto meta = egr::EagerUtils::autograd_meta(&self->tensor);
meta->SetStopGradient(CastPyArg2AttrBoolean(value, 0));
if (!meta->GradNode()) {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
}
return 0;
EAGER_CATCH_AND_THROW_RETURN_ZERO
}
......
......@@ -172,7 +172,7 @@ class TestImperativeQatAmp(unittest.TestCase):
acc_top1 = sum(acc_top1_list) / len(acc_top1_list)
return acc_top1
def test_ptq(self):
def ptq(self):
start_time = time.time()
self.set_vars()
......@@ -217,6 +217,11 @@ class TestImperativeQatAmp(unittest.TestCase):
end_time = time.time()
print("total time: %ss" % (end_time - start_time))
def test_ptq(self):
self.ptq()
with _test_eager_guard():
self.ptq()
if __name__ == '__main__':
unittest.main()
......@@ -997,17 +997,17 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_parallel_executor_transformer',
'test_tensor_scalar_type_promotion_dynamic',
'test_eager_deletion_delete_vars', 'test_asp_pruning_1d',
'test_imperative_auto_mixed_precision', 'test_imperative_using_non_zero_gpu',
'test_machine_translation', 'test_flatten_op', 'test_onnx_export',
'test_optimizer_for_varbase', 'test_fusion_transpose_flatten_concat_op',
'best_fit_allocator_test', 'test_ir_fusion_group_pass',
'test_trt_quant_conv2d_dequant_fuse_pass', 'test_allclose_op',
'test_ftrl_op', 'test_elementwise_add_op', 'test_instance_norm_op',
'test_lambv2_op', 'test_yolo_box_op', 'test_parallel_executor_drop_scope',
'test_generator_dataloader', 'test_conv2d_transpose_op_depthwise_conv',
'test_imperative_save_load_v2', 'test_lookahead',
'test_moving_average_abs_max_scale_op', 'test_roi_perspective_transform_op',
'test_tensorrt_engine', 'test_affine_grid_function', 'test_nonzero_api',
'test_imperative_using_non_zero_gpu', 'test_machine_translation',
'test_flatten_op', 'test_onnx_export', 'test_optimizer_for_varbase',
'test_fusion_transpose_flatten_concat_op', 'best_fit_allocator_test',
'test_ir_fusion_group_pass', 'test_trt_quant_conv2d_dequant_fuse_pass',
'test_allclose_op', 'test_ftrl_op', 'test_elementwise_add_op',
'test_instance_norm_op', 'test_lambv2_op', 'test_yolo_box_op',
'test_parallel_executor_drop_scope', 'test_generator_dataloader',
'test_conv2d_transpose_op_depthwise_conv', 'test_imperative_save_load_v2',
'test_lookahead', 'test_moving_average_abs_max_scale_op',
'test_roi_perspective_transform_op', 'test_tensorrt_engine',
'test_affine_grid_function', 'test_nonzero_api',
'test_ir_memory_optimize_pass', 'test_reduce_mkldnn_op',
'test_bilinear_interp_op', 'test_cvm_op', 'test_scale_op', 'test_matmul_op',
'test_sequence_pool', 'test_complex_simplenet', 'test_complex_reshape',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册