From 328758404bce669cc7e1d3a148783109db280e2a Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 28 Apr 2019 12:18:39 +0800 Subject: [PATCH] make kernel creation support any PRECISION and any DATALAYOUT by default --- .../mir/variable_place_inference_pass_test.cc | 6 ------ paddle/fluid/lite/core/op_lite.cc | 18 +++++++++++++++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc index 394aa9ba87c..76d8f849172 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc @@ -39,15 +39,9 @@ TEST(variable_place_inference_pass, test) { Place{ TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW), }, - Place{ - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), - }, Place{ TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), }, - Place{ - TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny), - }, }); Program program(*desc, scope, places); diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index f093efb4edf..f214155981f 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -24,7 +24,7 @@ std::vector> OpLite::CreateKernels( std::vector> kernels; CHECK(!op_type_.empty()) << "op_type_ should be set first"; - for (auto place : places) { + auto pick_kernel = [&](const Place &place) { auto ks = KernelRegistry::Global().Create( (kernel_type.empty() ? op_type_ : kernel_type), place.target, place.precision, place.layout); @@ -32,6 +32,22 @@ std::vector> OpLite::CreateKernels( AttachKernel(it.get()); kernels.emplace_back(std::move(it)); } + }; + + std::set place_set; + for (auto place : places) { + place_set.insert(place); + // Pick kernels those support any Precision and any DataLayout + place.precision = PRECISION(kAny); + place_set.insert(place); + place.layout = DATALAYOUT(kAny); + place_set.insert(place); + } + + std::set targets; + for (auto place : place_set) { + pick_kernel(place); + targets.insert(place.target); } CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; -- GitLab