未验证 提交 d762e07e 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #7294 from jacquesqiao/add-back-priority

add back priority
...@@ -72,7 +72,7 @@ bool InitDevices(const std::vector<std::string> &devices) { ...@@ -72,7 +72,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
LOG(WARNING) << "Not specified CPU device, create CPU by Default."; LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
} }
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
framework::UseALL(); // framework::UseALL();
return true; return true;
} }
......
...@@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) { ...@@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU(); paddle::framework::UseCPU();
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -20); EXPECT_EQ(op_test_value, -9);
// add cuda kernels // add cuda kernels
paddle::framework::UseCUDA(); paddle::framework::UseCUDA();
op->Run(scope, cuda_place); op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -30); EXPECT_EQ(op_test_value, -10);
// use cudnn kernel // use cudnn kernel
paddle::framework::UseCUDNN(); paddle::framework::UseCUDNN();
op->Run(scope, cuda_place); op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -40); EXPECT_EQ(op_test_value, -20);
} }
...@@ -495,6 +495,22 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -495,6 +495,22 @@ void OperatorWithKernel::Run(const Scope& scope,
ExecutionContext ctx(*this, scope, *dev_ctx); ExecutionContext ctx(*this, scope, *dev_ctx);
auto expected_kernel_key = this->GetExpectedKernelType(ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx);
OpKernelMap& kernels = kernels_iter->second;
for (auto& candidate : kKernelPriority) {
auto candidate_key =
OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate),
expected_kernel_key.data_layout_, std::get<1>(candidate));
if ((candidate_key == expected_kernel_key) ||
(kernels.count(candidate_key))) {
expected_kernel_key = candidate_key;
break;
}
}
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
Scope& new_scope = scope.NewScope(); Scope& new_scope = scope.NewScope();
for (auto& var_name_item : this->Inputs()) { for (auto& var_name_item : this->Inputs()) {
...@@ -525,10 +541,10 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -525,10 +541,10 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
} }
OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx)); kernel_iter->second->Compute(ExecutionContext(
*this, new_scope, *pool.Get(expected_kernel_key.place_)));
} }
proto::DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(
......
...@@ -53,7 +53,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -53,7 +53,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(src_item.place());
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait(); dev_ctx.Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册