diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index 7ec8d18b0e886948f4fb951e17875584413771db..e7087e063cbe8839716e3648d55cd25cc778f06f 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -72,7 +72,7 @@ bool InitDevices(const std::vector &devices) { LOG(WARNING) << "Not specified CPU device, create CPU by Default."; } platform::DeviceContextPool::Init(places); - framework::UseALL(); + // framework::UseALL(); return true; } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index f7a10ada809e6943e60c2d8cde05b8a9e2a7a2c2..66f07b6757fe1fe613e61ac66057be43ef5aced7 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) { paddle::framework::UseCPU(); op->Run(scope, cpu_place); - EXPECT_EQ(op_test_value, -20); + EXPECT_EQ(op_test_value, -9); // add cuda kernels paddle::framework::UseCUDA(); op->Run(scope, cuda_place); - EXPECT_EQ(op_test_value, -30); + EXPECT_EQ(op_test_value, -10); // use cudnn kernel paddle::framework::UseCUDNN(); op->Run(scope, cuda_place); - EXPECT_EQ(op_test_value, -40); + EXPECT_EQ(op_test_value, -20); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index fe8096835d6f4e01e4e8a190722d4eb4fdb4095f..35ebe48ba682f135b7f85edb3b2999db7c29e51a 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -495,6 +495,22 @@ void OperatorWithKernel::Run(const Scope& scope, ExecutionContext ctx(*this, scope, *dev_ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx); + OpKernelMap& kernels = kernels_iter->second; + + for (auto& candidate : kKernelPriority) { + auto candidate_key = + OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate), + expected_kernel_key.data_layout_, std::get<1>(candidate)); + + if ((candidate_key == expected_kernel_key) || + (kernels.count(candidate_key))) { + expected_kernel_key = candidate_key; + break; + } + } + + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + Scope& new_scope = scope.NewScope(); for (auto& var_name_item : this->Inputs()) { @@ -525,10 +541,10 @@ void OperatorWithKernel::Run(const Scope& scope, } } - OpKernelMap& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); - kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx)); + kernel_iter->second->Compute(ExecutionContext( + *this, new_scope, *pool.Get(expected_kernel_key.place_))); } proto::DataType OperatorWithKernel::IndicateDataType( diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 387d1e0a747f71d85826b52d140c2838112227f6..48c01f984f825208d911a06c6e48b802fa24aa0e 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -53,7 +53,7 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); + auto &dev_ctx = *pool.Get(src_item.place()); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait();