未验证 提交 d0f46aac 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP] fix bug in phi static graph mode (#41269)

* [KP] fix bug in phi static graph mode

* modify the useless code
上级 14b91f60
......@@ -1293,16 +1293,54 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
} else {
pt_kernel_name = pt_kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
#ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
if (use_xpu_kp_kernel_rt) {
VLOG(3) << "phi xpu_kp using rt mode in static graph";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "phi xpu_kp using debug mode in static graph";
}
bool is_xpu_kp_support =
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
auto expected_kernel_key_library_type = kernel_type_->library_type_;
kernel_type_->library_type_ = LibraryType::kKP;
VLOG(3) << "modifing XPU KP kernel in static graph: " << type_
<< ", using_kernel_key:" << *kernel_type_.get();
auto try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
kernel_type_->library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel in static graph: " << type_
<< " is failed " << *kernel_type_.get();
}
}
}
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
}
#ifdef PADDLE_WITH_XPU
// NOTE(Liu-xiandong): Determine whether the selected kernel is valid
// If not, use the kernel registered in fluid. And if the fluid do not
// contains the related heterogeneous kernel, use phi CPU kernel.
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
bool is_xpu_unsupport =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
!paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) ||
paddle::platform::is_in_xpu_black_list(type_);
#endif
if (pt_kernel_->IsValid()
#ifdef PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&& !is_xpu_unsupport
#endif
) {
......@@ -1310,10 +1348,29 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} else {
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi,
// we need to select the heterogeneous kernel in fluid, but the kernel
// registered in KP use library_type[KP], we need to modify it.
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
paddle::platform::is_in_xpu_kpwhite_list(type_);
bool is_xpu_kp_support =
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
kernel_type_->library_type_ = LibraryType::kKP;
}
#endif
if (kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(*kernel_type_.get()) ==
kernels_iter->second.end()
#ifdef PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| is_xpu_unsupport
#endif
) {
......@@ -1552,10 +1609,22 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
}
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
auto cache_expected_kernel_key_library_type =
expected_kernel_key.library_type_;
expected_kernel_key.library_type_ = LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << type_
<< ", using_kernel_key:" << expected_kernel_key;
// if can't find corresponding kernel when is_xpu_kp_support is on
// if the fluid do not register related kernel, it can't work and hava
// error as before
if (kernel_iter == kernels.end()) {
expected_kernel_key.library_type_ =
cache_expected_kernel_key_library_type;
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
} else {
VLOG(3) << "using XPU KP kernel: " << type_
<< ", using_kernel_key:" << expected_kernel_key;
}
}
bool is_xpu_unsupport =
(!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
......
......@@ -174,7 +174,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
// modify the expected_kernel_key for KP in phi
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
#ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
......@@ -238,6 +240,9 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi,
// we need to select the heterogeneous kernel in fluid, but the kernel
// registered in KP use library_type[KP], we need to modify it.
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册