提交 ca90356b 编写于 作者: Q qiaolongfei

add back priority

上级 0f353ab4
......@@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU();
op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -20);
EXPECT_EQ(op_test_value, -9);
// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -30);
EXPECT_EQ(op_test_value, -10);
// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -40);
EXPECT_EQ(op_test_value, -20);
}
......@@ -474,6 +474,20 @@ void OperatorWithKernel::Run(const Scope& scope,
ExecutionContext ctx(*this, scope, *dev_ctx);
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
OpKernelMap& kernels = kernels_iter->second;
for (auto& candidate : kKernelPriority) {
auto candidate_key =
OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate),
expected_kernel_key.data_layout_, std::get<1>(candidate));
if ((candidate_key == expected_kernel_key) ||
(kernels.count(candidate_key))) {
expected_kernel_key = candidate_key;
break;
}
}
Scope& new_scope = scope.NewScope();
for (auto& var_name_item : this->Inputs()) {
......@@ -504,7 +518,6 @@ void OperatorWithKernel::Run(const Scope& scope,
}
}
OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册