diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f23a266ef03641bc8f8d273b15ab4982e377cb03..ad01adf1a25b9d65c6ca85bc7e7a40d4b1fd0198 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1456,7 +1456,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { kernel_iter = kernels.find(expected_kernel_key); } #endif -#ifdef PADDLE_WITH_XPU + +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) if (platform::is_xpu_place(expected_kernel_key.place_) && (kernel_iter == kernels.end() || !paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || @@ -1470,17 +1471,36 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { #endif #ifdef PADDLE_WITH_XPU_KP - bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && - paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key); - bool use_xpu_kp_kernel_debug = - paddle::platform::is_in_xpu_kpwhite_list(type_); - if (platform::is_xpu_place(expected_kernel_key.place_) && - (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) { - expected_kernel_key.library_type_ = LibraryType::kKP; - kernel_iter = kernels.find(expected_kernel_key); - VLOG(3) << "using XPU KP kernel: " << type_ - << ", using_kernel_key:" << expected_kernel_key; + if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + bool use_xpu_kp_kernel_rt = + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_in_xpu_kpwhite_list(type_); + if (use_xpu_kp_kernel_rt) { + VLOG(3) << "xpu_kp using rt mode "; + } + if (use_xpu_kp_kernel_debug) { + VLOG(3) << "xpu_kp using debug mode "; + } + bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + expected_kernel_key.library_type_ = LibraryType::kKP; + kernel_iter = kernels.find(expected_kernel_key); + VLOG(3) << "using XPU KP kernel: " << type_ + << ", using_kernel_key:" << expected_kernel_key; + } + bool is_xpu_unsupport = + (!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(type_)); + if (!is_xpu_kp_support && + (kernel_iter == kernels.end() || is_xpu_unsupport)) { + VLOG(3) << "missing XPU kernel: " << type_ + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } } #endif diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index bae49fb381a475dd8227d1dc855a6db28c9cd273..a427b9b8199116098d149689961cedf14e86e5e1 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -234,7 +234,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); -#ifdef PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && (kernel_iter == kernels.end() || is_xpu_unsupport)) { VLOG(3) << "missing XPU kernel: " << op.Type() @@ -243,29 +243,36 @@ PreparedOp PrepareImpl(const NameVarMap& ins, expected_kernel_key.place_ = platform::CPUPlace(); kernel_iter = kernels.find(expected_kernel_key); } - #endif #ifdef PADDLE_WITH_XPU_KP - expected_kernel_key.place_ = platform::XPUPlace(); - bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && - paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); - bool use_xpu_kp_kernel_debug = - paddle::platform::is_in_xpu_kpwhite_list(op.Type()); - if (use_xpu_kp_kernel_rt) { - VLOG(3) << "xpu_kp using rt mode "; - } - if (use_xpu_kp_kernel_debug) { - VLOG(3) << "xpu_kp using debug mode "; - } - if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && - (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) { - expected_kernel_key.place_ = platform::XPUPlace(); - expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; - kernel_iter = kernels.find(expected_kernel_key); - VLOG(3) << "using XPU KP kernel: " << op.Type() - << ", using_kernel_key:" << expected_kernel_key; + if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { + bool use_xpu_kp_kernel_rt = + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_in_xpu_kpwhite_list(op.Type()); + if (use_xpu_kp_kernel_rt) { + VLOG(3) << "xpu_kp using rt mode "; + } + if (use_xpu_kp_kernel_debug) { + VLOG(3) << "xpu_kp using debug mode "; + } + bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + kernel_iter = kernels.find(expected_kernel_key); + VLOG(3) << "using XPU KP kernel: " << op.Type() + << ", using_kernel_key:" << expected_kernel_key; + } + if (!is_xpu_kp_support && + (kernel_iter == kernels.end() || is_xpu_unsupport)) { + VLOG(3) << "missing XPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } } #endif