diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5ab14a1daba226f02e92db4d0d172bf2ac549646..0f558b46872a2d8e94cbe7141119b0d81e696e90 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -32,6 +32,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/ops/compat/signatures.h" namespace pten { @@ -598,6 +599,17 @@ std::vector ExecutionContext::MultiOutput( } bool OpSupportGPU(const std::string& op_type) { + // check in new Function kernel first + auto& kernel_factory = pten::KernelFactory::Instance(); + auto kernel_key_map = + kernel_factory.SelectKernelMap(pten::TransToPtenKernelName(op_type)); + for (auto& kernel : kernel_key_map) { + if (platform::is_gpu_place( + pten::TransToFluidPlace(kernel.first.backend()))) { + return true; + } + } + auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto it = all_kernels.find(op_type); if (it == all_kernels.end()) { @@ -609,6 +621,7 @@ bool OpSupportGPU(const std::string& op_type) { return true; } } + return false; }