未验证 提交 21beef91 编写于 作者: Q QingshuChen 提交者: GitHub

support kunlun black list and add kl1 op (#34605)

* support kunlun black list and add kl1 op

* xpu_op_list add device_context dependence
上级 fa16c21f
......@@ -1254,9 +1254,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
}
#endif
#ifdef PADDLE_WITH_XPU
if (kernel_iter == kernels.end() &&
if ((kernel_iter == kernels.end() &&
is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) {
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) ||
paddle::platform::is_in_xpu_black_list(type_)) {
VLOG(3) << "missing XPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
......
......@@ -131,9 +131,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU
if (kernel_iter == kernels.end() &&
if ((kernel_iter == kernels.end() &&
is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) {
!paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
......
......@@ -70,7 +70,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
if(WITH_XPU)
cc_library(xpu_info SRCS xpu/xpu_info.cc DEPS gflags glog enforce xpulib)
cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib)
cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib device_context)
endif()
if(WITH_ASCEND)
......
......@@ -55,25 +55,51 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"affine_channel_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"clip_by_norm",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"coalesce_tensor",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"c_reduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_allreduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicaland", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalnot", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad",
......@@ -116,7 +142,11 @@ XPUOpMap& get_kl1_ops() {
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"fill_constant",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gaussian_random",
......@@ -140,7 +170,11 @@ XPUOpMap& get_kl1_ops() {
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"load", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"load", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......@@ -158,15 +192,20 @@ XPUOpMap& get_kl1_ops() {
{"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"range", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......@@ -175,30 +214,67 @@ XPUOpMap& get_kl1_ops() {
{"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"sign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"top_k", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......@@ -212,12 +288,36 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"uniform_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
// AddMore
};
......
......@@ -9,7 +9,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <mutex>
#include <string>
#include <unordered_set>
#include "paddle/fluid/platform/xpu/xpu1_op_list.h"
#include "paddle/fluid/platform/xpu/xpu2_op_list.h"
......@@ -19,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace platform {
bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) {
bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type) {
auto& ops = get_kl1_ops();
auto v =
get_xpu_version(BOOST_GET_CONST(platform::XPUPlace, type.place_).device);
......@@ -34,6 +36,45 @@ bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) {
return false;
}
// ops_string contains op_list(e.g., 'mul,mul_grad'), parse the op string and
// insert op to op set
static void tokenize(const std::string& ops, char delim,
std::unordered_set<std::string>* op_set) {
std::string::size_type beg = 0;
for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
++end) {
op_set->insert(ops.substr(beg, end - beg));
beg = end + 1;
}
op_set->insert(ops.substr(beg));
}
bool is_in_xpu_black_list(const std::string& op_name) {
static bool inited = false;
static std::unordered_set<std::string> xpu_black_list;
static std::mutex s_mtx;
if (!inited) {
std::lock_guard<std::mutex> guard(s_mtx);
if (!inited) {
if (std::getenv("XPU_BLACK_LIST") != nullptr) {
std::string ops(std::getenv("XPU_BLACK_LIST"));
tokenize(ops, ',', &xpu_black_list);
}
inited = true;
VLOG(3) << "XPU Black List: ";
for (auto iter = xpu_black_list.begin(); iter != xpu_black_list.end();
++iter) {
VLOG(3) << *iter << " ";
}
}
}
if (xpu_black_list.find(op_name) != xpu_black_list.end()) {
return true;
}
return false;
}
} // namespace platform
} // namespace paddle
#endif
......@@ -20,7 +20,8 @@ namespace platform {
using pOpKernelType = paddle::framework::OpKernelType;
bool is_xpu_support_op(std::string op_name, const pOpKernelType& type);
bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type);
bool is_in_xpu_black_list(const std::string& op_name);
} // namespace platform
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册