未验证 提交 ba882657 编写于 作者: H hong 提交者: GitHub

Update op support gpu impl (#39386)

* find gpu kernel in pten factory; test=develop

* check in functional kernel first; test=develop
上级 196dbfc2
......@@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/ops/compat/signatures.h"
namespace pten {
......@@ -598,6 +599,17 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
}
bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
auto& kernel_factory = pten::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(pten::TransToPtenKernelName(op_type));
for (auto& kernel : kernel_key_map) {
if (platform::is_gpu_place(
pten::TransToFluidPlace(kernel.first.backend()))) {
return true;
}
}
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
......@@ -609,6 +621,7 @@ bool OpSupportGPU(const std::string& op_type) {
return true;
}
}
return false;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册