提交 32875840 编写于 作者: S Superjomn

make kernel creation support any PRECISION and any DATALAYOUT

by default
上级 0245a2dd
......@@ -39,15 +39,9 @@ TEST(variable_place_inference_pass, test) {
Place{
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW),
},
Place{
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny),
},
Place{
TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW),
},
Place{
TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny),
},
});
Program program(*desc, scope, places);
......
......@@ -24,7 +24,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
std::vector<std::unique_ptr<KernelBase>> kernels;
CHECK(!op_type_.empty()) << "op_type_ should be set first";
for (auto place : places) {
auto pick_kernel = [&](const Place &place) {
auto ks = KernelRegistry::Global().Create(
(kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision, place.layout);
......@@ -32,6 +32,22 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
AttachKernel(it.get());
kernels.emplace_back(std::move(it));
}
};
std::set<Place> place_set;
for (auto place : places) {
place_set.insert(place);
// Pick kernels those support any Precision and any DataLayout
place.precision = PRECISION(kAny);
place_set.insert(place);
place.layout = DATALAYOUT(kAny);
place_set.insert(place);
}
std::set<TargetType> targets;
for (auto place : place_set) {
pick_kernel(place);
targets.insert(place.target);
}
CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册