未验证 提交 705776ca 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP] fix bug in activation xpu kp kernel (#41219)

* fix bug in activation xpu kp kernel

* delete useless comment
上级 f0f2e2f9
......@@ -191,12 +191,23 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
bool is_xpu_kp_support =
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
auto expected_kernel_key_library_type =
expected_kernel_key.library_type_;
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
VLOG(3) << "modify XPU KP kernel: " << op.Type()
VLOG(3) << "modifing XPU KP kernel: " << op.Type()
<< ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key;
}
}
}
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key);
......@@ -227,6 +238,20 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
}
#endif
if ((kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end())
......@@ -255,6 +280,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.",
op.Type()));
auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -271,18 +297,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) {
VLOG(3) << "xpu_kp using rt mode ";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode ";
}
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
......
......@@ -59,6 +59,21 @@ KernelKeyMap KernelFactory::SelectKernelMap(
return iter->second;
}
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end()) {
return false;
}
return true;
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
......
......@@ -245,6 +245,9 @@ class KernelFactory {
DataLayout layout,
DataType dtype) const;
bool IsSelectKernelValid(const std::string& kernel_name,
const KernelKey& kernel_key) const;
Kernel SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册