未验证 提交 705776ca 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP] fix bug in activation xpu kp kernel (#41219)

* fix bug in activation xpu kp kernel

* delete useless comment
上级 f0f2e2f9
...@@ -191,12 +191,23 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -191,12 +191,23 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
bool is_xpu_kp_support = bool is_xpu_kp_support =
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) { if (is_xpu_kp_support) {
auto expected_kernel_key_library_type =
expected_kernel_key.library_type_;
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
VLOG(3) << "modify XPU KP kernel: " << op.Type() VLOG(3) << "modifing XPU KP kernel: " << op.Type()
<< ", using_kernel_key:" << expected_kernel_key; << ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key;
}
} }
} }
#endif #endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key); pt_kernel_key);
...@@ -227,6 +238,20 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -227,6 +238,20 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& all_op_kernels = op.AllOpKernels(); auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type()); auto kernels_iter = all_op_kernels.find(op.Type());
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
}
#endif
if ((kernels_iter == all_op_kernels.end() || if ((kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) == kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end()) kernels_iter->second.end())
...@@ -255,6 +280,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -255,6 +280,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
platform::errors::NotFound( platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.", "There are no kernels which are registered in the %s operator.",
op.Type())); op.Type()));
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -271,18 +297,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -271,18 +297,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) { if (use_xpu_kp_kernel_rt) {
VLOG(3) << "xpu_kp using rt mode "; VLOG(3) << "xpu_kp using rt mode ";
} }
if (use_xpu_kp_kernel_debug) { if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode "; VLOG(3) << "xpu_kp using debug mode ";
} }
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) { if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
......
...@@ -59,6 +59,21 @@ KernelKeyMap KernelFactory::SelectKernelMap( ...@@ -59,6 +59,21 @@ KernelKeyMap KernelFactory::SelectKernelMap(
return iter->second; return iter->second;
} }
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end()) {
return false;
}
return true;
}
const Kernel& KernelFactory::SelectKernelOrThrowError( const Kernel& KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& kernel_key) const { const std::string& kernel_name, const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
......
...@@ -245,6 +245,9 @@ class KernelFactory { ...@@ -245,6 +245,9 @@ class KernelFactory {
DataLayout layout, DataLayout layout,
DataType dtype) const; DataType dtype) const;
bool IsSelectKernelValid(const std::string& kernel_name,
const KernelKey& kernel_key) const;
Kernel SelectKernel(const std::string& kernel_name, Kernel SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册