未验证 提交 a14ae84b 编写于 作者: H HongyuJia 提交者: GitHub

opt kernel_selection error msg (#48864)

上级 52116b16
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#endif #endif
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/string/string_helper.h"
DECLARE_bool(enable_api_kernel_fallback); DECLARE_bool(enable_api_kernel_fallback);
...@@ -28,8 +29,8 @@ namespace phi { ...@@ -28,8 +29,8 @@ namespace phi {
const static Kernel empty_kernel; // NOLINT const static Kernel empty_kernel; // NOLINT
std::string kernel_selection_error_message(const std::string& kernel_name, std::string KernelSelectionErrorMessage(const std::string& kernel_name,
const KernelKey& target_key); const KernelKey& target_key);
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const { uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0; uint32_t hash_value = 0;
...@@ -146,7 +147,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -146,7 +147,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
"The kernel with key %s of kernel `%s` is not registered. %s", "The kernel with key %s of kernel `%s` is not registered. %s",
kernel_key, kernel_key,
kernel_name, kernel_name,
kernel_selection_error_message(kernel_name, kernel_key))); KernelSelectionErrorMessage(kernel_name, kernel_key)));
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name);
...@@ -176,7 +177,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -176,7 +177,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
"fail to fallback to CPU one. %s", "fail to fallback to CPU one. %s",
kernel_key, kernel_key,
kernel_name, kernel_name,
kernel_selection_error_message(kernel_name, kernel_key))); KernelSelectionErrorMessage(kernel_name, kernel_key)));
VLOG(3) << "missing " << kernel_key.backend() << " kernel: " << kernel_name VLOG(3) << "missing " << kernel_key.backend() << " kernel: " << kernel_name
<< ", expected_kernel_key:" << kernel_key << ", expected_kernel_key:" << kernel_key
...@@ -195,7 +196,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -195,7 +196,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
" to CPU one, please set the flag true before run again.", " to CPU one, please set the flag true before run again.",
kernel_key, kernel_key,
kernel_name, kernel_name,
kernel_selection_error_message(kernel_name, kernel_key))); KernelSelectionErrorMessage(kernel_name, kernel_key)));
return {kernel_iter->second, false}; return {kernel_iter->second, false};
} }
...@@ -368,8 +369,8 @@ std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) { ...@@ -368,8 +369,8 @@ std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
// (GPU, Undefined(AnyLayout), [float32, float64, ...]); // (GPU, Undefined(AnyLayout), [float32, float64, ...]);
// ... // ...
// } // }
std::string kernel_selection_error_message(const std::string& kernel_name, std::string KernelSelectionErrorMessage(const std::string& kernel_name,
const KernelKey& target_key) { const KernelKey& target_key) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
KernelFactory::Instance().kernels().find(kernel_name), KernelFactory::Instance().kernels().find(kernel_name),
KernelFactory::Instance().kernels().end(), KernelFactory::Instance().kernels().end(),
...@@ -402,12 +403,7 @@ std::string kernel_selection_error_message(const std::string& kernel_name, ...@@ -402,12 +403,7 @@ std::string kernel_selection_error_message(const std::string& kernel_name,
// 1. If target_key not supports target backend, output "Selected wrong // 1. If target_key not supports target backend, output "Selected wrong
// Backend ..." // Backend ..."
if (!support_backend) { if (!support_backend) {
std::string error_message = ""; std::string error_message = paddle::string::join_strings(backend_set, ", ");
for (auto iter = backend_set.begin(); iter != backend_set.end(); ++iter) {
error_message += *iter;
error_message += ", ";
}
error_message = error_message.substr(0, error_message.length() - 2);
return "Selected wrong Backend `" + return "Selected wrong Backend `" +
paddle::experimental::BackendToString(target_key.backend()) + paddle::experimental::BackendToString(target_key.backend()) +
"`. Paddle support following Backends: " + error_message + "."; "`. Paddle support following Backends: " + error_message + ".";
...@@ -415,12 +411,7 @@ std::string kernel_selection_error_message(const std::string& kernel_name, ...@@ -415,12 +411,7 @@ std::string kernel_selection_error_message(const std::string& kernel_name,
// 2. If target_key not supports target datatype, output "Selected wrong // 2. If target_key not supports target datatype, output "Selected wrong
// DataType ..." // DataType ..."
if (!support_dtype) { if (!support_dtype) {
std::string error_message = ""; std::string error_message = paddle::string::join_strings(dtype_set, ", ");
for (auto iter = dtype_set.begin(); iter != dtype_set.end(); ++iter) {
error_message += *iter;
error_message += ", ";
}
error_message = error_message.substr(0, error_message.length() - 2);
return "Selected wrong DataType `" + return "Selected wrong DataType `" +
paddle::experimental::DataTypeToString(target_key.dtype()) + paddle::experimental::DataTypeToString(target_key.dtype()) +
"`. Paddle support following DataTypes: " + error_message + "."; "`. Paddle support following DataTypes: " + error_message + ".";
...@@ -431,14 +422,9 @@ std::string kernel_selection_error_message(const std::string& kernel_name, ...@@ -431,14 +422,9 @@ std::string kernel_selection_error_message(const std::string& kernel_name,
kernel_name + "`: { "; kernel_name + "`: { ";
for (auto iter = all_kernel_key.begin(); iter != all_kernel_key.end(); for (auto iter = all_kernel_key.begin(); iter != all_kernel_key.end();
++iter) { ++iter) {
message += "(" + iter->first + ", [";
std::vector<std::string>& dtype_vec = iter->second; std::vector<std::string>& dtype_vec = iter->second;
for (std::size_t i = 0; i < dtype_vec.size(); ++i) { message += "(" + iter->first + ", [";
message += dtype_vec[i]; message += paddle::string::join_strings(dtype_vec, ", ");
if (i + 1 != dtype_vec.size()) {
message += ", ";
}
}
message += "]); "; message += "]); ";
} }
message += "}."; message += "}.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册