未验证 提交 db7639ca 编写于 作者: S Shibo Tao 提交者: GitHub

optimize register mechanism (#3745)

* refactor register mechanism, current so size: 1.20MB. test=develop

* fix KernelRegistry::Global().Create. test=develop

* fix cpplint errors. test=develop

* fix test_subgraph_pass bug. test=develop

* register kernel with target,precision,datalayout combination. test=develop

* fix test_paddle_api no op found bug. test=develop

* enhance comment

* fix lite/kernels/arm/elementwise_compute_test.cc. test=develop

* fix code style

* revert format of unchanged files. test=develop

* fix code format according to cpplint 1.5.1. test=develop

* remove redundant include header. test=develop
上级 732bb91b
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
#include "lite/api/light_api.h" #include "lite/api/light_api.h"
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
#include "lite/api/paddle_api.h" #include "lite/api/paddle_api.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/io.h" #include "lite/utils/io.h"
DEFINE_string(model_dir, "", ""); DEFINE_string(model_dir, "", "");
namespace paddle { namespace paddle {
......
...@@ -13,8 +13,12 @@ ...@@ -13,8 +13,12 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include "lite/api/paddle_api.h" #include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/test_helper.h" #include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
......
...@@ -17,277 +17,5 @@ ...@@ -17,277 +17,5 @@
#include <set> #include <set>
namespace paddle { namespace paddle {
namespace lite { namespace lite {} // namespace lite
const std::map<std::string, std::string> &GetOp2PathDict() {
return OpKernelInfoCollector::Global().GetOp2PathDict();
}
std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type,
TargetType target,
PrecisionType precision,
DataLayoutType layout) {
Place place{target, precision, layout};
VLOG(5) << "creating " << op_type << " kernel for " << place.DebugString();
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
case DATALAYOUT(kNHWC): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNHWC)>(op_type); \
case DATALAYOUT(kImageDefault): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageDefault)>(op_type); \
case DATALAYOUT(kImageFolder): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageFolder)>(op_type); \
case DATALAYOUT(kImageNW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageNW)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kFP16): \
CREATE_KERNEL1(target__, kFP16); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
case PRECISION(kInt32): \
CREATE_KERNEL1(target__, kInt32); \
case PRECISION(kInt64): \
CREATE_KERNEL1(target__, kInt64); \
default: \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
}
switch (target) {
case TARGET(kHost): {
CREATE_KERNEL(kHost);
} break;
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_X86)
case TARGET(kX86): {
CREATE_KERNEL(kX86);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_CUDA)
case TARGET(kCUDA): {
CREATE_KERNEL(kCUDA);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_ARM)
case TARGET(kARM): {
CREATE_KERNEL(kARM);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_OPENCL)
case TARGET(kOpenCL): {
CREATE_KERNEL(kOpenCL);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_NPU)
case TARGET(kNPU): {
CREATE_KERNEL(kNPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_APU)
case TARGET(kAPU): {
CREATE_KERNEL(kAPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_XPU)
case TARGET(kXPU): {
CREATE_KERNEL(kXPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_FPGA)
case TARGET(kFPGA): {
CREATE_KERNEL(kFPGA);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_BM)
case TARGET(kBM): {
CREATE_KERNEL(kBM);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_MLU)
case TARGET(kMLU): {
CREATE_KERNEL(kMLU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_RKNPU)
case TARGET(kRKNPU): {
CREATE_KERNEL(kRKNPU);
} break;
#endif
default:
CHECK(false) << "not supported kernel target " << TargetToStr(target);
}
#undef CREATE_KERNEL
return std::list<std::unique_ptr<KernelBase>>();
}
KernelRegistry::KernelRegistry() : registries_() {
#define INIT_FOR(target__, precision__, layout__) \
registries_[std::make_tuple(TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__))] \
.set<KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_CUDA)
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kFloat, kNHWC);
INIT_FOR(kCUDA, kInt8, kNCHW);
INIT_FOR(kCUDA, kFP16, kNCHW);
INIT_FOR(kCUDA, kFP16, kNHWC);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
INIT_FOR(kCUDA, kInt64, kNCHW);
INIT_FOR(kCUDA, kInt64, kNHWC);
INIT_FOR(kCUDA, kInt32, kNCHW);
INIT_FOR(kCUDA, kInt32, kNHWC);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_MLU)
INIT_FOR(kMLU, kFloat, kNHWC);
INIT_FOR(kMLU, kFloat, kNCHW);
INIT_FOR(kMLU, kFP16, kNHWC);
INIT_FOR(kMLU, kFP16, kNCHW);
INIT_FOR(kMLU, kInt8, kNHWC);
INIT_FOR(kMLU, kInt8, kNCHW);
INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW);
#endif
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kAny, kNHWC);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kHost, kBool, kNCHW);
INIT_FOR(kHost, kBool, kNHWC);
INIT_FOR(kHost, kBool, kAny);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
INIT_FOR(kHost, kFloat, kAny);
INIT_FOR(kHost, kFP16, kNCHW);
INIT_FOR(kHost, kFP16, kNHWC);
INIT_FOR(kHost, kFP16, kAny);
INIT_FOR(kHost, kInt8, kNCHW);
INIT_FOR(kHost, kInt8, kNHWC);
INIT_FOR(kHost, kInt8, kAny);
INIT_FOR(kHost, kInt16, kNCHW);
INIT_FOR(kHost, kInt16, kNHWC);
INIT_FOR(kHost, kInt16, kAny);
INIT_FOR(kHost, kInt32, kNCHW);
INIT_FOR(kHost, kInt32, kNHWC);
INIT_FOR(kHost, kInt32, kAny);
INIT_FOR(kHost, kInt64, kNCHW);
INIT_FOR(kHost, kInt64, kNHWC);
INIT_FOR(kHost, kInt64, kAny);
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_X86)
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_ARM)
INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kFloat, kNHWC);
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kInt8, kNHWC);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW);
INIT_FOR(kARM, kInt64, kNCHW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_OPENCL)
INIT_FOR(kOpenCL, kFloat, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNHWC);
INIT_FOR(kOpenCL, kAny, kNCHW);
INIT_FOR(kOpenCL, kAny, kNHWC);
INIT_FOR(kOpenCL, kFloat, kAny);
INIT_FOR(kOpenCL, kInt8, kNCHW);
INIT_FOR(kOpenCL, kAny, kAny);
INIT_FOR(kOpenCL, kFP16, kNCHW);
INIT_FOR(kOpenCL, kFP16, kNHWC);
INIT_FOR(kOpenCL, kFP16, kImageDefault);
INIT_FOR(kOpenCL, kFP16, kImageFolder);
INIT_FOR(kOpenCL, kFP16, kImageNW);
INIT_FOR(kOpenCL, kFloat, kImageDefault);
INIT_FOR(kOpenCL, kFloat, kImageFolder);
INIT_FOR(kOpenCL, kFloat, kImageNW);
INIT_FOR(kOpenCL, kAny, kImageDefault);
INIT_FOR(kOpenCL, kAny, kImageFolder);
INIT_FOR(kOpenCL, kAny, kImageNW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_NPU)
INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kFloat, kNHWC);
INIT_FOR(kNPU, kInt8, kNCHW);
INIT_FOR(kNPU, kInt8, kNHWC);
INIT_FOR(kNPU, kAny, kNCHW);
INIT_FOR(kNPU, kAny, kNHWC);
INIT_FOR(kNPU, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_APU)
INIT_FOR(kAPU, kInt8, kNCHW);
INIT_FOR(kXPU, kFloat, kNCHW);
INIT_FOR(kXPU, kInt8, kNCHW);
INIT_FOR(kXPU, kAny, kNCHW);
INIT_FOR(kXPU, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_FPGA)
INIT_FOR(kFPGA, kFP16, kNHWC);
INIT_FOR(kFPGA, kFP16, kAny);
INIT_FOR(kFPGA, kFloat, kNHWC);
INIT_FOR(kFPGA, kAny, kNHWC);
INIT_FOR(kFPGA, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_BM)
INIT_FOR(kBM, kFloat, kNCHW);
INIT_FOR(kBM, kInt8, kNCHW);
INIT_FOR(kBM, kAny, kNCHW);
INIT_FOR(kBM, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_RKNPU)
INIT_FOR(kRKNPU, kFloat, kNCHW);
INIT_FOR(kRKNPU, kInt8, kNCHW);
INIT_FOR(kRKNPU, kAny, kNCHW);
INIT_FOR(kRKNPU, kAny, kAny);
#endif
#undef INIT_FOR
}
KernelRegistry &KernelRegistry::Global() {
static auto *x = new KernelRegistry;
return *x;
}
} // namespace lite
} // namespace paddle } // namespace paddle
此差异已折叠。
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/argmax_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstdlib> #include <cstdlib>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/argmax_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -66,9 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) { ...@@ -66,9 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
} }
TEST(argmax_arm, retrive_op) { TEST(argmax_arm, retrive_op) {
auto argmax = auto argmax = KernelRegistry::Global().Create("arg_max");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"arg_max");
ASSERT_FALSE(argmax.empty()); ASSERT_FALSE(argmax.empty());
ASSERT_TRUE(argmax.front()); ASSERT_TRUE(argmax.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/axpy_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/axpy_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -61,8 +63,7 @@ void axpy_compute_ref(const operators::AxpyParam& param) { ...@@ -61,8 +63,7 @@ void axpy_compute_ref(const operators::AxpyParam& param) {
} }
TEST(axpy_arm, retrive_op) { TEST(axpy_arm, retrive_op) {
auto axpy = auto axpy = KernelRegistry::Global().Create("axpy");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("axpy");
ASSERT_FALSE(axpy.empty()); ASSERT_FALSE(axpy.empty());
ASSERT_TRUE(axpy.front()); ASSERT_TRUE(axpy.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/batch_norm_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/batch_norm_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -78,9 +80,7 @@ void batch_norm_compute_ref(const operators::BatchNormParam& param) { ...@@ -78,9 +80,7 @@ void batch_norm_compute_ref(const operators::BatchNormParam& param) {
} }
TEST(batch_norm_arm, retrive_op) { TEST(batch_norm_arm, retrive_op) {
auto batch_norm = auto batch_norm = KernelRegistry::Global().Create("batch_norm");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"batch_norm");
ASSERT_FALSE(batch_norm.empty()); ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front()); ASSERT_TRUE(batch_norm.front());
} }
......
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/concat_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/kernels/arm/concat_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -221,8 +223,7 @@ TEST(concat_arm, compute_input_multi) { ...@@ -221,8 +223,7 @@ TEST(concat_arm, compute_input_multi) {
} }
TEST(concat, retrive_op) { TEST(concat, retrive_op) {
auto concat = auto concat = KernelRegistry::Global().Create("concat");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>("concat");
ASSERT_FALSE(concat.empty()); ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front()); ASSERT_TRUE(concat.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/decode_bboxes_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/decode_bboxes_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -115,9 +117,7 @@ void decode_bboxes_compute_ref(const operators::DecodeBboxesParam& param) { ...@@ -115,9 +117,7 @@ void decode_bboxes_compute_ref(const operators::DecodeBboxesParam& param) {
} }
TEST(decode_bboxes_arm, retrive_op) { TEST(decode_bboxes_arm, retrive_op) {
auto decode_bboxes = auto decode_bboxes = KernelRegistry::Global().Create("decode_bboxes");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"decode_bboxes");
ASSERT_FALSE(decode_bboxes.empty()); ASSERT_FALSE(decode_bboxes.empty());
ASSERT_TRUE(decode_bboxes.front()); ASSERT_TRUE(decode_bboxes.front());
} }
......
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/dropout_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/dropout_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -30,9 +32,7 @@ TEST(dropout_arm, init) { ...@@ -30,9 +32,7 @@ TEST(dropout_arm, init) {
} }
TEST(dropout, retrive_op) { TEST(dropout, retrive_op) {
auto dropout = auto dropout = KernelRegistry::Global().Create("dropout");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"dropout");
ASSERT_FALSE(dropout.empty()); ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front()); ASSERT_TRUE(dropout.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/elementwise_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/elementwise_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace arm { namespace arm {
TEST(elementwise_add_arm, retrive_op) { TEST(elementwise_add_arm, retrive_op) {
auto elementwise_add = auto elementwise_add = KernelRegistry::Global().Create("elementwise_add");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_add");
ASSERT_FALSE(elementwise_add.empty()); ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front()); ASSERT_TRUE(elementwise_add.front());
} }
...@@ -336,8 +336,7 @@ TEST(elementwise_add, compute) { ...@@ -336,8 +336,7 @@ TEST(elementwise_add, compute) {
TEST(fusion_elementwise_add_activation_arm, retrive_op) { TEST(fusion_elementwise_add_activation_arm, retrive_op) {
auto fusion_elementwise_add_activation = auto fusion_elementwise_add_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create("fusion_elementwise_add_activation");
"fusion_elementwise_add_activation");
ASSERT_FALSE(fusion_elementwise_add_activation.empty()); ASSERT_FALSE(fusion_elementwise_add_activation.empty());
ASSERT_TRUE(fusion_elementwise_add_activation.front()); ASSERT_TRUE(fusion_elementwise_add_activation.front());
} }
...@@ -435,9 +434,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) { ...@@ -435,9 +434,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
} }
TEST(elementwise_mul_arm, retrive_op) { TEST(elementwise_mul_arm, retrive_op) {
auto elementwise_mul = auto elementwise_mul = KernelRegistry::Global().Create("elementwise_mul");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_mul");
ASSERT_FALSE(elementwise_mul.empty()); ASSERT_FALSE(elementwise_mul.empty());
ASSERT_TRUE(elementwise_mul.front()); ASSERT_TRUE(elementwise_mul.front());
} }
...@@ -530,8 +527,7 @@ TEST(elementwise_mul, compute) { ...@@ -530,8 +527,7 @@ TEST(elementwise_mul, compute) {
TEST(fusion_elementwise_mul_activation_arm, retrive_op) { TEST(fusion_elementwise_mul_activation_arm, retrive_op) {
auto fusion_elementwise_mul_activation = auto fusion_elementwise_mul_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create("fusion_elementwise_mul_activation");
"fusion_elementwise_mul_activation");
ASSERT_FALSE(fusion_elementwise_mul_activation.empty()); ASSERT_FALSE(fusion_elementwise_mul_activation.empty());
ASSERT_TRUE(fusion_elementwise_mul_activation.front()); ASSERT_TRUE(fusion_elementwise_mul_activation.front());
} }
...@@ -629,9 +625,7 @@ TEST(fusion_elementwise_mul_activation_arm, compute) { ...@@ -629,9 +625,7 @@ TEST(fusion_elementwise_mul_activation_arm, compute) {
} }
TEST(elementwise_max_arm, retrive_op) { TEST(elementwise_max_arm, retrive_op) {
auto elementwise_max = auto elementwise_max = KernelRegistry::Global().Create("elementwise_max");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_max");
ASSERT_FALSE(elementwise_max.empty()); ASSERT_FALSE(elementwise_max.empty());
ASSERT_TRUE(elementwise_max.front()); ASSERT_TRUE(elementwise_max.front());
} }
...@@ -724,8 +718,7 @@ TEST(elementwise_max, compute) { ...@@ -724,8 +718,7 @@ TEST(elementwise_max, compute) {
TEST(fusion_elementwise_max_activation_arm, retrive_op) { TEST(fusion_elementwise_max_activation_arm, retrive_op) {
auto fusion_elementwise_max_activation = auto fusion_elementwise_max_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create("fusion_elementwise_max_activation");
"fusion_elementwise_max_activation");
ASSERT_FALSE(fusion_elementwise_max_activation.empty()); ASSERT_FALSE(fusion_elementwise_max_activation.empty());
ASSERT_TRUE(fusion_elementwise_max_activation.front()); ASSERT_TRUE(fusion_elementwise_max_activation.front());
} }
...@@ -823,9 +816,7 @@ TEST(fusion_elementwise_max_activation_arm, compute) { ...@@ -823,9 +816,7 @@ TEST(fusion_elementwise_max_activation_arm, compute) {
} }
TEST(elementwise_mod_int64_arm, retrive_op) { TEST(elementwise_mod_int64_arm, retrive_op) {
auto elementwise_mod = auto elementwise_mod = KernelRegistry::Global().Create("elementwise_mod");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kInt64)>(
"elementwise_mod");
ASSERT_FALSE(elementwise_mod.empty()); ASSERT_FALSE(elementwise_mod.empty());
ASSERT_TRUE(elementwise_mod.front()); ASSERT_TRUE(elementwise_mod.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/layer_norm_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -181,9 +183,7 @@ TEST(layer_norm_arm, compute) { ...@@ -181,9 +183,7 @@ TEST(layer_norm_arm, compute) {
} }
TEST(layer_norm, retrive_op) { TEST(layer_norm, retrive_op) {
auto layer_norm = auto layer_norm = KernelRegistry::Global().Create("layer_norm");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"layer_norm");
ASSERT_FALSE(layer_norm.empty()); ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front()); ASSERT_TRUE(layer_norm.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/lrn_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/lrn_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -133,8 +135,7 @@ void lrn_compute_ref(const operators::LrnParam& param) { ...@@ -133,8 +135,7 @@ void lrn_compute_ref(const operators::LrnParam& param) {
} }
TEST(lrn_arm, retrive_op) { TEST(lrn_arm, retrive_op) {
auto lrn = auto lrn = KernelRegistry::Global().Create("lrn");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("lrn");
ASSERT_FALSE(lrn.empty()); ASSERT_FALSE(lrn.empty());
ASSERT_TRUE(lrn.front()); ASSERT_TRUE(lrn.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +28,7 @@ namespace kernels { ...@@ -26,9 +28,7 @@ namespace kernels {
namespace arm { namespace arm {
TEST(merge_lod_tensor_arm, retrive_op) { TEST(merge_lod_tensor_arm, retrive_op) {
auto kernel = auto kernel = KernelRegistry::Global().Create("merge_lod_tensor");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"merge_lod_tensor");
ASSERT_FALSE(kernel.empty()); ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front()); ASSERT_TRUE(kernel.front());
} }
......
...@@ -12,16 +12,18 @@ ...@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/mul_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/mul_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -69,8 +71,7 @@ void FillData(T* a, ...@@ -69,8 +71,7 @@ void FillData(T* a,
} }
TEST(mul_arm, retrive_op) { TEST(mul_arm, retrive_op) {
auto mul = auto mul = KernelRegistry::Global().Create("mul");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
ASSERT_FALSE(mul.empty()); ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front()); ASSERT_TRUE(mul.front());
} }
......
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/pool_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/pool_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -341,8 +343,7 @@ TEST(pool_arm, compute) { ...@@ -341,8 +343,7 @@ TEST(pool_arm, compute) {
} }
TEST(pool_arm, retrive_op) { TEST(pool_arm, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( auto pool = KernelRegistry::Global().Create("pool2d");
"pool2d");
ASSERT_FALSE(pool.empty()); ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front()); ASSERT_TRUE(pool.front());
} }
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/scale_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/scale_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -103,8 +105,7 @@ TEST(scale_arm, compute) { ...@@ -103,8 +105,7 @@ TEST(scale_arm, compute) {
} }
TEST(scale, retrive_op) { TEST(scale, retrive_op) {
auto scale = auto scale = KernelRegistry::Global().Create("scale");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("scale");
ASSERT_FALSE(scale.empty()); ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front()); ASSERT_TRUE(scale.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/softmax_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) { ...@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) {
} }
TEST(softmax, retrive_op) { TEST(softmax, retrive_op) {
auto softmax = auto softmax = KernelRegistry::Global().Create("softmax");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"softmax");
ASSERT_FALSE(softmax.empty()); ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front()); ASSERT_TRUE(softmax.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/split_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstring> #include <cstring>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/split_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -165,8 +167,7 @@ TEST(split_arm, compute) { ...@@ -165,8 +167,7 @@ TEST(split_arm, compute) {
} }
TEST(split, retrive_op) { TEST(split, retrive_op) {
auto split = auto split = KernelRegistry::Global().Create("split");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("split");
ASSERT_FALSE(split.empty()); ASSERT_FALSE(split.empty());
ASSERT_TRUE(split.front()); ASSERT_TRUE(split.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/split_lod_tensor_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/arm/split_lod_tensor_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +28,7 @@ namespace kernels { ...@@ -26,9 +28,7 @@ namespace kernels {
namespace arm { namespace arm {
TEST(split_lod_tensor_arm, retrive_op) { TEST(split_lod_tensor_arm, retrive_op) {
auto kernel = auto kernel = KernelRegistry::Global().Create("split_lod_tensor");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"split_lod_tensor");
ASSERT_FALSE(kernel.empty()); ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front()); ASSERT_TRUE(kernel.front());
} }
......
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/transpose_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/kernels/arm/transpose_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -121,9 +123,7 @@ TEST(transpose_arm, compute_shape_nchw) { ...@@ -121,9 +123,7 @@ TEST(transpose_arm, compute_shape_nchw) {
} }
TEST(transpose, retrive_op) { TEST(transpose, retrive_op) {
auto transpose = auto transpose = KernelRegistry::Global().Create("transpose");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"transpose");
ASSERT_FALSE(transpose.empty()); ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front()); ASSERT_TRUE(transpose.front());
} }
...@@ -189,9 +189,7 @@ TEST(transpose2_arm, compute_shape_nchw) { ...@@ -189,9 +189,7 @@ TEST(transpose2_arm, compute_shape_nchw) {
} }
TEST(transpose2, retrive_op) { TEST(transpose2, retrive_op) {
auto transpose2 = auto transpose2 = KernelRegistry::Global().Create("transpose2");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"transpose2");
ASSERT_FALSE(transpose2.empty()); ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front()); ASSERT_TRUE(transpose2.front());
} }
......
...@@ -12,14 +12,16 @@ ...@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/cuda/lookup_table_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/cuda/lookup_table_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -56,9 +58,7 @@ void LookupTableComputeRef(const operators::LookupTableParam& param) { ...@@ -56,9 +58,7 @@ void LookupTableComputeRef(const operators::LookupTableParam& param) {
} }
TEST(lookup_table_cuda, retrieve_op) { TEST(lookup_table_cuda, retrieve_op) {
auto lookup_table = auto lookup_table = KernelRegistry::Global().Create("lookup_table");
KernelRegistry::Global().Create<TARGET(kCUDA), PRECISION(kFloat)>(
"lookup_table");
ASSERT_FALSE(lookup_table.empty()); ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front()); ASSERT_TRUE(lookup_table.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/fpga/activation_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/fpga/activation_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -37,8 +39,7 @@ void activation_compute_ref(const operators::ActivationParam& param) { ...@@ -37,8 +39,7 @@ void activation_compute_ref(const operators::ActivationParam& param) {
} }
TEST(activation_fpga, retrive_op) { TEST(activation_fpga, retrive_op) {
auto activation = auto activation = KernelRegistry::Global().Create("relu");
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>("relu");
ASSERT_FALSE(activation.empty()); ASSERT_FALSE(activation.empty());
ASSERT_TRUE(activation.front()); ASSERT_TRUE(activation.front());
} }
......
...@@ -12,15 +12,17 @@ ...@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/fpga/fc_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/fpga/fc_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -76,8 +78,7 @@ void FillData(T* a, ...@@ -76,8 +78,7 @@ void FillData(T* a,
} }
TEST(fc_fpga, retrive_op) { TEST(fc_fpga, retrive_op) {
auto fc = auto fc = KernelRegistry::Global().Create("fc");
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>("fc");
ASSERT_FALSE(fc.empty()); ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front()); ASSERT_TRUE(fc.front());
} }
......
...@@ -12,14 +12,15 @@ ...@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/fpga/pooling_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/pooling_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -277,8 +278,7 @@ TEST(pool_fpga, compute) { ...@@ -277,8 +278,7 @@ TEST(pool_fpga, compute) {
} }
TEST(pool_fpga, retrive_op) { TEST(pool_fpga, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>( auto pool = KernelRegistry::Global().Create("pool2d");
"pool2d");
ASSERT_FALSE(pool.empty()); ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front()); ASSERT_TRUE(pool.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/fpga/softmax_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/fpga/softmax_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) { ...@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) {
} }
TEST(softmax, retrive_op) { TEST(softmax, retrive_op) {
auto softmax = auto softmax = KernelRegistry::Global().Create("softmax");
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>(
"softmax");
ASSERT_FALSE(softmax.empty()); ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front()); ASSERT_TRUE(softmax.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/activation_compute.cc"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,8 +28,7 @@ namespace kernels { ...@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(relu_x86, retrive_op) { TEST(relu_x86, retrive_op) {
auto relu = auto relu = KernelRegistry::Global().Create("relu");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
ASSERT_FALSE(relu.empty()); ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front()); ASSERT_TRUE(relu.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -81,8 +83,7 @@ int get_max_len(const LoD& lod) { ...@@ -81,8 +83,7 @@ int get_max_len(const LoD& lod) {
TEST(attention_padding_mask_x86, retrive_op) { TEST(attention_padding_mask_x86, retrive_op) {
auto attention_padding_mask = auto attention_padding_mask =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>( KernelRegistry::Global().Create("attention_padding_mask");
"attention_padding_mask");
ASSERT_FALSE(attention_padding_mask.empty()); ASSERT_FALSE(attention_padding_mask.empty());
ASSERT_TRUE(attention_padding_mask.front()); ASSERT_TRUE(attention_padding_mask.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/batch_norm_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/batch_norm_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +28,7 @@ namespace kernels { ...@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(batch_norm_x86, retrive_op) { TEST(batch_norm_x86, retrive_op) {
auto batch_norm = auto batch_norm = KernelRegistry::Global().Create("batch_norm");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"batch_norm");
ASSERT_FALSE(batch_norm.empty()); ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front()); ASSERT_TRUE(batch_norm.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/cast_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/cast_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,8 +27,7 @@ namespace kernels { ...@@ -25,8 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(cast_x86, retrive_op) { TEST(cast_x86, retrive_op) {
auto cast = auto cast = KernelRegistry::Global().Create("cast");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("cast");
ASSERT_FALSE(cast.empty()); ASSERT_FALSE(cast.empty());
ASSERT_TRUE(cast.front()); ASSERT_TRUE(cast.front());
} }
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/concat_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/concat_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -23,9 +25,7 @@ namespace kernels { ...@@ -23,9 +25,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(concat_x86, retrive_op) { TEST(concat_x86, retrive_op) {
auto concat = auto concat = KernelRegistry::Global().Create("concat");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"concat");
ASSERT_FALSE(concat.empty()); ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front()); ASSERT_TRUE(concat.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/conv_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/conv_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(conv_x86, retrive_op) { TEST(conv_x86, retrive_op) {
auto conv2d = auto conv2d = KernelRegistry::Global().Create("conv2d");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"conv2d");
ASSERT_FALSE(conv2d.empty()); ASSERT_FALSE(conv2d.empty());
ASSERT_TRUE(conv2d.front()); ASSERT_TRUE(conv2d.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/dropout_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/dropout_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +28,7 @@ namespace kernels { ...@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(dropout_x86, retrive_op) { TEST(dropout_x86, retrive_op) {
auto dropout = auto dropout = KernelRegistry::Global().Create("dropout");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"dropout");
ASSERT_FALSE(dropout.empty()); ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front()); ASSERT_TRUE(dropout.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/elementwise_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +28,7 @@ namespace kernels { ...@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(elementwise_add_x86, retrive_op) { TEST(elementwise_add_x86, retrive_op) {
auto elementwise_add = auto elementwise_add = KernelRegistry::Global().Create("elementwise_add");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"elementwise_add");
ASSERT_FALSE(elementwise_add.empty()); ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front()); ASSERT_TRUE(elementwise_add.front());
} }
......
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -26,8 +29,7 @@ namespace x86 { ...@@ -26,8 +29,7 @@ namespace x86 {
TEST(fill_constant_batch_size_like_x86, retrive_op) { TEST(fill_constant_batch_size_like_x86, retrive_op) {
auto fill_constant_batch_size_like = auto fill_constant_batch_size_like =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>( KernelRegistry::Global().Create("fill_constant_batch_size_like");
"fill_constant_batch_size_like");
ASSERT_FALSE(fill_constant_batch_size_like.empty()); ASSERT_FALSE(fill_constant_batch_size_like.empty());
ASSERT_TRUE(fill_constant_batch_size_like.front()); ASSERT_TRUE(fill_constant_batch_size_like.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/gather_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/gather_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(gather_x86, retrive_op) { TEST(gather_x86, retrive_op) {
auto gather = auto gather = KernelRegistry::Global().Create("gather");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"gather");
ASSERT_FALSE(gather.empty()); ASSERT_FALSE(gather.empty());
int cnt = 0; int cnt = 0;
for (auto item = gather.begin(); item != gather.end(); ++item) { for (auto item = gather.begin(); item != gather.end(); ++item) {
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc" #include "lite/kernels/x86/activation_compute.cc"
...@@ -26,8 +28,7 @@ namespace kernels { ...@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(gelu_x86, retrive_op) { TEST(gelu_x86, retrive_op) {
auto gelu = auto gelu = KernelRegistry::Global().Create("gelu");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gelu");
ASSERT_FALSE(gelu.empty()); ASSERT_FALSE(gelu.empty());
ASSERT_TRUE(gelu.front()); ASSERT_TRUE(gelu.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/gru_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/gru_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,8 +28,7 @@ namespace kernels { ...@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(gru_x86, retrive_op) { TEST(gru_x86, retrive_op) {
auto gru = auto gru = KernelRegistry::Global().Create("gru");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gru");
ASSERT_FALSE(gru.empty()); ASSERT_FALSE(gru.empty());
ASSERT_TRUE(gru.front()); ASSERT_TRUE(gru.front());
} }
......
...@@ -12,15 +12,17 @@ ...@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/layer_norm_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/backends/x86/jit/helper.h" #include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h" #include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h" #include "lite/backends/x86/jit/kernels.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/layer_norm_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -74,9 +76,7 @@ std::vector<float> ref(lite::Tensor* x, ...@@ -74,9 +76,7 @@ std::vector<float> ref(lite::Tensor* x,
// layer_norm // layer_norm
TEST(layer_norm_x86, retrive_op) { TEST(layer_norm_x86, retrive_op) {
auto layer_norm = auto layer_norm = KernelRegistry::Global().Create("layer_norm");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"layer_norm");
ASSERT_FALSE(layer_norm.empty()); ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front()); ASSERT_TRUE(layer_norm.front());
} }
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.h" #include "lite/kernels/x86/activation_compute.h"
...@@ -24,9 +26,7 @@ namespace kernels { ...@@ -24,9 +26,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(leaky_relu_x86, retrive_op) { TEST(leaky_relu_x86, retrive_op) {
auto leaky_relu = auto leaky_relu = KernelRegistry::Global().Create("leaky_relu");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"leaky_relu");
ASSERT_FALSE(leaky_relu.empty()); ASSERT_FALSE(leaky_relu.empty());
ASSERT_TRUE(leaky_relu.front()); ASSERT_TRUE(leaky_relu.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(match_matrix_tensor_x86, retrive_op) { TEST(match_matrix_tensor_x86, retrive_op) {
auto kernel = auto kernel = KernelRegistry::Global().Create("match_matrix_tensor");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"match_matrix_tensor");
ASSERT_FALSE(kernel.empty()); ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front()); ASSERT_TRUE(kernel.front());
} }
......
...@@ -12,22 +12,23 @@ ...@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/matmul_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/matmul_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace x86 { namespace x86 {
TEST(matmul_x86, retrive_op) { TEST(matmul_x86, retrive_op) {
auto matmul = auto matmul = KernelRegistry::Global().Create("matmul");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"matmul");
ASSERT_FALSE(matmul.empty()); ASSERT_FALSE(matmul.empty());
ASSERT_TRUE(matmul.front()); ASSERT_TRUE(matmul.front());
} }
......
...@@ -12,21 +12,23 @@ ...@@ -12,21 +12,23 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/mul_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace x86 { namespace x86 {
TEST(mul_x86, retrive_op) { TEST(mul_x86, retrive_op) {
auto mul = auto mul = KernelRegistry::Global().Create("mul");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("mul");
ASSERT_FALSE(mul.empty()); ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front()); ASSERT_TRUE(mul.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/pool_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/pool_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +28,7 @@ namespace kernels { ...@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(pool_x86, retrive_op) { TEST(pool_x86, retrive_op) {
auto pool2d = auto pool2d = KernelRegistry::Global().Create("pool2d");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"pool2d");
ASSERT_FALSE(pool2d.empty()); ASSERT_FALSE(pool2d.empty());
ASSERT_TRUE(pool2d.front()); ASSERT_TRUE(pool2d.front());
} }
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.h" #include "lite/kernels/x86/activation_compute.h"
...@@ -24,8 +26,7 @@ namespace kernels { ...@@ -24,8 +26,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(relu_x86, retrive_op) { TEST(relu_x86, retrive_op) {
auto relu = auto relu = KernelRegistry::Global().Create("relu");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
ASSERT_FALSE(relu.empty()); ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front()); ASSERT_TRUE(relu.front());
} }
......
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/reshape_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/reshape_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -26,9 +29,7 @@ namespace x86 { ...@@ -26,9 +29,7 @@ namespace x86 {
// reshape // reshape
TEST(reshape_x86, retrive_op) { TEST(reshape_x86, retrive_op) {
auto reshape = auto reshape = KernelRegistry::Global().Create("reshape");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape");
ASSERT_FALSE(reshape.empty()); ASSERT_FALSE(reshape.empty());
ASSERT_TRUE(reshape.front()); ASSERT_TRUE(reshape.front());
} }
...@@ -86,9 +87,7 @@ TEST(reshape_x86, run_test) { ...@@ -86,9 +87,7 @@ TEST(reshape_x86, run_test) {
// reshape2 // reshape2
TEST(reshape2_x86, retrive_op) { TEST(reshape2_x86, retrive_op) {
auto reshape2 = auto reshape2 = KernelRegistry::Global().Create("reshape2");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape2");
ASSERT_FALSE(reshape2.empty()); ASSERT_FALSE(reshape2.empty());
ASSERT_TRUE(reshape2.front()); ASSERT_TRUE(reshape2.front());
} }
......
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/scale_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/scale_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -24,8 +26,7 @@ namespace kernels { ...@@ -24,8 +26,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(scale_x86, retrive_op) { TEST(scale_x86, retrive_op) {
auto scale = auto scale = KernelRegistry::Global().Create("scale");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("scale");
ASSERT_FALSE(scale.empty()); ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front()); ASSERT_TRUE(scale.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/search_fc_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_fc_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -53,9 +55,7 @@ void fc_cpu_base(const lite::Tensor* X, ...@@ -53,9 +55,7 @@ void fc_cpu_base(const lite::Tensor* X,
} }
TEST(search_fc_x86, retrive_op) { TEST(search_fc_x86, retrive_op) {
auto search_fc = auto search_fc = KernelRegistry::Global().Create("search_fc");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_fc");
ASSERT_FALSE(search_fc.empty()); ASSERT_FALSE(search_fc.empty());
ASSERT_TRUE(search_fc.front()); ASSERT_TRUE(search_fc.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/search_grnn_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_grnn_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(search_grnn_x86, retrive_op) { TEST(search_grnn_x86, retrive_op) {
auto kernel = auto kernel = KernelRegistry::Global().Create("search_grnn");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_grnn");
ASSERT_FALSE(kernel.empty()); ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front()); ASSERT_TRUE(kernel.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/search_group_padding_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_group_padding_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,8 +28,7 @@ namespace x86 { ...@@ -26,8 +28,7 @@ namespace x86 {
TEST(search_group_padding_x86, retrieve_op) { TEST(search_group_padding_x86, retrieve_op) {
auto search_group_padding = auto search_group_padding =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>( KernelRegistry::Global().Create("search_group_padding");
"search_group_padding");
ASSERT_FALSE(search_group_padding.empty()); ASSERT_FALSE(search_group_padding.empty());
ASSERT_TRUE(search_group_padding.front()); ASSERT_TRUE(search_group_padding.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/search_seq_depadding_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_seq_depadding_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(search_seq_depadding_x86, retrive_op) { TEST(search_seq_depadding_x86, retrive_op) {
auto kernel = auto kernel = KernelRegistry::Global().Create("search_seq_depadding");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_seq_depadding");
ASSERT_FALSE(kernel.empty()); ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front()); ASSERT_TRUE(kernel.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -77,8 +79,7 @@ void prepare_input(Tensor* x, const LoD& x_lod) { ...@@ -77,8 +79,7 @@ void prepare_input(Tensor* x, const LoD& x_lod) {
TEST(sequence_arithmetic_x86, retrive_op) { TEST(sequence_arithmetic_x86, retrive_op) {
auto sequence_arithmetic = auto sequence_arithmetic =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>( KernelRegistry::Global().Create("sequence_arithmetic");
"sequence_arithmetic");
ASSERT_FALSE(sequence_arithmetic.empty()); ASSERT_FALSE(sequence_arithmetic.empty());
ASSERT_TRUE(sequence_arithmetic.front()); ASSERT_TRUE(sequence_arithmetic.front());
} }
......
...@@ -12,12 +12,15 @@ ...@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/sequence_concat_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_concat_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -94,9 +97,7 @@ static void sequence_concat_ref(const std::vector<lite::Tensor*>& xs, ...@@ -94,9 +97,7 @@ static void sequence_concat_ref(const std::vector<lite::Tensor*>& xs,
} // namespace } // namespace
TEST(sequence_concat_x86, retrive_op) { TEST(sequence_concat_x86, retrive_op) {
auto sequence_concat = auto sequence_concat = KernelRegistry::Global().Create("sequence_concat");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_concat");
ASSERT_FALSE(sequence_concat.empty()); ASSERT_FALSE(sequence_concat.empty());
ASSERT_TRUE(sequence_concat.front()); ASSERT_TRUE(sequence_concat.front());
} }
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/sequence_expand_as_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_expand_as_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -27,8 +29,7 @@ namespace x86 { ...@@ -27,8 +29,7 @@ namespace x86 {
TEST(sequence_expand_as_x86, retrive_op) { TEST(sequence_expand_as_x86, retrive_op) {
auto sequence_expand_as = auto sequence_expand_as =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>( KernelRegistry::Global().Create("sequence_expand_as");
"sequence_expand_as");
ASSERT_FALSE(sequence_expand_as.empty()); ASSERT_FALSE(sequence_expand_as.empty());
ASSERT_TRUE(sequence_expand_as.front()); ASSERT_TRUE(sequence_expand_as.front());
} }
......
...@@ -12,21 +12,22 @@ ...@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/sequence_pool_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_pool_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace x86 { namespace x86 {
TEST(sequence_pool_x86, retrive_op) { TEST(sequence_pool_x86, retrive_op) {
auto sequence_pool = auto sequence_pool = KernelRegistry::Global().Create("sequence_pool");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_pool");
ASSERT_FALSE(sequence_pool.empty()); ASSERT_FALSE(sequence_pool.empty());
ASSERT_TRUE(sequence_pool.front()); ASSERT_TRUE(sequence_pool.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/sequence_reverse_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_reverse_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -44,9 +46,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) { ...@@ -44,9 +46,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
} // namespace } // namespace
TEST(sequence_reverse_x86, retrive_op) { TEST(sequence_reverse_x86, retrive_op) {
auto sequence_reverse = auto sequence_reverse = KernelRegistry::Global().Create("sequence_reverse");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_reverse");
ASSERT_FALSE(sequence_reverse.empty()); ASSERT_FALSE(sequence_reverse.empty());
ASSERT_TRUE(sequence_reverse.front()); ASSERT_TRUE(sequence_reverse.front());
} }
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/shape_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/shape_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -23,8 +25,7 @@ namespace kernels { ...@@ -23,8 +25,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(shape_x86, retrive_op) { TEST(shape_x86, retrive_op) {
auto shape = auto shape = KernelRegistry::Global().Create("shape");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("shape");
ASSERT_FALSE(shape.empty()); ASSERT_FALSE(shape.empty());
ASSERT_TRUE(shape.front()); ASSERT_TRUE(shape.front());
} }
......
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/slice_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/slice_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -79,8 +82,7 @@ static void slice_ref(const float* input, ...@@ -79,8 +82,7 @@ static void slice_ref(const float* input,
} }
TEST(slice_x86, retrive_op) { TEST(slice_x86, retrive_op) {
auto slice = auto slice = KernelRegistry::Global().Create("slice");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("slice");
ASSERT_FALSE(slice.empty()); ASSERT_FALSE(slice.empty());
ASSERT_TRUE(slice.front()); ASSERT_TRUE(slice.front());
} }
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/softmax_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/softmax_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,9 +27,7 @@ namespace kernels { ...@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(softmax_x86, retrive_op) { TEST(softmax_x86, retrive_op) {
auto softmax = auto softmax = KernelRegistry::Global().Create("softmax");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"softmax");
ASSERT_FALSE(softmax.empty()); ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front()); ASSERT_TRUE(softmax.front());
} }
......
...@@ -12,12 +12,15 @@ ...@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/stack_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/stack_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -25,8 +28,7 @@ namespace x86 { ...@@ -25,8 +28,7 @@ namespace x86 {
// stack // stack
TEST(stack_x86, retrive_op) { TEST(stack_x86, retrive_op) {
auto stack = auto stack = KernelRegistry::Global().Create("stack");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("stack");
ASSERT_FALSE(stack.empty()); ASSERT_FALSE(stack.empty());
ASSERT_TRUE(stack.front()); ASSERT_TRUE(stack.front());
} }
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc" #include "lite/kernels/x86/activation_compute.cc"
...@@ -26,8 +28,7 @@ namespace kernels { ...@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 { namespace x86 {
TEST(tanh_x86, retrive_op) { TEST(tanh_x86, retrive_op) {
auto tanh = auto tanh = KernelRegistry::Global().Create("tanh");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("tanh");
ASSERT_FALSE(tanh.empty()); ASSERT_FALSE(tanh.empty());
ASSERT_TRUE(tanh.front()); ASSERT_TRUE(tanh.front());
} }
......
...@@ -12,12 +12,15 @@ ...@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/kernels/x86/transpose_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -25,9 +28,7 @@ namespace x86 { ...@@ -25,9 +28,7 @@ namespace x86 {
// transpose // transpose
TEST(transpose_x86, retrive_op) { TEST(transpose_x86, retrive_op) {
auto transpose = auto transpose = KernelRegistry::Global().Create("transpose");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose");
ASSERT_FALSE(transpose.empty()); ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front()); ASSERT_TRUE(transpose.front());
} }
...@@ -75,9 +76,7 @@ TEST(transpose_x86, run_test) { ...@@ -75,9 +76,7 @@ TEST(transpose_x86, run_test) {
// transpose2 // transpose2
TEST(transpose2_x86, retrive_op) { TEST(transpose2_x86, retrive_op) {
auto transpose2 = auto transpose2 = KernelRegistry::Global().Create("transpose2");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose2");
ASSERT_FALSE(transpose2.empty()); ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front()); ASSERT_TRUE(transpose2.front());
} }
......
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/x86/var_conv_2d_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/kernels/x86/var_conv_2d_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -197,9 +200,7 @@ static void var_conv_2d_ref(const lite::Tensor* bottom, ...@@ -197,9 +200,7 @@ static void var_conv_2d_ref(const lite::Tensor* bottom,
} }
TEST(var_conv_2d_x86, retrive_op) { TEST(var_conv_2d_x86, retrive_op) {
auto var_conv_2d = auto var_conv_2d = KernelRegistry::Global().Create("var_conv_2d");
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"var_conv_2d");
ASSERT_FALSE(var_conv_2d.empty()); ASSERT_FALSE(var_conv_2d.empty());
ASSERT_TRUE(var_conv_2d.front()); ASSERT_TRUE(var_conv_2d.front());
} }
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h" #include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
#include "lite/utils/variant.h"
/* /*
* This file contains all the argument parameter data structure for operators. * This file contains all the argument parameter data structure for operators.
*/ */
......
...@@ -14,10 +14,16 @@ ...@@ -14,10 +14,16 @@
#pragma once #pragma once
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "lite/utils/any.h" #include "lite/utils/any.h"
#include "lite/utils/check.h" #include "lite/utils/check.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/factory.h"
#include "lite/utils/hash.h" #include "lite/utils/hash.h"
#include "lite/utils/io.h" #include "lite/utils/io.h"
#include "lite/utils/macros.h" #include "lite/utils/macros.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "lite/utils/all.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
/*
* Factor for any Type creator.
*
* Usage:
*
* struct SomeType;
* // Register a creator.
* Factory<SomeType>::Global().Register("some_key", [] ->
* std::unique_ptr<SomeType> { ... });
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template <typename ItemType, typename ItemTypePtr>
class Factory {
public:
using item_t = ItemType;
using self_t = Factory<item_t, ItemTypePtr>;
using item_ptr_t = ItemTypePtr;
using creator_t = std::function<item_ptr_t()>;
static Factory& Global() {
static Factory* x = new self_t;
return *x;
}
void Register(const std::string& op_type, creator_t&& creator) {
creators_[op_type].emplace_back(std::move(creator));
}
item_ptr_t Create(const std::string& op_type) const {
auto res = Creates(op_type);
if (res.empty()) return nullptr;
CHECK_EQ(res.size(), 1UL) << "Get multiple Op for type " << op_type;
return std::move(res.front());
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
std::list<item_ptr_t> res;
auto it = creators_.find(op_type);
if (it == creators_.end()) return res;
for (auto& c : it->second) {
res.emplace_back(c());
}
return res;
}
std::string DebugString() const {
STL::stringstream ss;
for (const auto& item : creators_) {
ss << " - " << item.first << "\n";
}
return ss.str();
}
protected:
std::map<std::string, std::list<creator_t>> creators_;
};
/* A helper function to help run a lambda at the start.
*/
template <typename Type>
class Registor {
public:
explicit Registor(std::function<void()>&& functor) { functor(); }
// Touch will do nothing.
int Touch() { return 0; }
};
} // namespace lite
} // namespace paddle
Subproject commit ac203b20926b13a35ff85277d2e5d3c38698eee8 Subproject commit 6df40a2471737b27271bdd9b900ab5f3aec746c7
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册