diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc index ac822e3764ff44264ae6662b714abe6a8b7c2837..614ee990a9811ab74ceedb4fa000fa385698d679 100644 --- a/lite/core/arena/framework.cc +++ b/lite/core/arena/framework.cc @@ -59,6 +59,8 @@ void TestCase::CreateInstruction() { CHECK(it != kernels.end()) << "failed to create the kernel in " << place_.DebugString() << " with alias: " << alias_; + // reset final place + place_ = (*it)->place(); // prepare context (*it)->SetContext(std::move(ctx_)); instruction_.reset(new Instruction(op, std::move(*it))); diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index 0936a44a66e4777633b84dadf0a1dc049213faab..c76e369466a9b998b2ad6fde67b97117649fddc0 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -47,18 +47,19 @@ std::vector> OpLite::CreateKernels( return kernels; } - std::set place_set; - for (auto place : places) { - place_set.insert(place); - // Pick kernels those support any Precision and any DataLayout - place.precision = PRECISION(kAny); - place_set.insert(place); - place.layout = DATALAYOUT(kAny); - place_set.insert(place); + std::set expanded_places(places.begin(), places.end()); + for (auto &place : places) { + // Pick kernels those support any Precision and any DataLayout, For example: + // kARM,kFloat,kNCHW -> kARM,kFloat,kAny; kARM,kAny,kNCHW; kARM,kAny,kAny + expanded_places.insert( + Place(place.target, place.precision, DATALAYOUT(kAny))); + expanded_places.insert(Place(place.target, PRECISION(kAny), place.layout)); + expanded_places.insert( + Place(place.target, PRECISION(kAny), DATALAYOUT(kAny))); } std::set targets; - for (auto place : place_set) { + for (auto place : expanded_places) { pick_kernel(place); targets.insert(place.target); } diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 26ae22ce9d27cffcb6adf53ca16b01181edddf9e..bfa5c85522927a8767a3c9cc0488408faeec8194 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -91,7 +91,6 @@ add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(logical_compute_arm ARM extra SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(less_than_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(compare_compute_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/compare_compute.cc b/lite/kernels/arm/compare_compute.cc index 490d3c302871edbcbef30b2ebfb35a08ab754b53..709942a0d9f385e4ba55be32657633c0edc378cf 100644 --- a/lite/kernels/arm/compare_compute.cc +++ b/lite/kernels/arm/compare_compute.cc @@ -73,8 +73,6 @@ inline void get_mid_dims(const lite::DDim &x_dims, (*post) *= x_dims[i]; } } -template