未验证 提交 a214e5dc 编写于 作者: N niuliling123 提交者: GitHub

Fix inaccurate return of low precision op list (#49391)

上级 c7899074
...@@ -100,7 +100,6 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -100,7 +100,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps() .GetMutableAllowOps()
->count(op_name)) { ->count(op_name)) {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
return paddle::experimental::DataType::FLOAT16; return paddle::experimental::DataType::FLOAT16;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
...@@ -118,8 +117,6 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -118,8 +117,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableUnsupportedFp16Ops() .GetMutableUnsupportedFp16Ops()
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
} }
return dst_type; return dst_type;
} }
...@@ -132,8 +129,6 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -132,8 +129,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
} }
return dst_type; return dst_type;
} }
...@@ -142,7 +137,6 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -142,7 +137,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps() .GetMutableAllowOps()
->count(op_name)) { ->count(op_name)) {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
return paddle::experimental::DataType::BFLOAT16; return paddle::experimental::DataType::BFLOAT16;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
...@@ -158,8 +152,6 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -158,8 +152,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableUnsupportedBf16Ops() .GetMutableUnsupportedBf16Ops()
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
} }
return dst_type; return dst_type;
} }
...@@ -172,8 +164,6 @@ inline paddle::experimental::DataType GetAmpDestDtype( ...@@ -172,8 +164,6 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.GetMutableBlockOps() .GetMutableBlockOps()
->count(op_name)) { ->count(op_name)) {
dst_type = paddle::experimental::DataType::FLOAT32; dst_type = paddle::experimental::DataType::FLOAT32;
} else {
paddle::imperative::AmpOperators::Instance().AddToAmpOpList(op_name);
} }
return dst_type; return dst_type;
} }
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h" #include "paddle/fluid/imperative/var_helper.h"
DECLARE_bool(low_precision_op_list);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -194,16 +193,6 @@ AmpOperators::GetMutableUnsupportedBf16Ops() { ...@@ -194,16 +193,6 @@ AmpOperators::GetMutableUnsupportedBf16Ops() {
return unsupported_bf16_ops_; return unsupported_bf16_ops_;
} }
void AmpOperators::AddToAmpOpList(const std::string& op_name) {
if (FLAGS_low_precision_op_list) {
current_amp_ops_[op_name] += 1;
}
}
std::map<const std::string, int> AmpOperators::GetAmpOpList() {
return current_amp_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: "; os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps(); auto allow_ops = ops.GetMutableAllowOps();
......
...@@ -60,10 +60,6 @@ class AmpOperators { ...@@ -60,10 +60,6 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedBf16Ops(); GetMutableUnsupportedBf16Ops();
void AddToAmpOpList(const std::string& op_name);
std::map<const std::string, int> GetAmpOpList();
private: private:
AmpOperators(); // forbid calling default constructor AmpOperators(); // forbid calling default constructor
...@@ -80,9 +76,6 @@ class AmpOperators { ...@@ -80,9 +76,6 @@ class AmpOperators {
// The set of ops that has no bf16 CUDA kennel. // The set of ops that has no bf16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_; std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_;
// The amp op list of current module.
std::map<const std::string, int> current_amp_ops_;
}; };
std::ostream& operator<<(std::ostream& os, AmpOperators& ops); std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
......
...@@ -2546,7 +2546,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2546,7 +2546,7 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); [] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("get_low_precision_op_list", [] { m.def("get_low_precision_op_list", [] {
return paddle::imperative::AmpOperators::Instance().GetAmpOpList(); return phi::KernelFactory::Instance().GetLowPrecisionKernelList();
}); });
m.def("autotune_status", [] { m.def("autotune_status", [] {
......
...@@ -1200,6 +1200,9 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1200,6 +1200,9 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( {code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); {code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} const auto& kernel = kernel_result.kernel; {code_indent} const auto& kernel = kernel_result.kernel;
{code_indent} if (FLAGS_low_precision_op_list) {{
{code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
{code_indent} }}
{code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel; {code_indent} VLOG(6) << "{kernel_name} kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); {code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
{input_tensors} {input_tensors}
......
...@@ -347,6 +347,7 @@ def source_include(header_file_path): ...@@ -347,6 +347,7 @@ def source_include(header_file_path):
#include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(conv2d_disable_cudnn); DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -290,6 +290,7 @@ def source_include(header_file_path): ...@@ -290,6 +290,7 @@ def source_include(header_file_path):
#include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_bool(conv2d_disable_cudnn); DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -54,6 +54,8 @@ def source_include(header_file_path): ...@@ -54,6 +54,8 @@ def source_include(header_file_path):
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h"
DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -221,6 +221,9 @@ class SparseAPI(ForwardAPI): ...@@ -221,6 +221,9 @@ class SparseAPI(ForwardAPI):
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
const auto& phi_kernel = kernel_result.kernel; const auto& phi_kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
}}
VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel; VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
...@@ -324,6 +327,8 @@ def source_include(header_file_path): ...@@ -324,6 +327,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/sparse/unary.h" #include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h" #include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h" #include "paddle/phi/infermeta/sparse/multiary.h"
DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -134,6 +134,8 @@ def source_include(header_file_path): ...@@ -134,6 +134,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/sparse/unary.h" #include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h" #include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/backward.h" #include "paddle/phi/infermeta/sparse/backward.h"
DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -210,6 +210,9 @@ class StringsAPI(ForwardAPI): ...@@ -210,6 +210,9 @@ class StringsAPI(ForwardAPI):
VLOG(6) << "{self.api} api strings kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; VLOG(6) << "{self.api} api strings kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type);
}}
const auto& kernel = kernel_result.kernel; const auto& kernel = kernel_result.kernel;
VLOG(6) << "{self.api} api strings kernel: " << kernel; VLOG(6) << "{self.api} api strings kernel: " << kernel;
...@@ -334,6 +337,8 @@ def source_include(header_file_path): ...@@ -334,6 +337,8 @@ def source_include(header_file_path):
#include "paddle/phi/api/lib/api_registry.h" #include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -55,16 +55,19 @@ PADDLE_DEFINE_EXPORTED_int32(paddle_num_threads, ...@@ -55,16 +55,19 @@ PADDLE_DEFINE_EXPORTED_int32(paddle_num_threads,
/** /**
* Low Precision Op related FLAG * Low Precision Op related FLAG
* Name: FLAGS_low_precision_op_list * Name: FLAGS_low_precision_op_list
* Since Version: 0.13.0 * Since Version: 2.5.0
* Value Range: bool, default=false * Value Range: int32, default=0
* Example: * Example:
* Note: Used to debug. Get the low precision op list of current module. * Note: Used to debug. Get the low precision op list of current module.
* FLAGS_check_nan_inf is set.
* - 1, return the low precision op list of current module.
* - 2, return the op list of current module.
*/ */
PADDLE_DEFINE_EXPORTED_bool(low_precision_op_list, PADDLE_DEFINE_EXPORTED_int32(low_precision_op_list,
false, 0,
"Checking whether get the low precision op list of " "Setting the level of low precision op"
"current module. It will be " "list printing. It will be return the "
"rerun the low precision list after module."); "low precision op list of current module.");
/** /**
* Operator related FLAG * Operator related FLAG
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/string/string_helper.h" #include "paddle/utils/string/string_helper.h"
DECLARE_int32(low_precision_op_list);
DECLARE_bool(enable_api_kernel_fallback); DECLARE_bool(enable_api_kernel_fallback);
namespace phi { namespace phi {
...@@ -106,9 +107,33 @@ bool KernelFactory::HasKernel(const std::string& kernel_name, ...@@ -106,9 +107,33 @@ bool KernelFactory::HasKernel(const std::string& kernel_name,
return true; return true;
} }
void KernelFactory::AddToLowPrecisionKernelList(
const std::string& name,
const paddle::experimental::DataType& kernel_key_type) {
if (FLAGS_low_precision_op_list >= 1) {
auto op_name = phi::TransToFluidOpName(name);
if (op_name.find("_grad") != std::string::npos) {
return; // only record forward api
}
bool is_low_precision =
(kernel_key_type == paddle::experimental::DataType::FLOAT16 ||
kernel_key_type == paddle::experimental::DataType::BFLOAT16);
bool need_record =
FLAGS_low_precision_op_list == 1 ? is_low_precision : true;
if (need_record) {
low_precision_kernels_[op_name] += 1;
}
}
}
std::map<const std::string, int> KernelFactory::GetLowPrecisionKernelList() {
return low_precision_kernels_;
}
KernelResult KernelFactory::SelectKernelOrThrowError( KernelResult KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& const_kernel_key) const { const std::string& kernel_name, const KernelKey& const_kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
kernels_.end(), kernels_.end(),
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#pragma once #pragma once
#include <map>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
...@@ -305,10 +305,19 @@ class KernelFactory { ...@@ -305,10 +305,19 @@ class KernelFactory {
const KernelArgsDef& GetFirstKernelArgsDef( const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const; const std::string& kernel_name) const;
void AddToLowPrecisionKernelList(
const std::string& name,
const paddle::experimental::DataType& kernel_key_type);
std::map<const std::string, int> GetLowPrecisionKernelList();
private: private:
KernelFactory() = default; KernelFactory() = default;
KernelNameMap kernels_; KernelNameMap kernels_;
// Get the low precision kernel list of current module.
std::map<const std::string, int> low_precision_kernels_;
}; };
inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/scale_kernel.h"
DECLARE_int32(low_precision_op_list);
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -54,6 +55,10 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x, ...@@ -54,6 +55,10 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x,
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type}); "scale", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel; const auto& kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {
phi::KernelFactory::Instance().AddToLowPrecisionKernelList(
"scale", kernel_data_type);
}
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", " VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]"; << kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel; VLOG(6) << "scale API kernel: " << kernel;
...@@ -225,6 +230,10 @@ Tensor scale_switch_case(const Tensor& x, ...@@ -225,6 +230,10 @@ Tensor scale_switch_case(const Tensor& x,
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type}); "scale", {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel; const auto& kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {
phi::KernelFactory::Instance().AddToLowPrecisionKernelList(
"scale", kernel_data_type);
}
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", " VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]"; << kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel; VLOG(6) << "scale API kernel: " << kernel;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os
import warnings import warnings
import paddle import paddle
...@@ -94,18 +95,23 @@ _g_amp_state_ = None ...@@ -94,18 +95,23 @@ _g_amp_state_ = None
def low_precision_op_list(): def low_precision_op_list():
op_list = paddle.fluid.core.get_low_precision_op_list() if os.getenv("FLAGS_low_precision_op_list") is not None:
op_count = 0 level = int(os.getenv("FLAGS_low_precision_op_list"))
print('<---------------- low precision op list ------------------->') if level == 0:
print('<---- op name ------|------- op count---------------------->') return
for x in op_list: if level == 1:
print(' %-18s| %4d' % (x, op_list[x])) print('<{:-^60}>'.format(" low precision op list "))
op_count += 1 else:
print( print('<{:-^60}>'.format(" op list "))
'<------------- low precision op num:{:5d} ----------------->'.format( op_list = paddle.fluid.core.get_low_precision_op_list()
op_count op_count = 0
print(
'<{:-^40}'.format(" op_name "), '|', '{:-^17}>'.format(" op count ")
) )
) for x in op_list:
print(' %-40s| %-15d' % (x, op_list[x]))
op_count += 1
print('<{:-^60}>'.format(" op count: " + str(op_count) + " "))
def amp_state(): def amp_state():
......
...@@ -25,12 +25,11 @@ class TestAMPList(unittest.TestCase): ...@@ -25,12 +25,11 @@ class TestAMPList(unittest.TestCase):
b = paddle.rand([2, 3]) b = paddle.rand([2, 3])
# amp list conv2d, cast # amp list conv2d, cast
with paddle.amp.auto_cast(): with paddle.amp.auto_cast(enable=True, level='O2'):
conv = conv2d(data) conv = conv2d(data)
c = a + b c = a + b
paddle.amp.low_precision_op_list() paddle.amp.low_precision_op_list()
op_list = paddle.fluid.core.get_low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list()
print(conv.dtype)
if conv.dtype == paddle.float16: if conv.dtype == paddle.float16:
self.assertTrue('elementwise_add' in op_list) self.assertTrue('elementwise_add' in op_list)
self.assertTrue('conv2d' in op_list) self.assertTrue('conv2d' in op_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册