未验证 提交 21beef91 编写于 作者: Q QingshuChen 提交者: GitHub

support kunlun black list and add kl1 op (#34605)

* support kunlun black list and add kl1 op

* xpu_op_list add device_context dependence
上级 fa16c21f
...@@ -1254,9 +1254,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -1254,9 +1254,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (kernel_iter == kernels.end() && if ((kernel_iter == kernels.end() &&
is_xpu_place(expected_kernel_key.place_) && is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) { !paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) ||
paddle::platform::is_in_xpu_black_list(type_)) {
VLOG(3) << "missing XPU kernel: " << type_ VLOG(3) << "missing XPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
......
...@@ -131,9 +131,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -131,9 +131,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (kernel_iter == kernels.end() && if ((kernel_iter == kernels.end() &&
is_xpu_place(expected_kernel_key.place_) && is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) { !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "missing XPU kernel: " << op.Type() VLOG(3) << "missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
......
...@@ -70,7 +70,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) ...@@ -70,7 +70,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
if(WITH_XPU) if(WITH_XPU)
cc_library(xpu_info SRCS xpu/xpu_info.cc DEPS gflags glog enforce xpulib) cc_library(xpu_info SRCS xpu/xpu_info.cc DEPS gflags glog enforce xpulib)
cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib) cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib device_context)
endif() endif()
if(WITH_ASCEND) if(WITH_ASCEND)
......
...@@ -55,25 +55,51 @@ XPUOpMap& get_kl1_ops() { ...@@ -55,25 +55,51 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"affine_channel_grad", {"affine_channel_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad", {"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"clip_by_norm", {"clip_by_norm",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"coalesce_tensor", {"coalesce_tensor",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"c_reduce_sum", {"c_reduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_allreduce_sum", {"c_allreduce_sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
{"logicaland", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::INT8, XPUPlace()),
{"logicalnot", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d", {"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad", {"depthwise_conv2d_grad",
...@@ -116,7 +142,11 @@ XPUOpMap& get_kl1_ops() { ...@@ -116,7 +142,11 @@ XPUOpMap& get_kl1_ops() {
{"elementwise_min_grad", {"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"fill_constant", {"fill_constant",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"gaussian_random", {"gaussian_random",
...@@ -140,7 +170,11 @@ XPUOpMap& get_kl1_ops() { ...@@ -140,7 +170,11 @@ XPUOpMap& get_kl1_ops() {
{"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"layer_norm_grad", {"layer_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"load", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"load", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_loss_grad", {"log_loss_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -158,15 +192,20 @@ XPUOpMap& get_kl1_ops() { ...@@ -158,15 +192,20 @@ XPUOpMap& get_kl1_ops() {
{"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
{"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"range", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad", {"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -175,30 +214,67 @@ XPUOpMap& get_kl1_ops() { ...@@ -175,30 +214,67 @@ XPUOpMap& get_kl1_ops() {
{"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_max_grad", {"reduce_max_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"reshape2_grad", {"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad", {"roi_align_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"shape", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"sign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_grad", {"softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softmax_with_cross_entropy", {"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze_grad", {"squeeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2_grad", {"squeeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"top_k", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"top_k", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -212,12 +288,36 @@ XPUOpMap& get_kl1_ops() { ...@@ -212,12 +288,36 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"uniform_random", {"uniform_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze_grad", {"unsqueeze_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
{"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2_grad", {"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} {"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
// AddMore // AddMore
}; };
......
...@@ -9,7 +9,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <mutex>
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/platform/xpu/xpu1_op_list.h" #include "paddle/fluid/platform/xpu/xpu1_op_list.h"
#include "paddle/fluid/platform/xpu/xpu2_op_list.h" #include "paddle/fluid/platform/xpu/xpu2_op_list.h"
...@@ -19,7 +21,7 @@ limitations under the License. */ ...@@ -19,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) { bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type) {
auto& ops = get_kl1_ops(); auto& ops = get_kl1_ops();
auto v = auto v =
get_xpu_version(BOOST_GET_CONST(platform::XPUPlace, type.place_).device); get_xpu_version(BOOST_GET_CONST(platform::XPUPlace, type.place_).device);
...@@ -34,6 +36,45 @@ bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) { ...@@ -34,6 +36,45 @@ bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) {
return false; return false;
} }
// ops_string contains op_list(e.g., 'mul,mul_grad'), parse the op string and
// insert op to op set
static void tokenize(const std::string& ops, char delim,
std::unordered_set<std::string>* op_set) {
std::string::size_type beg = 0;
for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
++end) {
op_set->insert(ops.substr(beg, end - beg));
beg = end + 1;
}
op_set->insert(ops.substr(beg));
}
bool is_in_xpu_black_list(const std::string& op_name) {
static bool inited = false;
static std::unordered_set<std::string> xpu_black_list;
static std::mutex s_mtx;
if (!inited) {
std::lock_guard<std::mutex> guard(s_mtx);
if (!inited) {
if (std::getenv("XPU_BLACK_LIST") != nullptr) {
std::string ops(std::getenv("XPU_BLACK_LIST"));
tokenize(ops, ',', &xpu_black_list);
}
inited = true;
VLOG(3) << "XPU Black List: ";
for (auto iter = xpu_black_list.begin(); iter != xpu_black_list.end();
++iter) {
VLOG(3) << *iter << " ";
}
}
}
if (xpu_black_list.find(op_name) != xpu_black_list.end()) {
return true;
}
return false;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#endif #endif
...@@ -20,7 +20,8 @@ namespace platform { ...@@ -20,7 +20,8 @@ namespace platform {
using pOpKernelType = paddle::framework::OpKernelType; using pOpKernelType = paddle::framework::OpKernelType;
bool is_xpu_support_op(std::string op_name, const pOpKernelType& type); bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type);
bool is_in_xpu_black_list(const std::string& op_name);
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册