diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7550854709ac80c9f76d332cf80f2d0a3b1d54fb..747a17bd05e08d2362acabccd0efc6372d40e36b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1462,7 +1462,7 @@ bool OperatorWithKernel::SupportsKernelType( if (paddle::platform::is_xpu_place(kernel_type.place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type.data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(type_); @@ -1754,7 +1754,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (paddle::platform::is_xpu_place(kernel_type_->place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type_->data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(type_); @@ -1831,7 +1831,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (paddle::platform::is_xpu_place(kernel_type_->place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type_->data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(type_); @@ -1880,7 +1880,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, bool use_xpu_kp_kernel_rt = paddle::platform::is_xpu_place(kernel_type_->place_) && FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(kernel_type_->data_type_)); bool use_xpu_kp_kernel_debug = paddle::platform::is_xpu_place(kernel_type_->place_) && @@ -2285,7 +2285,7 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op( + paddle::platform::is_xpu_kp_support_op( type_, framework::TransToPhiDataType(expected_kernel_key.data_type_)); bool use_xpu_kp_kernel_debug = diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index df315ba97ec9a76bd8077adecb145e2f63d74a5c..61502d186e356b852f70328e95a963263b88f0c1 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -296,7 +296,7 @@ PreparedOp PrepareImpl( #ifdef PADDLE_WITH_XPU_KP if (expected_kernel_key.backend() == phi::Backend::XPU) { bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op( + FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op( op.Type(), expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = paddle::platform::is_in_xpu_kpwhite_list(op.Type()); @@ -369,8 +369,8 @@ PreparedOp PrepareImpl( bool use_xpu_kp_kernel_rt = expected_kernel_key.backend() == phi::Backend::XPU && FLAGS_run_kp_kernel && - paddle::platform::is_xpu_support_op(op.Type(), - expected_kernel_key.dtype()); + paddle::platform::is_xpu_kp_support_op(op.Type(), + expected_kernel_key.dtype()); bool use_xpu_kp_kernel_debug = expected_kernel_key.backend() == phi::Backend::XPU && paddle::platform::is_in_xpu_kpwhite_list(op.Type()); diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h deleted file mode 100644 index 4a4e370bf9aa4244199413de0a8e2112a18c28a4..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#ifdef PADDLE_WITH_XPU_KP -#include -#include -#include - -#include "paddle/fluid/framework/op_kernel_type.h" - -namespace paddle { -namespace platform { - -using vartype = paddle::framework::proto::VarType; -using pOpKernelType = paddle::framework::OpKernelType; -using XPUKernelSet = - std::unordered_set; -using XPUOpMap = std::unordered_map; - -XPUOpMap& get_kp_ops() { - static XPUOpMap s_xpu_kp_kernels{ - {"elementwise_add", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_div", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_sub", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_max", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_min", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_mul", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elementwise_floordiv", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - // activation op - {"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reciprocal", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"elu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"celu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"silu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logsigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softshrink", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"ceil", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"floor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"log1p", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"brelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"soft_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"softsign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu6", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_shrink", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"hard_sigmoid", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"thresholded_relu", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - // bitwise logical & compare - {"bitwise_and", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"bitwise_or", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"bitwise_not", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - {"bitwise_xor", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), - pOpKernelType(vartype::BOOL, XPUPlace())})}, - - {"logical_and", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"logical_or", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"logical_not", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"logical_xor", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - - {"less_than", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"less_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"greater_than", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"greater_equal", - XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, - // reduce op - // {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, - // XPUPlace())})}, - // {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, - // XPUPlace())})}, - // {"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - // {"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32, - // XPUPlace())})}, - {"pull_box_sparse", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"push_box_sparse", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_sync_calc_stream", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_sync_comm_stream", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"c_allreduce_sum", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - }; - - return s_xpu_kp_kernels; -} - -} // namespace platform -} // namespace paddle -#endif diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.cc b/paddle/fluid/platform/device/xpu/xpu_op_list.cc index 220a41c1b104258e98fb1f4501df78d5973f6925..1b12b9ebf946837b522bbd9780b914b42de10cd1 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" -#include "paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h" +#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h" #include "paddle/phi/backends/xpu/xpu_op_list.h" namespace paddle { @@ -39,6 +39,7 @@ static void tokenize(const std::string& ops, } #ifdef PADDLE_WITH_XPU_KP + bool is_in_xpu_kpwhite_list(const std::string& op_name) { static bool inited = false; static std::unordered_set xpu_kpwhite_list; @@ -63,6 +64,37 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) { } return false; } + +XPUOpListMap get_xpu_kp_op_list(phi::backends::xpu::XPUVersion version) { + auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 + ? phi::backends::xpu::get_kl1_ops() + : phi::backends::xpu::get_kp_ops(); + XPUOpListMap res; + for (auto& op : ops) { + std::vector op_types; + for (auto& item : op.second) { + op_types.push_back( + static_cast(phi::TransToProtoVarType(item))); + } + res[op.first] = std::move(op_types); + } + return res; +} + +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version) { + auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 + ? phi::backends::xpu::get_kl1_ops() + : phi::backends::xpu::get_kp_ops(); + std::vector res; + if (ops.find(op_name) != ops.end()) { + auto& dtypes = ops[op_name]; + for (auto& type : dtypes) { + res.push_back(static_cast(phi::TransToProtoVarType(type))); + } + } + return res; +} #endif std::vector get_xpu_op_support_type( diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.h b/paddle/fluid/platform/device/xpu/xpu_op_list.h index 3da4e7b190c4174b02cb52b8993f3eea82efd20f..8804021b108126934499271fc81aeb165a957a99 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.h @@ -17,21 +17,26 @@ limitations under the License. */ #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/phi/backends/xpu/xpu_op_list.h" +#ifdef PADDLE_WITH_XPU_KP +#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h" +#endif namespace paddle { namespace platform { using phi::backends::xpu::is_in_xpu_black_list; using phi::backends::xpu::is_xpu_support_op; +using vartype = paddle::framework::proto::VarType; +using XPUOpListMap = + std::unordered_map>; #ifdef PADDLE_WITH_XPU_KP +using phi::backends::xpu::is_xpu_kp_support_op; +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version); bool is_in_xpu_kpwhite_list(const std::string& op_name); #endif -using vartype = paddle::framework::proto::VarType; -using XPUOpListMap = - std::unordered_map>; - std::vector get_xpu_op_support_type( const std::string& op_name, phi::backends::xpu::XPUVersion version); XPUOpListMap get_xpu_op_list(phi::backends::xpu::XPUVersion version); diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 2f8a5a1c44c1c5f292bc161d59f5c0e85eb23e5c..a5a52f8af7e5c78d5a38487ca18bfb95772218ef 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -460,7 +460,7 @@ void BindPlace(pybind11::module &m) { // NOLINT #ifdef PADDLE_WITH_XPU_KP m.def("get_xpu_device_op_support_types", [](const std::string &op_name, phi::backends::xpu::XPUVersion version) { - return platform::get_xpu_op_support_type(op_name, version); + return platform::get_xpu_kp_op_support_type(op_name, version); }); #else m.def("get_xpu_device_op_support_types", diff --git a/paddle/phi/backends/xpu/xpu_op_kpfirst_list.h b/paddle/phi/backends/xpu/xpu_op_kpfirst_list.h new file mode 100644 index 0000000000000000000000000000000000000000..878a2153d421e05dae75e212b933464bf5d9641d --- /dev/null +++ b/paddle/phi/backends/xpu/xpu_op_kpfirst_list.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#ifdef PADDLE_WITH_XPU_KP +#include +#include +#include +#include "paddle/phi/common/data_type.h" + +namespace phi { +namespace backends { +namespace xpu { + +using XPUKernelSet = std::unordered_set; +using XPUOpMap = std::unordered_map; + +XPUOpMap& get_kp_ops() { + static XPUOpMap s_xpu_kp_kernels{ + {"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_floordiv", XPUKernelSet({phi::DataType::INT32})}, + // activation op + {"exp", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softplus", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, + {"log", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"celu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, + {"square", XPUKernelSet({phi::DataType::FLOAT32})}, + {"silu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"logsigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softshrink", XPUKernelSet({phi::DataType::FLOAT32})}, + {"ceil", XPUKernelSet({phi::DataType::FLOAT32})}, + {"floor", XPUKernelSet({phi::DataType::FLOAT32})}, + {"log1p", XPUKernelSet({phi::DataType::FLOAT32})}, + {"brelu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"soft_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softsign", XPUKernelSet({phi::DataType::FLOAT32})}, + {"relu6", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_shrink", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"swish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"thresholded_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + // bitwise logical & compare + {"bitwise_and", + XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + {"bitwise_or", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + {"bitwise_not", + XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + {"bitwise_xor", + XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})}, + + {"logical_and", XPUKernelSet({phi::DataType::INT32})}, + {"logical_or", XPUKernelSet({phi::DataType::INT32})}, + {"logical_not", XPUKernelSet({phi::DataType::INT32})}, + {"logical_xor", XPUKernelSet({phi::DataType::INT32})}, + + {"less_than", XPUKernelSet({phi::DataType::INT32})}, + {"less_equal", XPUKernelSet({phi::DataType::INT32})}, + {"greater_than", XPUKernelSet({phi::DataType::INT32})}, + {"greater_equal", XPUKernelSet({phi::DataType::INT32})}, + {"equal", XPUKernelSet({phi::DataType::INT32})}, + {"not_equal", XPUKernelSet({phi::DataType::INT32})}, + {"pull_box_sparse", XPUKernelSet({phi::DataType::FLOAT32})}, + {"push_box_sparse", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, + }; + + return s_xpu_kp_kernels; +} + +} // namespace xpu +} // namespace backends +} // namespace phi +#endif diff --git a/paddle/phi/backends/xpu/xpu_op_list.cc b/paddle/phi/backends/xpu/xpu_op_list.cc index edcf81183be426107465f5972106b78cd983218e..86529c3bd7c102e2734fb85aefb5f60b9cbcbd79 100644 --- a/paddle/phi/backends/xpu/xpu_op_list.cc +++ b/paddle/phi/backends/xpu/xpu_op_list.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include #include "paddle/phi/backends/xpu/xpu_info.h" +#ifdef PADDLE_WITH_XPU_KP +#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h" +#endif + namespace phi { namespace backends { namespace xpu { @@ -60,6 +64,23 @@ bool is_in_xpu_black_list(const std::string& fluid_op_name) { return false; } +#ifdef PADDLE_WITH_XPU_KP +bool is_xpu_kp_support_op(const std::string& fluid_op_name, + const phi::DataType type) { + if (is_in_xpu_black_list(fluid_op_name)) return false; + auto v = get_xpu_version(0); + auto& ops = (v == phi::backends::xpu::XPUVersion::XPU1) + ? phi::backends::xpu::get_kl1_ops() + : phi::backends::xpu::get_kp_ops(); + + if (ops.find(fluid_op_name) != ops.end() && + ops[fluid_op_name].find(type) != ops[fluid_op_name].end()) { + return true; + } + return false; +} +#endif + bool is_xpu_support_op(const std::string& fluid_op_name, const phi::DataType type) { if (is_in_xpu_black_list(fluid_op_name)) return false; diff --git a/paddle/phi/backends/xpu/xpu_op_list.h b/paddle/phi/backends/xpu/xpu_op_list.h index 17b2f1c6965a63bb868c09330f8f3534611d2db5..975a5d02b16b2bf87e3e3bbf4121f014ee8f0bf6 100644 --- a/paddle/phi/backends/xpu/xpu_op_list.h +++ b/paddle/phi/backends/xpu/xpu_op_list.h @@ -26,6 +26,11 @@ using XPUOpMap = std::unordered_map; XPUOpMap& get_kl1_ops(); XPUOpMap& get_kl2_ops(); +#ifdef PADDLE_WITH_XPU_KP +bool is_xpu_kp_support_op(const std::string& fluid_op_name, + const phi::DataType type); +#endif + bool is_in_xpu_black_list(const std::string& fluid_op_name); bool is_xpu_support_op(const std::string& fluid_op_name, const phi::DataType type); diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 096207e06fcde1e030618f0c18e75b742481c0d3..6af7ac7b9b74cc60073c32b5f0be446aa7b5ae88 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -16,7 +16,7 @@ #include "glog/logging.h" #include "paddle/phi/core/enforce.h" -#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU) #include "paddle/phi/backends/xpu/xpu_op_list.h" #include "paddle/phi/core/compat/convert_utils.h" #endif @@ -28,7 +28,7 @@ DECLARE_int32(low_precision_op_list); DECLARE_bool(enable_api_kernel_fallback); - +DECLARE_bool(run_kp_kernel); namespace phi { const static Kernel empty_kernel; // NOLINT @@ -189,7 +189,32 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_name, KernelSelectionErrorMessage(kernel_name, kernel_key))); -#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU_KP) + auto fluid_op_name = TransToFluidOpName(kernel_name); + bool has_kp_kernel = false; + VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); + bool is_xpu_kp_supported = phi::backends::xpu::is_xpu_kp_support_op( + fluid_op_name, kernel_key.dtype()); + // Check in xpu_kp + if (is_xpu_kp_supported && FLAGS_run_kp_kernel) { + auto kernel_key_kp = + KernelKey(Backend::KPS, kernel_key.layout(), kernel_key.dtype()); + auto kernel_iter_kp = iter->second.find(kernel_key_kp); + has_kp_kernel = (kernel_iter_kp != iter->second.end()); + if (has_kp_kernel) { + kernel_key = kernel_key_kp; + kernel_iter = kernel_iter_kp; + } + } + // check in xpu + bool xpu_unsupport = + !phi::backends::xpu::is_xpu_support_op(fluid_op_name, kernel_key.dtype()); + VLOG(6) << "Current KernelKey is " << kernel_key; + // Fall back to CPU, when FLAGS_enable_api_kernel_fallback is true and op + // was unregistered in xpu and kp + if (FLAGS_enable_api_kernel_fallback && + (kernel_iter == iter->second.end() || (xpu_unsupport && !has_kp_kernel)) +#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) || !phi::backends::xpu::is_xpu_support_op(TransToFluidOpName(kernel_name), diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index bacfce613f66cdfb4b3c3a3922e7388ef8a6cb36..4d984acd0c35250be3e7ce62512ccf11ae769d13 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -62,8 +62,10 @@ struct KernelArgsParseFunctor { #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || arg_type == std::type_index(typeid(const GPUContext&))) { -#elif defined(PADDLE_WITH_XPU) +#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || arg_type == std::type_index(typeid(const XPUContext&))) { +#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_KP) + || arg_type == std::type_index(typeid(const KPSContext&))) { #elif defined(PADDLE_WITH_CUSTOM_DEVICE) || arg_type == std::type_index(typeid(const CustomContext&))) { #else diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index f7e323d285cfc461fbba380c0c11f2a338f03167..290d1fd4fb6dfbc53e6767a845ce53bd2d97d4ec 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -283,7 +283,7 @@ PD_REGISTER_KERNEL(elementwise_pow, #endif -#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) +#if defined(PADDLE_WITH_XPU_KP) && defined(PADDLE_WITH_XPU) PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {} PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {} PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {} @@ -368,8 +368,7 @@ PD_REGISTER_KERNEL(subtract, phi::dtype::float16, int64_t) {} #endif - -#if defined PADDLE_WITH_XPU +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) PD_REGISTER_KERNEL(floor_divide, XPU, ALL_LAYOUT,