未验证 提交 6ef3f2ce 编写于 作者: N niuliling123 提交者: GitHub

Fix KP operator Kernel selection error (#50178)

上级 6664a232
......@@ -1462,7 +1462,7 @@ bool OperatorWithKernel::SupportsKernelType(
if (paddle::platform::is_xpu_place(kernel_type.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(
paddle::platform::is_xpu_kp_support_op(
type_, framework::TransToPhiDataType(kernel_type.data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
......@@ -1754,7 +1754,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(
paddle::platform::is_xpu_kp_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
......@@ -1831,7 +1831,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(
paddle::platform::is_xpu_kp_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
......@@ -1880,7 +1880,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(
paddle::platform::is_xpu_kp_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
......@@ -2285,7 +2285,7 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(
paddle::platform::is_xpu_kp_support_op(
type_,
framework::TransToPhiDataType(expected_kernel_key.data_type_));
bool use_xpu_kp_kernel_debug =
......
......@@ -296,7 +296,7 @@ PreparedOp PrepareImpl(
#ifdef PADDLE_WITH_XPU_KP
if (expected_kernel_key.backend() == phi::Backend::XPU) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op(
FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(
op.Type(), expected_kernel_key.dtype());
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
......@@ -369,8 +369,8 @@ PreparedOp PrepareImpl(
bool use_xpu_kp_kernel_rt =
expected_kernel_key.backend() == phi::Backend::XPU &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key.dtype());
paddle::platform::is_xpu_kp_support_op(op.Type(),
expected_kernel_key.dtype());
bool use_xpu_kp_kernel_debug =
expected_kernel_key.backend() == phi::Backend::XPU &&
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU_KP
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/op_kernel_type.h"
namespace paddle {
namespace platform {
using vartype = paddle::framework::proto::VarType;
using pOpKernelType = paddle::framework::OpKernelType;
using XPUKernelSet =
std::unordered_set<pOpKernelType, paddle::framework::OpKernelType::Hash>;
using XPUOpMap = std::unordered_map<std::string, XPUKernelSet>;
XPUOpMap& get_kp_ops() {
static XPUOpMap s_xpu_kp_kernels{
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
// activation op
{"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"leaky_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reciprocal", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"celu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"silu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logsigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softshrink", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"ceil", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"floor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log1p", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"brelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"soft_relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softsign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu6", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_shrink", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_sigmoid",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"thresholded_relu",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// bitwise logical & compare
{"bitwise_and",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"bitwise_or",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"bitwise_not",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"bitwise_xor",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"logical_and",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"logical_or", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"logical_not",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"logical_xor",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"less_than", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"less_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"greater_than",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"not_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
// reduce op
// {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
// {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
// {"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
// {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
// {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
// {"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL,
// XPUPlace())})},
// {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL,
// XPUPlace())})},
// {"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
// {"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32,
// XPUPlace())})},
{"pull_box_sparse",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"push_box_sparse",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_sync_calc_stream",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_sync_comm_stream",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_allreduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
};
return s_xpu_kp_kernels;
}
} // namespace platform
} // namespace paddle
#endif
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h"
#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h"
#include "paddle/phi/backends/xpu/xpu_op_list.h"
namespace paddle {
......@@ -39,6 +39,7 @@ static void tokenize(const std::string& ops,
}
#ifdef PADDLE_WITH_XPU_KP
bool is_in_xpu_kpwhite_list(const std::string& op_name) {
static bool inited = false;
static std::unordered_set<std::string> xpu_kpwhite_list;
......@@ -63,6 +64,37 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) {
}
return false;
}
XPUOpListMap get_xpu_kp_op_list(phi::backends::xpu::XPUVersion version) {
auto& ops = version == phi::backends::xpu::XPUVersion::XPU1
? phi::backends::xpu::get_kl1_ops()
: phi::backends::xpu::get_kp_ops();
XPUOpListMap res;
for (auto& op : ops) {
std::vector<vartype::Type> op_types;
for (auto& item : op.second) {
op_types.push_back(
static_cast<vartype::Type>(phi::TransToProtoVarType(item)));
}
res[op.first] = std::move(op_types);
}
return res;
}
std::vector<vartype::Type> get_xpu_kp_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version) {
auto& ops = version == phi::backends::xpu::XPUVersion::XPU1
? phi::backends::xpu::get_kl1_ops()
: phi::backends::xpu::get_kp_ops();
std::vector<vartype::Type> res;
if (ops.find(op_name) != ops.end()) {
auto& dtypes = ops[op_name];
for (auto& type : dtypes) {
res.push_back(static_cast<vartype::Type>(phi::TransToProtoVarType(type)));
}
}
return res;
}
#endif
std::vector<vartype::Type> get_xpu_op_support_type(
......
......@@ -17,21 +17,26 @@ limitations under the License. */
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/phi/backends/xpu/xpu_op_list.h"
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h"
#endif
namespace paddle {
namespace platform {
using phi::backends::xpu::is_in_xpu_black_list;
using phi::backends::xpu::is_xpu_support_op;
using vartype = paddle::framework::proto::VarType;
using XPUOpListMap =
std::unordered_map<std::string, std::vector<vartype::Type>>;
#ifdef PADDLE_WITH_XPU_KP
using phi::backends::xpu::is_xpu_kp_support_op;
std::vector<vartype::Type> get_xpu_kp_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version);
bool is_in_xpu_kpwhite_list(const std::string& op_name);
#endif
using vartype = paddle::framework::proto::VarType;
using XPUOpListMap =
std::unordered_map<std::string, std::vector<vartype::Type>>;
std::vector<vartype::Type> get_xpu_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version);
XPUOpListMap get_xpu_op_list(phi::backends::xpu::XPUVersion version);
......
......@@ -460,7 +460,7 @@ void BindPlace(pybind11::module &m) { // NOLINT
#ifdef PADDLE_WITH_XPU_KP
m.def("get_xpu_device_op_support_types",
[](const std::string &op_name, phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_op_support_type(op_name, version);
return platform::get_xpu_kp_op_support_type(op_name, version);
});
#else
m.def("get_xpu_device_op_support_types",
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU_KP
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace backends {
namespace xpu {
using XPUKernelSet = std::unordered_set<phi::DataType>;
using XPUOpMap = std::unordered_map<std::string, XPUKernelSet>;
XPUOpMap& get_kp_ops() {
static XPUOpMap s_xpu_kp_kernels{
{"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_floordiv", XPUKernelSet({phi::DataType::INT32})},
// activation op
{"exp", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"softplus", XPUKernelSet({phi::DataType::FLOAT32})},
{"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})},
{"log", XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"elu", XPUKernelSet({phi::DataType::FLOAT32})},
{"celu", XPUKernelSet({phi::DataType::FLOAT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32})},
{"square", XPUKernelSet({phi::DataType::FLOAT32})},
{"silu", XPUKernelSet({phi::DataType::FLOAT32})},
{"logsigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"softshrink", XPUKernelSet({phi::DataType::FLOAT32})},
{"ceil", XPUKernelSet({phi::DataType::FLOAT32})},
{"floor", XPUKernelSet({phi::DataType::FLOAT32})},
{"log1p", XPUKernelSet({phi::DataType::FLOAT32})},
{"brelu", XPUKernelSet({phi::DataType::FLOAT32})},
{"soft_relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"softsign", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_shrink", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32})},
{"thresholded_relu", XPUKernelSet({phi::DataType::FLOAT32})},
// bitwise logical & compare
{"bitwise_and",
XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})},
{"bitwise_or", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})},
{"bitwise_not",
XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})},
{"bitwise_xor",
XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL})},
{"logical_and", XPUKernelSet({phi::DataType::INT32})},
{"logical_or", XPUKernelSet({phi::DataType::INT32})},
{"logical_not", XPUKernelSet({phi::DataType::INT32})},
{"logical_xor", XPUKernelSet({phi::DataType::INT32})},
{"less_than", XPUKernelSet({phi::DataType::INT32})},
{"less_equal", XPUKernelSet({phi::DataType::INT32})},
{"greater_than", XPUKernelSet({phi::DataType::INT32})},
{"greater_equal", XPUKernelSet({phi::DataType::INT32})},
{"equal", XPUKernelSet({phi::DataType::INT32})},
{"not_equal", XPUKernelSet({phi::DataType::INT32})},
{"pull_box_sparse", XPUKernelSet({phi::DataType::FLOAT32})},
{"push_box_sparse", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
};
return s_xpu_kp_kernels;
}
} // namespace xpu
} // namespace backends
} // namespace phi
#endif
......@@ -16,6 +16,10 @@ limitations under the License. */
#include <unordered_set>
#include "paddle/phi/backends/xpu/xpu_info.h"
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_op_kpfirst_list.h"
#endif
namespace phi {
namespace backends {
namespace xpu {
......@@ -60,6 +64,23 @@ bool is_in_xpu_black_list(const std::string& fluid_op_name) {
return false;
}
#ifdef PADDLE_WITH_XPU_KP
bool is_xpu_kp_support_op(const std::string& fluid_op_name,
const phi::DataType type) {
if (is_in_xpu_black_list(fluid_op_name)) return false;
auto v = get_xpu_version(0);
auto& ops = (v == phi::backends::xpu::XPUVersion::XPU1)
? phi::backends::xpu::get_kl1_ops()
: phi::backends::xpu::get_kp_ops();
if (ops.find(fluid_op_name) != ops.end() &&
ops[fluid_op_name].find(type) != ops[fluid_op_name].end()) {
return true;
}
return false;
}
#endif
bool is_xpu_support_op(const std::string& fluid_op_name,
const phi::DataType type) {
if (is_in_xpu_black_list(fluid_op_name)) return false;
......
......@@ -26,6 +26,11 @@ using XPUOpMap = std::unordered_map<std::string, XPUKernelSet>;
XPUOpMap& get_kl1_ops();
XPUOpMap& get_kl2_ops();
#ifdef PADDLE_WITH_XPU_KP
bool is_xpu_kp_support_op(const std::string& fluid_op_name,
const phi::DataType type);
#endif
bool is_in_xpu_black_list(const std::string& fluid_op_name);
bool is_xpu_support_op(const std::string& fluid_op_name,
const phi::DataType type);
......
......@@ -16,7 +16,7 @@
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU)
#include "paddle/phi/backends/xpu/xpu_op_list.h"
#include "paddle/phi/core/compat/convert_utils.h"
#endif
......@@ -28,7 +28,7 @@
DECLARE_int32(low_precision_op_list);
DECLARE_bool(enable_api_kernel_fallback);
DECLARE_bool(run_kp_kernel);
namespace phi {
const static Kernel empty_kernel; // NOLINT
......@@ -189,7 +189,32 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_name,
KernelSelectionErrorMessage(kernel_name, kernel_key)));
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU_KP)
auto fluid_op_name = TransToFluidOpName(kernel_name);
bool has_kp_kernel = false;
VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name);
bool is_xpu_kp_supported = phi::backends::xpu::is_xpu_kp_support_op(
fluid_op_name, kernel_key.dtype());
// Check in xpu_kp
if (is_xpu_kp_supported && FLAGS_run_kp_kernel) {
auto kernel_key_kp =
KernelKey(Backend::KPS, kernel_key.layout(), kernel_key.dtype());
auto kernel_iter_kp = iter->second.find(kernel_key_kp);
has_kp_kernel = (kernel_iter_kp != iter->second.end());
if (has_kp_kernel) {
kernel_key = kernel_key_kp;
kernel_iter = kernel_iter_kp;
}
}
// check in xpu
bool xpu_unsupport =
!phi::backends::xpu::is_xpu_support_op(fluid_op_name, kernel_key.dtype());
VLOG(6) << "Current KernelKey is " << kernel_key;
// Fall back to CPU, when FLAGS_enable_api_kernel_fallback is true and op
// was unregistered in xpu and kp
if (FLAGS_enable_api_kernel_fallback &&
(kernel_iter == iter->second.end() || (xpu_unsupport && !has_kp_kernel))
#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name);
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) ||
!phi::backends::xpu::is_xpu_support_op(TransToFluidOpName(kernel_name),
......
......@@ -62,8 +62,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
|| arg_type == std::type_index(typeid(const GPUContext&))) {
#elif defined(PADDLE_WITH_XPU)
#elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| arg_type == std::type_index(typeid(const XPUContext&))) {
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_KP)
|| arg_type == std::type_index(typeid(const KPSContext&))) {
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
|| arg_type == std::type_index(typeid(const CustomContext&))) {
#else
......
......@@ -283,7 +283,7 @@ PD_REGISTER_KERNEL(elementwise_pow,
#endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
#if defined(PADDLE_WITH_XPU_KP) && defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {}
PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {}
PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {}
......@@ -368,8 +368,7 @@ PD_REGISTER_KERNEL(subtract,
phi::dtype::float16,
int64_t) {}
#endif
#if defined PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(floor_divide,
XPU,
ALL_LAYOUT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册