diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc index 394aa9ba87c2ad12f8acff31fabaa434b3ec982a..76d8f8491729bfa8ce7ba7e4eaf270b0cf18807b 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc @@ -39,15 +39,9 @@ TEST(variable_place_inference_pass, test) { Place{ TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW), }, - Place{ - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), - }, Place{ TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), }, - Place{ - TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny), - }, }); Program program(*desc, scope, places); diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index f093efb4edf172425849272ac544047a86bb0a23..f214155981fdb613524a9be4772856b22f5ae0ce 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -24,7 +24,7 @@ std::vector> OpLite::CreateKernels( std::vector> kernels; CHECK(!op_type_.empty()) << "op_type_ should be set first"; - for (auto place : places) { + auto pick_kernel = [&](const Place &place) { auto ks = KernelRegistry::Global().Create( (kernel_type.empty() ? op_type_ : kernel_type), place.target, place.precision, place.layout); @@ -32,6 +32,22 @@ std::vector> OpLite::CreateKernels( AttachKernel(it.get()); kernels.emplace_back(std::move(it)); } + }; + + std::set place_set; + for (auto place : places) { + place_set.insert(place); + // Pick kernels those support any Precision and any DataLayout + place.precision = PRECISION(kAny); + place_set.insert(place); + place.layout = DATALAYOUT(kAny); + place_set.insert(place); + } + + std::set targets; + for (auto place : place_set) { + pick_kernel(place); + targets.insert(place.target); } CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;