diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index a427b9b8199116098d149689961cedf14e86e5e1..ffac264b51d502242aa4e8c7afebc88b03c5b6f0 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -161,24 +161,48 @@ PreparedOp PrepareImpl(const NameVarMap& ins, framework::KernelSignature pt_kernel_signature; phi::KernelKey pt_kernel_key; std::string pt_kernel_name; -#ifdef PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) bool is_xpu_unsupport = paddle::platform::is_xpu_place(expected_kernel_key.place_) && !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key) || paddle::platform::is_in_xpu_black_list(op.Type()); + #endif if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx); VLOG(6) << pt_kernel_signature; pt_kernel_name = pt_kernel_signature.name; +// modify the expected_kernel_key for KP in phi +#ifdef PADDLE_WITH_XPU_KP + if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + bool use_xpu_kp_kernel_rt = + FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op( + op.Type(), expected_kernel_key); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_in_xpu_kpwhite_list(op.Type()); + if (use_xpu_kp_kernel_rt) { + VLOG(3) << "phi xpu_kp using rt mode "; + } + if (use_xpu_kp_kernel_debug) { + VLOG(3) << "phi xpu_kp using debug mode "; + } + bool is_xpu_kp_support = + (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + VLOG(3) << "modify XPU KP kernel: " << op.Type() + << ", using_kernel_key:" << expected_kernel_key; + } + } +#endif pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, pt_kernel_key); if (pt_kernel.IsValid() -#ifdef PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) && !is_xpu_unsupport #endif ) { @@ -206,7 +230,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, if ((kernels_iter == all_op_kernels.end() || kernels_iter->second.find(expected_kernel_key) == kernels_iter->second.end()) -#ifdef PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || is_xpu_unsupport #endif ) {