From 6ef3f2ce185c588b450c42966d9b504867607025 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Fri, 24 Feb 2023 16:30:12 +0800 Subject: [PATCH] Fix KP operator Kernel selection error (#50178) --- paddle/fluid/framework/operator.cc | 10 +- paddle/fluid/imperative/prepared_operator.cc | 6 +- .../platform/device/xpu/xpu_op_kpfirst_list.h | 140 ------------------ .../fluid/platform/device/xpu/xpu_op_list.cc | 34 ++++- .../fluid/platform/device/xpu/xpu_op_list.h | 13 +- paddle/fluid/pybind/place.cc | 2 +- paddle/phi/backends/xpu/xpu_op_kpfirst_list.h | 95 ++++++++++++ paddle/phi/backends/xpu/xpu_op_list.cc | 21 +++ paddle/phi/backends/xpu/xpu_op_list.h | 5 + paddle/phi/core/kernel_factory.cc | 31 +++- paddle/phi/core/kernel_registry.h | 4 +- paddle/phi/kernels/elementwise_kernel.cc | 5 +- 12 files changed, 205 insertions(+), 161 deletions(-) delete mode 100644 paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h create mode 100644 paddle/phi/backends/xpu/xpu_op_kpfirst_list.h diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7550854709a..747a17bd05e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1462,7 +1462,7 @@ bool OperatorWithKernel::SupportsKernelType( if (paddle::platform::is_xpu_place(kernel_type.place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type.data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(type_); @@ -1754,7 +1754,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (paddle::platform::is_xpu_place(kernel_type_->place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type_->data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(type_); @@ -1831,7 +1831,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (paddle::platform::is_xpu_place(kernel_type_->place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type_->data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(type_); @@ -1880,7 +1880,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, bool use_xpu_kp_kernel_rt = paddle::platform::is_xpu_place(kernel_type_->place_) && FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type_->data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_xpu_place(kernel_type_->place_) && @@ -2285,7 +2285,7 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(expected_kernel_key.data_type_)); bool use_xpu_kp_kernel_debug = diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index df315ba97ec..61502d186e3 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -296,7 +296,7 @@ PreparedOp PrepareImpl( #ifdef PADDLE_WITH_XPU_KP if (expected_kernel_key.backend() == phi::Backend::XPU) { bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op( + FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op( op.Type(), expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(op.Type()); @@ -369,8 +369,8 @@ PreparedOp PrepareImpl( bool use_xpu_kp_kernel_rt = expected_kernel_key.backend() == phi::Backend::XPU && FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op(op.Type(), - expected_kernel_key.dtype()); + paddle::platform::is_xpu_kp_support_op(op.Type(), + expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = expected_kernel_key.backend() == phi::Backend::XPU && paddle::platform::is_in_xpu_kpwhite_list(op.Type()); diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h deleted file mode 100644 index 4a4e370bf9a..00000000000 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#ifdef PADDLE_WITH_XPU_KP -#include -#include -#include - -#include "paddle/fluid/framework/op_kernel_type.h" - -namespace paddle { -namespace platform { - -using vartype = paddle::framework::proto::VarType; -using pOpKernelType = paddle::framework::OpKernelType; -using XPUKernelSet = - std::unordered_set; -using XPUOpMap = std::unordered_map; - -XPUOpMap& get_kp_ops() { - static XPUOpMap s_xpu_kp_kernels{ - {"elementwise_add", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_div", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_sub", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_max", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_min", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_mul", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_floordiv", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - // activation op - {"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reciprocal", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"celu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"silu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logsigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softshrink", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"ceil", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"floor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"log1p", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"brelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"soft_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softsign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu6", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_shrink", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_sigmoid", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"thresholded_relu", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - // bitwise logical & compare - {"bitwise_and", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"bitwise_or", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"bitwise_not", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"bitwise_xor", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - - {"logical_and", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"logical_or", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"logical_not", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"logical_xor", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - - {"less_than", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"less_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"greater_than", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"greater_equal", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - // reduce op - // {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, - // XPUPlace())})}, - // {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, - // XPUPlace())})}, - // {"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - {"pull_box_sparse", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"push_box_sparse", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_sync_calc_stream", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_sync_comm_stream", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_allreduce_sum", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - }; - - return s_xpu_kp_kernels; -} - -} // namespace platform -} // namespace paddle -#endif diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.cc b/paddle/fluid/platform/device/xpu/xpu_op_list.cc index 220a41c1b10..1b12b9ebf94 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" -#include "paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h" +#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h" #include "paddle/phi/backends/xpu/xpu_op_list.h" namespace paddle { @@ -39,6 +39,7 @@ static void tokenize(const std::string& ops, } #ifdef PADDLE_WITH_XPU_KP + bool is_in_xpu_kpwhite_list(const std::string& op_name) { static bool inited = false; static std::unordered_set xpu_kpwhite_list; @@ -63,6 +64,37 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) { } return false; } + +XPUOpListMap get_xpu_kp_op_list(phi::backends::xpu::XPUVersion version) { + auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 + ? phi::backends::xpu::get_kl1_ops() + : phi::backends::xpu::get_kp_ops(); + XPUOpListMap res; + for (auto& op : ops) { + std::vector op_types; + for (auto& item : op.second) { + op_types.push_back( + static_cast(phi::TransToProtoVarType(item))); + } + res[op.first] = std::move(op_types); + } + return res; +} + +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version) { + auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 + ? phi::backends::xpu::get_kl1_ops() + : phi::backends::xpu::get_kp_ops(); + std::vector res; + if (ops.find(op_name) != ops.end()) { + auto& dtypes = ops[op_name]; + for (auto& type : dtypes) { + res.push_back(static_cast(phi::TransToProtoVarType(type))); + } + } + return res; +} #endif std::vector get_xpu_op_support_type( diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.h b/paddle/fluid/platform/device/xpu/xpu_op_list.h index 3da4e7b190c..8804021b108 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.h @@ -17,21 +17,26 @@ limitations under the License. */ #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/phi/backends/xpu/xpu_op_list.h" +#ifdef PADDLE_WITH_XPU_KP +#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h" +#endif namespace paddle { namespace platform { using phi::backends::xpu::is_in_xpu_black_list; using phi::backends::xpu::is_xpu_support_op; +using vartype = paddle::framework::proto::VarType; +using XPUOpListMap = + std::unordered_map>; #ifdef PADDLE_WITH_XPU_KP +using phi::backends::xpu::is_xpu_kp_support_op; +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version); bool is_in_xpu_kpwhite_list(const std::string& op_name); #endif -using vartype = paddle::framework::proto::VarType; -using XPUOpListMap = - std::unordered_map>; - std::vector get_xpu_op_support_type( const std::string& op_name, phi::backends::xpu::XPUVersion version); XPUOpListMap get_xpu_op_list(phi::backends::xpu::XPUVersion version); diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 2f8a5a1c44c..a5a52f8af7e 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -460,7 +460,7 @@ void BindPlace(pybind11::module &m) { // NOLINT #ifdef PADDLE_WITH_XPU_KP m.def("get_xpu_device_op_support_types", [](const std::string &op_name, phi::backends::xpu::XPUVersion version) { - return platform::get_xpu_op_support_type(op_name, version); + return platform::get_xpu_kp_op_support_type(op_name, version); }); #else m.def("get_xpu_device_op_support_types", diff --git a/paddle/phi/backends/xpu/xpu_op_kpfirst_list.h b/paddle/phi/backends/xpu/xpu_op_kpfirst_list.h new file mode 100644 index 00000000000..878a2153d42 --- /dev/null +++ b/paddle/phi/backends/xpu/xpu_op_kpfirst_list.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#ifdef PADDLE_WITH_XPU_KP +#include +#include +#include +#include "paddle/phi/common/data_type.h" + +namespace phi { +namespace backends { +namespace xpu { + +using XPUKernelSet = std::unordered_set; +using XPUOpMap = std::unordered_map; + +XPUOpMap& get_kp_ops() { + static XPUOpMap s_xpu_kp_kernels{ + {"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_floordiv", XPUKernelSet({phi::DataType::INT32})}, + // activation op + {"exp", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softplus", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, + {"log", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"celu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, + {"square", XPUKernelSet({phi::DataType::FLOAT32})}, + {"silu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"logsigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softshrink", XPUKernelSet({phi::DataType::FLOAT32})}, + {"ceil", XPUKernelSet({phi::DataType::FLOAT32})}, + {"floor", XPUKernelSet({phi::DataType::FLOAT32})}, + {"log1p", XPUKernelSet({phi::DataType::FLOAT32})}, + {"brelu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"soft_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softsign", XPUKernelSet({phi::DataType::FLOAT32})}, + {"relu6", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_shrink", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"swish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"thresholded_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + // bitwise logical & compare + {"bitwise_and", + XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + {"bitwise_or", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + {"bitwise_not", + XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + {"bitwise_xor", + XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + + {"logical_and", XPUKernelSet({phi::DataType::INT32})}, + {"logical_or", XPUKernelSet({phi::DataType::INT32})}, + {"logical_not", XPUKernelSet({phi::DataType::INT32})}, + {"logical_xor", XPUKernelSet({phi::DataType::INT32})}, + + {"less_than", XPUKernelSet({phi::DataType::INT32})}, + {"less_equal", XPUKernelSet({phi::DataType::INT32})}, + {"greater_than", XPUKernelSet({phi::DataType::INT32})}, + {"greater_equal", XPUKernelSet({phi::DataType::INT32})}, + {"equal", XPUKernelSet({phi::DataType::INT32})}, + {"not_equal", XPUKernelSet({phi::DataType::INT32})}, + {"pull_box_sparse", XPUKernelSet({phi::DataType::FLOAT32})}, + {"push_box_sparse", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, + }; + + return s_xpu_kp_kernels; +} + +} // namespace xpu +} // namespace backends +} // namespace phi +#endif diff --git a/paddle/phi/backends/xpu/xpu_op_list.cc b/paddle/phi/backends/xpu/xpu_op_list.cc index edcf81183be..86529c3bd7c 100644 --- a/paddle/phi/backends/xpu/xpu_op_list.cc +++ b/paddle/phi/backends/xpu/xpu_op_list.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include #include "paddle/phi/backends/xpu/xpu_info.h" +#ifdef PADDLE_WITH_XPU_KP +#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h" +#endif + namespace phi { namespace backends { namespace xpu { @@ -60,6 +64,23 @@ bool is_in_xpu_black_list(const std::string& fluid_op_name) { return false; } +#ifdef PADDLE_WITH_XPU_KP +bool is_xpu_kp_support_op(const std::string& fluid_op_name, + const phi::DataType type) { + if (is_in_xpu_black_list(fluid_op_name)) return false; + auto v = get_xpu_version(0); + auto& ops = (v == phi::backends::xpu::XPUVersion::XPU1) + ? phi::backends::xpu::get_kl1_ops() + : phi::backends::xpu::get_kp_ops(); + + if (ops.find(fluid_op_name) != ops.end() && + ops[fluid_op_name].find(type) != ops[fluid_op_name].end()) { + return true; + } + return false; +} +#endif + bool is_xpu_support_op(const std::string& fluid_op_name, const phi::DataType type) { if (is_in_xpu_black_list(fluid_op_name)) return false; diff --git a/paddle/phi/backends/xpu/xpu_op_list.h b/paddle/phi/backends/xpu/xpu_op_list.h index 17b2f1c6965..975a5d02b16 100644 --- a/paddle/phi/backends/xpu/xpu_op_list.h +++ b/paddle/phi/backends/xpu/xpu_op_list.h @@ -26,6 +26,11 @@ using XPUOpMap = std::unordered_map; XPUOpMap& get_kl1_ops(); XPUOpMap& get_kl2_ops(); +#ifdef PADDLE_WITH_XPU_KP +bool is_xpu_kp_support_op(const std::string& fluid_op_name, + const phi::DataType type); +#endif + bool is_in_xpu_black_list(const std::string& fluid_op_name); bool is_xpu_support_op(const std::string& fluid_op_name, const phi::DataType type); diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 096207e06fc..6af7ac7b9b7 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -16,7 +16,7 @@ #include "glog/logging.h" #include "paddle/phi/core/enforce.h" -#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU) #include "paddle/phi/backends/xpu/xpu_op_list.h" #include "paddle/phi/core/compat/convert_utils.h" #endif @@ -28,7 +28,7 @@ DECLARE_int32(low_precision_op_list); DECLARE_bool(enable_api_kernel_fallback); - +DECLARE_bool(run_kp_kernel); namespace phi { const static Kernel empty_kernel; // NOLINT @@ -189,7 +189,32 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_name, KernelSelectionErrorMessage(kernel_name, kernel_key))); -#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU_KP) + auto fluid_op_name = TransToFluidOpName(kernel_name); + bool has_kp_kernel = false; + VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); + bool is_xpu_kp_supported = phi::backends::xpu::is_xpu_kp_support_op( + fluid_op_name, kernel_key.dtype()); + // Check in xpu_kp + if (is_xpu_kp_supported && FLAGS_run_kp_kernel) { + auto kernel_key_kp = + KernelKey(Backend::KPS, kernel_key.layout(), kernel_key.dtype()); + auto kernel_iter_kp = iter->second.find(kernel_key_kp); + has_kp_kernel = (kernel_iter_kp != iter->second.end()); + if (has_kp_kernel) { + kernel_key = kernel_key_kp; + kernel_iter = kernel_iter_kp; + } + } + // check in xpu + bool xpu_unsupport = + !phi::backends::xpu::is_xpu_support_op(fluid_op_name, kernel_key.dtype()); + VLOG(6) << "Current KernelKey is " << kernel_key; + // Fall back to CPU, when FLAGS_enable_api_kernel_fallback is true and op + // was unregistered in xpu and kp + if (FLAGS_enable_api_kernel_fallback && + (kernel_iter == iter->second.end() || (xpu_unsupport && !has_kp_kernel)) +#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) || !phi::backends::xpu::is_xpu_support_op(TransToFluidOpName(kernel_name), diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index bacfce613f6..4d984acd0c3 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -62,8 +62,10 @@ struct KernelArgsParseFunctor { #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || arg_type == std::type_index(typeid(const GPUContext&))) { -#elif defined(PADDLE_WITH_XPU) +#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || arg_type == std::type_index(typeid(const XPUContext&))) { +#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_KP) + || arg_type == std::type_index(typeid(const KPSContext&))) { #elif defined(PADDLE_WITH_CUSTOM_DEVICE) || arg_type == std::type_index(typeid(const CustomContext&))) { #else diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index f7e323d285c..290d1fd4fb6 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -283,7 +283,7 @@ PD_REGISTER_KERNEL(elementwise_pow, #endif -#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) +#if defined(PADDLE_WITH_XPU_KP) && defined(PADDLE_WITH_XPU) PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {} PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {} PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {} @@ -368,8 +368,7 @@ PD_REGISTER_KERNEL(subtract, phi::dtype::float16, int64_t) {} #endif - -#if defined PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) PD_REGISTER_KERNEL(floor_divide, XPU, ALL_LAYOUT, -- GitLab