diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c6cc9befbd84698078b258bb661745d9be7c1282..d76e06bd4143e29635610826be570cf83fbe02b5 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -149,6 +149,48 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, kernel_signature_(std::move(kernel_signature)), phi_kernel_(phi_kernel) {} +#ifdef PADDLE_WITH_MLU + +static void tokenize(const std::string& ops, + char delim, + std::unordered_set* op_set) { + std::string::size_type beg = 0; + for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos; + ++end) { + op_set->insert(ops.substr(beg, end - beg)); + beg = end + 1; + } + + op_set->insert(ops.substr(beg)); +} + +static bool is_in_mlu_black_list(const std::string& op_name) { + static bool inited = false; + static std::unordered_set mlu_black_list; + static std::mutex s_mtx; + if (!inited) { + std::lock_guard guard(s_mtx); + if (!inited) { + if (std::getenv("MLU_BLACK_LIST") != nullptr) { + std::string ops(std::getenv("MLU_BLACK_LIST")); + tokenize(ops, ',', &mlu_black_list); + } + inited = true; + VLOG(3) << "MLU Black List: "; + for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end(); + ++iter) { + VLOG(3) << *iter << " "; + } + } + } + if (mlu_black_list.find(op_name) != mlu_black_list.end()) { + return true; + } + return false; +} + +#endif + template PreparedOp PrepareImpl( const NameVarMap& ins, @@ -212,6 +254,12 @@ PreparedOp PrepareImpl( paddle::platform::is_in_xpu_black_list(op.Type()); #endif +#ifdef PADDLE_WITH_MLU + if (is_in_mlu_black_list(op.Type())) { + expected_kernel_key.place_ = platform::CPUPlace(); + } +#endif + bool has_phi_kernel = false; const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type()); diff --git a/paddle/fluid/operators/strided_slice_op_mlu.cc b/paddle/fluid/operators/strided_slice_op_mlu.cc index 5800c167b0158b203c8c8ce8345d28616bc44408..6caf1ad5ad15ff5ff02cb4f7fb80b2a47a034c22 100644 --- a/paddle/fluid/operators/strided_slice_op_mlu.cc +++ b/paddle/fluid/operators/strided_slice_op_mlu.cc @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using Tensor = phi::DenseTensor; using Variable = framework::Variable; using LoDTensorArray = framework::LoDTensorArray; using DDim = framework::DDim;