未验证 提交 a214e5dc 编写于 作者: N niuliling123 提交者: GitHub

Fix inaccurate return of low precision op list (#49391)

上级 c7899074
......@@ -100,7 +100,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
return paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
......@@ -118,8 +117,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableUnsupportedFp16Ops()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......@@ -132,8 +129,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableBlockOps()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......@@ -142,7 +137,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
return paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
......@@ -158,8 +152,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableUnsupportedBf16Ops()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......@@ -172,8 +164,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableBlockOps()
->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
}
return dst_type;
}
......
......@@ -22,7 +22,6 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
DECLARE_bool(low_precision_op_list);
namespace paddle {
namespace imperative {
......@@ -194,16 +193,6 @@ AmpOperators::GetMutableUnsupportedBf16Ops() {
return unsupported_bf16_ops_;
}
void AmpOperators::AddToAmpOpList(const std::string& op_name) {
if (FLAGS_low_precision_op_list) {
current_amp_ops_[op_name] += 1;
}
}
std::map<const std::string, int> AmpOperators::GetAmpOpList() {
return current_amp_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
......
......@@ -60,10 +60,6 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedBf16Ops();
void AddToAmpOpList(const std::string& op_name);
std::map<const std::string, int> GetAmpOpList();
private:
AmpOperators(); // forbid calling default constructor
......@@ -80,9 +76,6 @@ class AmpOperators {
// The set of ops that has no bf16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_;
// The amp op list of current module.
std::map<const std::string, int> current_amp_ops_;
};
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
......
......@@ -2546,7 +2546,7 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("get_low_precision_op_list", [] {
return paddle::imperative::AmpOperators::Instance().GetAmpOpList();
return phi::KernelFactory::Instance().GetLowPrecisionKernelList();
});
m.def("autotune_status", [] {
......
......@@ -1200,6 +1200,9 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} const auto& kernel = kernel_result.kernel;
{code_indent} if (FLAGS_low_precision_op_list) {{
{code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
{code_indent} }}
{code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
{input_tensors}
......
......@@ -347,6 +347,7 @@ def source_include(header_file_path):
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -290,6 +290,7 @@ def source_include(header_file_path):
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -54,6 +54,8 @@ def source_include(header_file_path):
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -221,6 +221,9 @@ class SparseAPI(ForwardAPI):
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
const auto& phi_kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
}}
VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
......@@ -324,6 +327,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -134,6 +134,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/backward.h"
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -210,6 +210,9 @@ class StringsAPI(ForwardAPI):
VLOG(6) << "{self.api} api strings kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
}}
const auto& kernel = kernel_result.kernel;
VLOG(6) << "{self.api} api strings kernel: " << kernel;
......@@ -334,6 +337,8 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h"
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -55,16 +55,19 @@ PADDLE_DEFINE_EXPORTED_int32(paddle_num_threads,
/**
* Low Precision Op related FLAG
* Name: FLAGS_low_precision_op_list
* Since Version: 0.13.0
* Value Range: bool, default=false
* Since Version: 2.5.0
* Value Range: int32, default=0
* Example:
* Note: Used to debug. Get the low precision op list of current module.
* FLAGS_check_nan_inf is set.
* - 1, return the low precision op list of current module.
* - 2, return the op list of current module.
*/
PADDLE_DEFINE_EXPORTED_bool(low_precision_op_list,
false,
"Checking whether get the low precision op list of "
"current module. It will be "
"rerun the low precision list after module.");
PADDLE_DEFINE_EXPORTED_int32(low_precision_op_list,
0,
"Setting the level of low precision op"
"list printing. It will be return the "
"low precision op list of current module.");
/**
* Operator related FLAG
......
......@@ -23,6 +23,7 @@
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/string/string_helper.h"
DECLARE_int32(low_precision_op_list);
DECLARE_bool(enable_api_kernel_fallback);
namespace phi {
......@@ -106,9 +107,33 @@ bool KernelFactory::HasKernel(const std::string& kernel_name,
return true;
}
void KernelFactory::AddToLowPrecisionKernelList(
const std::string& name,
const paddle::experimental::DataType& kernel_key_type) {
if (FLAGS_low_precision_op_list >= 1) {
auto op_name = phi::TransToFluidOpName(name);
if (op_name.find("_grad") != std::string::npos) {
return; // only record forward api
}
bool is_low_precision =
(kernel_key_type == paddle::experimental::DataType::FLOAT16 ||
kernel_key_type == paddle::experimental::DataType::BFLOAT16);
bool need_record =
FLAGS_low_precision_op_list == 1 ? is_low_precision : true;
if (need_record) {
low_precision_kernels_[op_name] += 1;
}
}
}
std::map<const std::string, int> KernelFactory::GetLowPrecisionKernelList() {
return low_precision_kernels_;
}
KernelResult KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& const_kernel_key) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
......
......@@ -14,12 +14,12 @@
#pragma once
#include <map>
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
......@@ -305,10 +305,19 @@ class KernelFactory {
const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const;
void AddToLowPrecisionKernelList(
const std::string& name,
const paddle::experimental::DataType& kernel_key_type);
std::map<const std::string, int> GetLowPrecisionKernelList();
private:
KernelFactory() = default;
KernelNameMap kernels_;
// Get the low precision kernel list of current module.
std::map<const std::string, int> low_precision_kernels_;
};
inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
......
......@@ -25,6 +25,7 @@
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/scale_kernel.h"
DECLARE_int32(low_precision_op_list);
namespace paddle {
namespace experimental {
......@@ -54,6 +55,10 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x,
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {
phi::KernelFactory::Instance().AddToLowPrecisionKernelList(
"scale", kernel_data_type);
}
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel;
......@@ -225,6 +230,10 @@ Tensor scale_switch_case(const Tensor& x,
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {
phi::KernelFactory::Instance().AddToLowPrecisionKernelList(
"scale", kernel_data_type);
}
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel;
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import os
import warnings
import paddle
......@@ -94,18 +95,23 @@ _g_amp_state_ = None
def low_precision_op_list():
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print('<---------------- low precision op list ------------------->')
print('<---- op name ------|------- op count---------------------->')
for x in op_list:
print(' %-18s| %4d' % (x, op_list[x]))
op_count += 1
print(
'<------------- low precision op num:{:5d} ----------------->'.format(
op_count
if os.getenv("FLAGS_low_precision_op_list") is not None:
level = int(os.getenv("FLAGS_low_precision_op_list"))
if level == 0:
return
if level == 1:
print('<{:-^60}>'.format(" low precision op list "))
else:
print('<{:-^60}>'.format(" op list "))
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print(
'<{:-^40}'.format(" op_name "), '|', '{:-^17}>'.format(" op count ")
)
)
for x in op_list:
print(' %-40s| %-15d' % (x, op_list[x]))
op_count += 1
print('<{:-^60}>'.format(" op count: " + str(op_count) + " "))
def amp_state():
......
......@@ -25,12 +25,11 @@ class TestAMPList(unittest.TestCase):
b = paddle.rand([2, 3])
# amp list conv2d, cast
with paddle.amp.auto_cast():
with paddle.amp.auto_cast(enable=True, level='O2'):
conv = conv2d(data)
c = a + b
paddle.amp.low_precision_op_list()
op_list = paddle.fluid.core.get_low_precision_op_list()
print(conv.dtype)
if conv.dtype == paddle.float16:
self.assertTrue('elementwise_add' in op_list)
self.assertTrue('conv2d' in op_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册