diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0f7012940d76b0f2846a11710e082db22204bbb9..6a9f5577705335d8185a158b15169d87bf2314d2 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1254,9 +1254,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, } #endif #ifdef PADDLE_WITH_XPU - if (kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) { + if ((kernel_iter == kernels.end() && + is_xpu_place(expected_kernel_key.place_) && + !paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) || + paddle::platform::is_in_xpu_black_list(type_)) { VLOG(3) << "missing XPU kernel: " << type_ << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 619d31c4f5b257d841ea3410d4f96067b34f320c..93f2fd38a7306417324f097761707a8e7ef2195a 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -131,9 +131,10 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); #ifdef PADDLE_WITH_XPU - if (kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) { + if ((kernel_iter == kernels.end() && + is_xpu_place(expected_kernel_key.place_) && + !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) || + paddle::platform::is_in_xpu_black_list(op.Type())) { VLOG(3) << "missing XPU kernel: " << op.Type() << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index efd25bc89294097c7d60c802395d4d4d05dcab7a..97c81568e673e84eabee9b94d6db08064574feab 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -70,7 +70,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) if(WITH_XPU) cc_library(xpu_info SRCS xpu/xpu_info.cc DEPS gflags glog enforce xpulib) -cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib) +cc_library(xpu_op_list SRCS xpu/xpu_op_list.cc DEPS gflags glog enforce xpulib device_context) endif() if(WITH_ASCEND) diff --git a/paddle/fluid/platform/xpu/xpu1_op_list.h b/paddle/fluid/platform/xpu/xpu1_op_list.h index 131525718cac759f9310831e47f29caff9945f5c..cdd60a856fbc90865ee29a1e3a1c371352b87618 100644 --- a/paddle/fluid/platform/xpu/xpu1_op_list.h +++ b/paddle/fluid/platform/xpu/xpu1_op_list.h @@ -55,25 +55,51 @@ XPUOpMap& get_kl1_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"affine_channel_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace())})}, {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"clip_by_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"coalesce_tensor", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"c_reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"c_allreduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logicalor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logicaland", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"logicalnot", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logicalor", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logicaland", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"logicalnot", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"depthwise_conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"depthwise_conv2d_grad", @@ -116,7 +142,11 @@ XPUOpMap& get_kl1_ops() { {"elementwise_min_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"fill_constant", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"gaussian_random", @@ -140,7 +170,11 @@ XPUOpMap& get_kl1_ops() { {"layer_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"load", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"load", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"log_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"log_loss_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, @@ -158,15 +192,20 @@ XPUOpMap& get_kl1_ops() { {"accuracy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"one_hot", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"one_hot", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, + {"one_hot_v2", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"lamb", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"pool2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"pool2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"range", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, @@ -175,30 +214,67 @@ XPUOpMap& get_kl1_ops() { {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_max_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reshape2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"reshape2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"shape", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"sign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softmax_with_cross_entropy", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"squeeze", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"squeeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"squeeze_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"squeeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"squeeze2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"top_k", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, @@ -212,12 +288,36 @@ XPUOpMap& get_kl1_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"uniform_random", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"unsqueeze", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"unsqueeze_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"unsqueeze2_grad", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::UINT8, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} // AddMore }; diff --git a/paddle/fluid/platform/xpu/xpu_op_list.cc b/paddle/fluid/platform/xpu/xpu_op_list.cc index b3349407942bd17e2e4597c3a60aec833e14f839..0c10436f397898b7e7f96bb9f23eb1249d58b31d 100644 --- a/paddle/fluid/platform/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/xpu/xpu_op_list.cc @@ -9,7 +9,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU +#include #include +#include #include "paddle/fluid/platform/xpu/xpu1_op_list.h" #include "paddle/fluid/platform/xpu/xpu2_op_list.h" @@ -19,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace platform { -bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) { +bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type) { auto& ops = get_kl1_ops(); auto v = get_xpu_version(BOOST_GET_CONST(platform::XPUPlace, type.place_).device); @@ -34,6 +36,45 @@ bool is_xpu_support_op(std::string op_name, const pOpKernelType& type) { return false; } +// ops_string contains op_list(e.g., 'mul,mul_grad'), parse the op string and +// insert op to op set +static void tokenize(const std::string& ops, char delim, + std::unordered_set* op_set) { + std::string::size_type beg = 0; + for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos; + ++end) { + op_set->insert(ops.substr(beg, end - beg)); + beg = end + 1; + } + + op_set->insert(ops.substr(beg)); +} + +bool is_in_xpu_black_list(const std::string& op_name) { + static bool inited = false; + static std::unordered_set xpu_black_list; + static std::mutex s_mtx; + if (!inited) { + std::lock_guard guard(s_mtx); + if (!inited) { + if (std::getenv("XPU_BLACK_LIST") != nullptr) { + std::string ops(std::getenv("XPU_BLACK_LIST")); + tokenize(ops, ',', &xpu_black_list); + } + inited = true; + VLOG(3) << "XPU Black List: "; + for (auto iter = xpu_black_list.begin(); iter != xpu_black_list.end(); + ++iter) { + VLOG(3) << *iter << " "; + } + } + } + if (xpu_black_list.find(op_name) != xpu_black_list.end()) { + return true; + } + return false; +} + } // namespace platform } // namespace paddle #endif diff --git a/paddle/fluid/platform/xpu/xpu_op_list.h b/paddle/fluid/platform/xpu/xpu_op_list.h index 487bc8ac48b66feefc6016632ffd5bfc0f09d56a..705f701e13634a30a9ff124fe1d9ee82db9b23fc 100644 --- a/paddle/fluid/platform/xpu/xpu_op_list.h +++ b/paddle/fluid/platform/xpu/xpu_op_list.h @@ -20,7 +20,8 @@ namespace platform { using pOpKernelType = paddle::framework::OpKernelType; -bool is_xpu_support_op(std::string op_name, const pOpKernelType& type); +bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type); +bool is_in_xpu_black_list(const std::string& op_name); } // namespace platform } // namespace paddle