未验证 提交 db7639ca 编写于 作者: S Shibo Tao 提交者: GitHub

optimize register mechanism (#3745)

* refactor register mechanism, current so size: 1.20MB. test=develop

* fix KernelRegistry::Global().Create. test=develop

* fix cpplint errors. test=develop

* fix test_subgraph_pass bug. test=develop

* register kernel with target,precision,datalayout combination. test=develop

* fix test_paddle_api no op found bug. test=develop

* enhance comment

* fix lite/kernels/arm/elementwise_compute_test.cc. test=develop

* fix code style

* revert format of unchanged files. test=develop

* fix code format according to cpplint 1.5.1. test=develop

* remove redundant include header. test=develop
上级 732bb91b
......@@ -15,8 +15,6 @@
#include "lite/api/light_api.h"
#include <algorithm>
#include <map>
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
namespace paddle {
namespace lite {
......
......@@ -15,8 +15,11 @@
#include "lite/api/paddle_api.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/io.h"
DEFINE_string(model_dir, "", "");
namespace paddle {
......
......@@ -13,8 +13,12 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
......
......@@ -17,277 +17,5 @@
#include <set>
namespace paddle {
namespace lite {
const std::map<std::string, std::string> &GetOp2PathDict() {
return OpKernelInfoCollector::Global().GetOp2PathDict();
}
std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type,
TargetType target,
PrecisionType precision,
DataLayoutType layout) {
Place place{target, precision, layout};
VLOG(5) << "creating " << op_type << " kernel for " << place.DebugString();
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
case DATALAYOUT(kNHWC): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNHWC)>(op_type); \
case DATALAYOUT(kImageDefault): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageDefault)>(op_type); \
case DATALAYOUT(kImageFolder): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageFolder)>(op_type); \
case DATALAYOUT(kImageNW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageNW)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kFP16): \
CREATE_KERNEL1(target__, kFP16); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
case PRECISION(kInt32): \
CREATE_KERNEL1(target__, kInt32); \
case PRECISION(kInt64): \
CREATE_KERNEL1(target__, kInt64); \
default: \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
}
switch (target) {
case TARGET(kHost): {
CREATE_KERNEL(kHost);
} break;
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_X86)
case TARGET(kX86): {
CREATE_KERNEL(kX86);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_CUDA)
case TARGET(kCUDA): {
CREATE_KERNEL(kCUDA);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_ARM)
case TARGET(kARM): {
CREATE_KERNEL(kARM);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_OPENCL)
case TARGET(kOpenCL): {
CREATE_KERNEL(kOpenCL);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_NPU)
case TARGET(kNPU): {
CREATE_KERNEL(kNPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_APU)
case TARGET(kAPU): {
CREATE_KERNEL(kAPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_XPU)
case TARGET(kXPU): {
CREATE_KERNEL(kXPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_FPGA)
case TARGET(kFPGA): {
CREATE_KERNEL(kFPGA);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_BM)
case TARGET(kBM): {
CREATE_KERNEL(kBM);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_MLU)
case TARGET(kMLU): {
CREATE_KERNEL(kMLU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_RKNPU)
case TARGET(kRKNPU): {
CREATE_KERNEL(kRKNPU);
} break;
#endif
default:
CHECK(false) << "not supported kernel target " << TargetToStr(target);
}
#undef CREATE_KERNEL
return std::list<std::unique_ptr<KernelBase>>();
}
KernelRegistry::KernelRegistry() : registries_() {
#define INIT_FOR(target__, precision__, layout__) \
registries_[std::make_tuple(TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__))] \
.set<KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_CUDA)
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kFloat, kNHWC);
INIT_FOR(kCUDA, kInt8, kNCHW);
INIT_FOR(kCUDA, kFP16, kNCHW);
INIT_FOR(kCUDA, kFP16, kNHWC);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
INIT_FOR(kCUDA, kInt64, kNCHW);
INIT_FOR(kCUDA, kInt64, kNHWC);
INIT_FOR(kCUDA, kInt32, kNCHW);
INIT_FOR(kCUDA, kInt32, kNHWC);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_MLU)
INIT_FOR(kMLU, kFloat, kNHWC);
INIT_FOR(kMLU, kFloat, kNCHW);
INIT_FOR(kMLU, kFP16, kNHWC);
INIT_FOR(kMLU, kFP16, kNCHW);
INIT_FOR(kMLU, kInt8, kNHWC);
INIT_FOR(kMLU, kInt8, kNCHW);
INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW);
#endif
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kAny, kNHWC);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kHost, kBool, kNCHW);
INIT_FOR(kHost, kBool, kNHWC);
INIT_FOR(kHost, kBool, kAny);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
INIT_FOR(kHost, kFloat, kAny);
INIT_FOR(kHost, kFP16, kNCHW);
INIT_FOR(kHost, kFP16, kNHWC);
INIT_FOR(kHost, kFP16, kAny);
INIT_FOR(kHost, kInt8, kNCHW);
INIT_FOR(kHost, kInt8, kNHWC);
INIT_FOR(kHost, kInt8, kAny);
INIT_FOR(kHost, kInt16, kNCHW);
INIT_FOR(kHost, kInt16, kNHWC);
INIT_FOR(kHost, kInt16, kAny);
INIT_FOR(kHost, kInt32, kNCHW);
INIT_FOR(kHost, kInt32, kNHWC);
INIT_FOR(kHost, kInt32, kAny);
INIT_FOR(kHost, kInt64, kNCHW);
INIT_FOR(kHost, kInt64, kNHWC);
INIT_FOR(kHost, kInt64, kAny);
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_X86)
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_ARM)
INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kFloat, kNHWC);
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kInt8, kNHWC);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW);
INIT_FOR(kARM, kInt64, kNCHW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_OPENCL)
INIT_FOR(kOpenCL, kFloat, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNHWC);
INIT_FOR(kOpenCL, kAny, kNCHW);
INIT_FOR(kOpenCL, kAny, kNHWC);
INIT_FOR(kOpenCL, kFloat, kAny);
INIT_FOR(kOpenCL, kInt8, kNCHW);
INIT_FOR(kOpenCL, kAny, kAny);
INIT_FOR(kOpenCL, kFP16, kNCHW);
INIT_FOR(kOpenCL, kFP16, kNHWC);
INIT_FOR(kOpenCL, kFP16, kImageDefault);
INIT_FOR(kOpenCL, kFP16, kImageFolder);
INIT_FOR(kOpenCL, kFP16, kImageNW);
INIT_FOR(kOpenCL, kFloat, kImageDefault);
INIT_FOR(kOpenCL, kFloat, kImageFolder);
INIT_FOR(kOpenCL, kFloat, kImageNW);
INIT_FOR(kOpenCL, kAny, kImageDefault);
INIT_FOR(kOpenCL, kAny, kImageFolder);
INIT_FOR(kOpenCL, kAny, kImageNW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_NPU)
INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kFloat, kNHWC);
INIT_FOR(kNPU, kInt8, kNCHW);
INIT_FOR(kNPU, kInt8, kNHWC);
INIT_FOR(kNPU, kAny, kNCHW);
INIT_FOR(kNPU, kAny, kNHWC);
INIT_FOR(kNPU, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_APU)
INIT_FOR(kAPU, kInt8, kNCHW);
INIT_FOR(kXPU, kFloat, kNCHW);
INIT_FOR(kXPU, kInt8, kNCHW);
INIT_FOR(kXPU, kAny, kNCHW);
INIT_FOR(kXPU, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_FPGA)
INIT_FOR(kFPGA, kFP16, kNHWC);
INIT_FOR(kFPGA, kFP16, kAny);
INIT_FOR(kFPGA, kFloat, kNHWC);
INIT_FOR(kFPGA, kAny, kNHWC);
INIT_FOR(kFPGA, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_BM)
INIT_FOR(kBM, kFloat, kNCHW);
INIT_FOR(kBM, kInt8, kNCHW);
INIT_FOR(kBM, kAny, kNCHW);
INIT_FOR(kBM, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_RKNPU)
INIT_FOR(kRKNPU, kFloat, kNCHW);
INIT_FOR(kRKNPU, kInt8, kNCHW);
INIT_FOR(kRKNPU, kAny, kNCHW);
INIT_FOR(kRKNPU, kAny, kAny);
#endif
#undef INIT_FOR
}
KernelRegistry &KernelRegistry::Global() {
static auto *x = new KernelRegistry;
return *x;
}
} // namespace lite
namespace lite {} // namespace lite
} // namespace paddle
......@@ -17,7 +17,6 @@
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <utility>
......@@ -33,19 +32,19 @@ using LiteType = paddle::lite::Type;
class OpKernelInfoCollector {
public:
static OpKernelInfoCollector &Global() {
static auto *x = new OpKernelInfoCollector;
static OpKernelInfoCollector& Global() {
static auto* x = new OpKernelInfoCollector;
return *x;
}
void AddOp2path(const std::string &op_name, const std::string &op_path) {
void AddOp2path(const std::string& op_name, const std::string& op_path) {
size_t index = op_path.find_last_of('/');
if (index != std::string::npos) {
op2path_.insert(std::pair<std::string, std::string>(
op_name, op_path.substr(index + 1)));
}
}
void AddKernel2path(const std::string &kernel_name,
const std::string &kernel_path) {
void AddKernel2path(const std::string& kernel_name,
const std::string& kernel_path) {
size_t index = kernel_path.find_last_of('/');
if (index != std::string::npos) {
kernel2path_.insert(std::pair<std::string, std::string>(
......@@ -53,13 +52,13 @@ class OpKernelInfoCollector {
}
}
void SetKernel2path(
const std::map<std::string, std::string> &kernel2path_map) {
const std::map<std::string, std::string>& kernel2path_map) {
kernel2path_ = kernel2path_map;
}
const std::map<std::string, std::string> &GetOp2PathDict() {
const std::map<std::string, std::string>& GetOp2PathDict() {
return op2path_;
}
const std::map<std::string, std::string> &GetKernel2PathDict() {
const std::map<std::string, std::string>& GetKernel2PathDict() {
return kernel2path_;
}
......@@ -71,409 +70,177 @@ class OpKernelInfoCollector {
namespace paddle {
namespace lite {
const std::map<std::string, std::string> &GetOp2PathDict();
using KernelFunc = std::function<void()>;
using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>;
class LiteOpRegistry final : public Factory<OpLite, std::shared_ptr<OpLite>> {
class OpLiteFactory {
public:
static LiteOpRegistry &Global() {
static auto *x = new LiteOpRegistry;
// Register a function to create an op
void RegisterCreator(const std::string& op_type,
std::function<std::shared_ptr<OpLite>()> fun) {
op_registry_[op_type] = fun;
}
static OpLiteFactory& Global() {
static OpLiteFactory* x = new OpLiteFactory;
return *x;
}
private:
LiteOpRegistry() = default;
std::shared_ptr<OpLite> Create(const std::string& op_type) const {
auto it = op_registry_.find(op_type);
if (it == op_registry_.end()) return nullptr;
return it->second();
}
std::string DebugString() const {
STL::stringstream ss;
for (const auto& item : op_registry_) {
ss << " - " << item.first << "\n";
}
return ss.str();
}
protected:
std::map<std::string, std::function<std::shared_ptr<OpLite>()>> op_registry_;
};
template <typename OpClass>
class OpLiteRegistor : public Registor<OpClass> {
using LiteOpRegistry = OpLiteFactory;
// Register OpLite by initializing a static OpLiteRegistrar instance
class OpLiteRegistrar {
public:
explicit OpLiteRegistor(const std::string &op_type)
: Registor<OpClass>([&] {
LiteOpRegistry::Global().Register(
op_type, [op_type]() -> std::unique_ptr<OpLite> {
return std::unique_ptr<OpLite>(new OpClass(op_type));
});
}) {}
OpLiteRegistrar(const std::string& op_type,
std::function<std::shared_ptr<OpLite>()> fun) {
OpLiteFactory::Global().RegisterCreator(op_type, fun);
}
// Touch function is used to guarantee registrar was initialized.
void touch() {}
};
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
using KernelRegistryForTarget =
Factory<KernelLite<Target, Precision, Layout>, std::unique_ptr<KernelBase>>;
class KernelRegistry final {
class KernelFactory {
public:
using any_kernel_registor_t =
variant<KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt32),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kNPU),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kNPU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kNPU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kAPU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kXPU),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kXPU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kXPU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kBM),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kBM),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kBM),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kRKNPU),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kRKNPU),
PRECISION(kAny),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kRKNPU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kRKNPU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kAny),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kAny),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16),
DATALAYOUT(kNCHW)> * //
>;
KernelRegistry();
static KernelRegistry &Global();
// Register a function to create kernels
void RegisterCreator(const std::string& op_type,
TargetType target,
PrecisionType precision,
DataLayoutType layout,
std::function<std::unique_ptr<KernelBase>()> fun) {
op_registry_[op_type][std::make_tuple(target, precision, layout)].push_back(
fun);
}
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
void Register(
const std::string &name,
typename KernelRegistryForTarget<Target, Precision, Layout>::creator_t
&&creator) {
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[std::make_tuple(Target, Precision, Layout)];
auto *reg = varient.template get<kernel_registor_t *>();
CHECK(reg) << "Can not be empty of " << name;
reg->Register(name, std::move(creator));
#ifdef LITE_ON_MODEL_OPTIMIZE_TOOL
kernel_info_map_[name].push_back(
std::make_tuple(Target, Precision, Layout));
#endif // LITE_ON_MODEL_OPTIMIZE_TOOL
static KernelFactory& Global() {
static KernelFactory* x = new KernelFactory;
return *x;
}
template <TargetType Target,
PrecisionType Precision = PRECISION(kFloat),
DataLayoutType Layout = DATALAYOUT(kNCHW)>
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
std::list<std::unique_ptr<KernelBase>> kernel_list;
std::tuple<TargetType, PrecisionType, DataLayoutType> temp_tuple(
Target, Precision, Layout);
if (registries_[temp_tuple].valid()) {
kernel_list =
registries_[temp_tuple].template get<kernel_registor_t *>()->Creates(
op_type);
/**
* Create all kernels belongs to an op.
*/
std::list<std::unique_ptr<KernelBase>> Create(const std::string& op_type) {
std::list<std::unique_ptr<KernelBase>> res;
if (op_registry_.find(op_type) == op_registry_.end()) return res;
auto& kernel_registry = op_registry_[op_type];
for (auto it = kernel_registry.begin(); it != kernel_registry.end(); ++it) {
for (auto& fun : it->second) {
res.emplace_back(fun());
}
}
return kernel_list;
return res;
}
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
/**
* Create a specific kernel. Return a list for API compatible.
*/
std::list<std::unique_ptr<KernelBase>> Create(const std::string& op_type,
TargetType target,
PrecisionType precision,
DataLayoutType layout);
DataLayoutType layout) {
std::list<std::unique_ptr<KernelBase>> res;
if (op_registry_.find(op_type) == op_registry_.end()) return res;
auto& kernel_registry = op_registry_[op_type];
auto it = kernel_registry.find(std::make_tuple(target, precision, layout));
if (it == kernel_registry.end()) return res;
for (auto& fun : it->second) {
res.emplace_back(fun());
}
return res;
}
std::string DebugString() const {
#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL
return "No more debug info";
#else // LITE_ON_MODEL_OPTIMIZE_TOOL
STL::stringstream ss;
ss << "\n";
ss << "Count of kernel kinds: ";
int count = 0;
for (auto &item : kernel_info_map_) {
count += item.second.size();
for (const auto& item : op_registry_) {
ss << " - " << item.first << "\n";
}
ss << count << "\n";
ss << "Count of registered kernels: " << kernel_info_map_.size() << "\n";
for (auto &item : kernel_info_map_) {
ss << "op: " << item.first << "\n";
for (auto &kernel : item.second) {
ss << " - (" << TargetToStr(std::get<0>(kernel)) << ",";
ss << PrecisionToStr(std::get<1>(kernel)) << ",";
ss << DataLayoutToStr(std::get<2>(kernel));
ss << ")";
ss << "\n";
}
}
return ss.str();
#endif // LITE_ON_MODEL_OPTIMIZE_TOOL
}
private:
mutable std::map<std::tuple<TargetType, PrecisionType, DataLayoutType>,
any_kernel_registor_t>
registries_;
#ifndef LITE_ON_TINY_PUBLISH
mutable std::map<
std::string,
std::vector<std::tuple<TargetType, PrecisionType, DataLayoutType>>>
kernel_info_map_;
#endif
protected:
// Outer map: op -> a map of kernel.
// Inner map: kernel -> creator function.
// Each kernel was represented by a combination of <TargetType, PrecisionType,
// DataLayoutType>
std::map<std::string,
std::map<std::tuple<TargetType, PrecisionType, DataLayoutType>,
std::list<std::function<std::unique_ptr<KernelBase>()>>>>
op_registry_;
};
template <TargetType target,
PrecisionType precision,
DataLayoutType layout,
typename KernelType>
class KernelRegistor : public lite::Registor<KernelType> {
using KernelRegistry = KernelFactory;
// Register Kernel by initializing a static KernelRegistrar instance
class KernelRegistrar {
public:
KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([=] {
KernelRegistry::Global().Register<target, precision, layout>(
op_type, [=]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType);
x->set_op_type(op_type);
x->set_alias(alias);
return x;
});
}) {}
KernelRegistrar(const std::string& op_type,
TargetType target,
PrecisionType precision,
DataLayoutType layout,
std::function<std::unique_ptr<KernelBase>()> fun) {
KernelFactory::Global().RegisterCreator(
op_type, target, precision, layout, fun);
}
// Touch function is used to guarantee registrar was initialized.
void touch() {}
};
} // namespace lite
} // namespace paddle
// Operator registry
#define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__
#define REGISTER_LITE_OP(op_type__, OpClass) \
static paddle::lite::OpLiteRegistor<OpClass> LITE_OP_REGISTER_INSTANCE( \
op_type__)(#op_type__); \
int touch_op_##op_type__() { \
OpKernelInfoCollector::Global().AddOp2path(#op_type__, __FILE__); \
return LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); \
// Register an op.
#define REGISTER_LITE_OP(op_type__, OpClass) \
static paddle::lite::OpLiteRegistrar op_type__##__registry( \
#op_type__, []() { \
return std::unique_ptr<paddle::lite::OpLite>(new OpClass(#op_type__)); \
}); \
int touch_op_##op_type__() { \
op_type__##__registry.touch(); \
OpKernelInfoCollector::Global().AddOp2path(#op_type__, __FILE__); \
return 0; \
}
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE( \
op_type__, target__, precision__, layout__, alias__) \
op_type__##__##target__##__##precision__##__##layout__##registor__instance__##alias__ // NOLINT
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
// Register a kernel.
#define REGISTER_LITE_KERNEL( \
op_type__, target__, precision__, layout__, KernelClass, alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__), \
KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE( \
op_type__, target__, precision__, layout__, alias__)(#op_type__, \
#alias__); \
static KernelClass LITE_KERNEL_INSTANCE( \
op_type__, target__, precision__, layout__, alias__); \
static paddle::lite::KernelRegistrar \
op_type__##target__##precision__##layout__##alias__##_kernel_registry( \
#op_type__, \
TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__), \
[]() { \
std::unique_ptr<KernelClass> x(new KernelClass); \
x->set_op_type(#op_type__); \
x->set_alias(#alias__); \
return x; \
}); \
int touch_##op_type__##target__##precision__##layout__##alias__() { \
op_type__##target__##precision__##layout__##alias__##_kernel_registry \
.touch(); \
OpKernelInfoCollector::Global().AddKernel2path( \
#op_type__ "," #target__ "," #precision__ "," #layout__ "," #alias__, \
__FILE__); \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \
.Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE( \
op_type__, target__, precision__, layout__, alias__) UNUSED = \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>( \
#op_type__ "/" #alias__)
#define LITE_KERNEL_INSTANCE( \
op_type__, target__, precision__, layout__, alias__) \
op_type__##target__##precision__##layout__##alias__
#define LITE_KERNEL_PARAM_INSTANCE( \
op_type__, target__, precision__, layout__, alias__) \
op_type__##target__##precision__##layout__##alias__##param_register
static auto \
op_type__##target__##precision__##layout__##alias__##param_register \
UNUSED = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>(#op_type__ "/" #alias__)
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/argmax_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/argmax_compute.h"
namespace paddle {
namespace lite {
......@@ -66,9 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
}
TEST(argmax_arm, retrive_op) {
auto argmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"arg_max");
auto argmax = KernelRegistry::Global().Create("arg_max");
ASSERT_FALSE(argmax.empty());
ASSERT_TRUE(argmax.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/axpy_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/axpy_compute.h"
namespace paddle {
namespace lite {
......@@ -61,8 +63,7 @@ void axpy_compute_ref(const operators::AxpyParam& param) {
}
TEST(axpy_arm, retrive_op) {
auto axpy =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("axpy");
auto axpy = KernelRegistry::Global().Create("axpy");
ASSERT_FALSE(axpy.empty());
ASSERT_TRUE(axpy.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/batch_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -78,9 +80,7 @@ void batch_norm_compute_ref(const operators::BatchNormParam& param) {
}
TEST(batch_norm_arm, retrive_op) {
auto batch_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"batch_norm");
auto batch_norm = KernelRegistry::Global().Create("batch_norm");
ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/concat_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/concat_compute.h"
namespace paddle {
namespace lite {
......@@ -221,8 +223,7 @@ TEST(concat_arm, compute_input_multi) {
}
TEST(concat, retrive_op) {
auto concat =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>("concat");
auto concat = KernelRegistry::Global().Create("concat");
ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/decode_bboxes_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/decode_bboxes_compute.h"
namespace paddle {
namespace lite {
......@@ -115,9 +117,7 @@ void decode_bboxes_compute_ref(const operators::DecodeBboxesParam& param) {
}
TEST(decode_bboxes_arm, retrive_op) {
auto decode_bboxes =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"decode_bboxes");
auto decode_bboxes = KernelRegistry::Global().Create("decode_bboxes");
ASSERT_FALSE(decode_bboxes.empty());
ASSERT_TRUE(decode_bboxes.front());
}
......
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/dropout_compute.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/dropout_compute.h"
namespace paddle {
namespace lite {
......@@ -30,9 +32,7 @@ TEST(dropout_arm, init) {
}
TEST(dropout, retrive_op) {
auto dropout =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"dropout");
auto dropout = KernelRegistry::Global().Create("dropout");
ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/elementwise_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/elementwise_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace arm {
TEST(elementwise_add_arm, retrive_op) {
auto elementwise_add =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_add");
auto elementwise_add = KernelRegistry::Global().Create("elementwise_add");
ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front());
}
......@@ -336,8 +336,7 @@ TEST(elementwise_add, compute) {
TEST(fusion_elementwise_add_activation_arm, retrive_op) {
auto fusion_elementwise_add_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_add_activation");
KernelRegistry::Global().Create("fusion_elementwise_add_activation");
ASSERT_FALSE(fusion_elementwise_add_activation.empty());
ASSERT_TRUE(fusion_elementwise_add_activation.front());
}
......@@ -435,9 +434,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
}
TEST(elementwise_mul_arm, retrive_op) {
auto elementwise_mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_mul");
auto elementwise_mul = KernelRegistry::Global().Create("elementwise_mul");
ASSERT_FALSE(elementwise_mul.empty());
ASSERT_TRUE(elementwise_mul.front());
}
......@@ -530,8 +527,7 @@ TEST(elementwise_mul, compute) {
TEST(fusion_elementwise_mul_activation_arm, retrive_op) {
auto fusion_elementwise_mul_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_mul_activation");
KernelRegistry::Global().Create("fusion_elementwise_mul_activation");
ASSERT_FALSE(fusion_elementwise_mul_activation.empty());
ASSERT_TRUE(fusion_elementwise_mul_activation.front());
}
......@@ -629,9 +625,7 @@ TEST(fusion_elementwise_mul_activation_arm, compute) {
}
TEST(elementwise_max_arm, retrive_op) {
auto elementwise_max =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_max");
auto elementwise_max = KernelRegistry::Global().Create("elementwise_max");
ASSERT_FALSE(elementwise_max.empty());
ASSERT_TRUE(elementwise_max.front());
}
......@@ -724,8 +718,7 @@ TEST(elementwise_max, compute) {
TEST(fusion_elementwise_max_activation_arm, retrive_op) {
auto fusion_elementwise_max_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_max_activation");
KernelRegistry::Global().Create("fusion_elementwise_max_activation");
ASSERT_FALSE(fusion_elementwise_max_activation.empty());
ASSERT_TRUE(fusion_elementwise_max_activation.front());
}
......@@ -823,9 +816,7 @@ TEST(fusion_elementwise_max_activation_arm, compute) {
}
TEST(elementwise_mod_int64_arm, retrive_op) {
auto elementwise_mod =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kInt64)>(
"elementwise_mod");
auto elementwise_mod = KernelRegistry::Global().Create("elementwise_mod");
ASSERT_FALSE(elementwise_mod.empty());
ASSERT_TRUE(elementwise_mod.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/layer_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -181,9 +183,7 @@ TEST(layer_norm_arm, compute) {
}
TEST(layer_norm, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"layer_norm");
auto layer_norm = KernelRegistry::Global().Create("layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lrn_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/lrn_compute.h"
namespace paddle {
namespace lite {
......@@ -133,8 +135,7 @@ void lrn_compute_ref(const operators::LrnParam& param) {
}
TEST(lrn_arm, retrive_op) {
auto lrn =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("lrn");
auto lrn = KernelRegistry::Global().Create("lrn");
ASSERT_FALSE(lrn.empty());
ASSERT_TRUE(lrn.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace arm {
TEST(merge_lod_tensor_arm, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"merge_lod_tensor");
auto kernel = KernelRegistry::Global().Create("merge_lod_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/mul_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/mul_compute.h"
namespace paddle {
namespace lite {
......@@ -69,8 +71,7 @@ void FillData(T* a,
}
TEST(mul_arm, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
auto mul = KernelRegistry::Global().Create("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/pool_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/pool_compute.h"
namespace paddle {
namespace lite {
......@@ -341,8 +343,7 @@ TEST(pool_arm, compute) {
}
TEST(pool_arm, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"pool2d");
auto pool = KernelRegistry::Global().Create("pool2d");
ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front());
}
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/scale_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/scale_compute.h"
namespace paddle {
namespace lite {
......@@ -103,8 +105,7 @@ TEST(scale_arm, compute) {
}
TEST(scale, retrive_op) {
auto scale =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("scale");
auto scale = KernelRegistry::Global().Create("scale");
ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/softmax_compute.h"
namespace paddle {
namespace lite {
......@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) {
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"softmax");
auto softmax = KernelRegistry::Global().Create("softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/split_compute.h"
#include <gtest/gtest.h>
#include <cstring>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/split_compute.h"
namespace paddle {
namespace lite {
......@@ -165,8 +167,7 @@ TEST(split_arm, compute) {
}
TEST(split, retrive_op) {
auto split =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("split");
auto split = KernelRegistry::Global().Create("split");
ASSERT_FALSE(split.empty());
ASSERT_TRUE(split.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/split_lod_tensor_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/split_lod_tensor_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace arm {
TEST(split_lod_tensor_arm, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"split_lod_tensor");
auto kernel = KernelRegistry::Global().Create("split_lod_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/transpose_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/transpose_compute.h"
namespace paddle {
namespace lite {
......@@ -121,9 +123,7 @@ TEST(transpose_arm, compute_shape_nchw) {
}
TEST(transpose, retrive_op) {
auto transpose =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"transpose");
auto transpose = KernelRegistry::Global().Create("transpose");
ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front());
}
......@@ -189,9 +189,7 @@ TEST(transpose2_arm, compute_shape_nchw) {
}
TEST(transpose2, retrive_op) {
auto transpose2 =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"transpose2");
auto transpose2 = KernelRegistry::Global().Create("transpose2");
ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/lookup_table_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/lookup_table_compute.h"
namespace paddle {
namespace lite {
......@@ -56,9 +58,7 @@ void LookupTableComputeRef(const operators::LookupTableParam& param) {
}
TEST(lookup_table_cuda, retrieve_op) {
auto lookup_table =
KernelRegistry::Global().Create<TARGET(kCUDA), PRECISION(kFloat)>(
"lookup_table");
auto lookup_table = KernelRegistry::Global().Create("lookup_table");
ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/activation_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/activation_compute.h"
namespace paddle {
namespace lite {
......@@ -37,8 +39,7 @@ void activation_compute_ref(const operators::ActivationParam& param) {
}
TEST(activation_fpga, retrive_op) {
auto activation =
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>("relu");
auto activation = KernelRegistry::Global().Create("relu");
ASSERT_FALSE(activation.empty());
ASSERT_TRUE(activation.front());
}
......
......@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/fc_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/fc_compute.h"
namespace paddle {
namespace lite {
......@@ -76,8 +78,7 @@ void FillData(T* a,
}
TEST(fc_fpga, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>("fc");
auto fc = KernelRegistry::Global().Create("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
......
......@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/pooling_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/pooling_compute.h"
namespace paddle {
namespace lite {
......@@ -277,8 +278,7 @@ TEST(pool_fpga, compute) {
}
TEST(pool_fpga, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>(
"pool2d");
auto pool = KernelRegistry::Global().Create("pool2d");
ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/softmax_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <vector>
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/softmax_compute.h"
namespace paddle {
namespace lite {
......@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) {
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>(
"softmax");
auto softmax = KernelRegistry::Global().Create("softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/activation_compute.cc"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
namespace paddle {
namespace lite {
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(relu_x86, retrive_op) {
auto relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
auto relu = KernelRegistry::Global().Create("relu");
ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
namespace paddle {
namespace lite {
......@@ -81,8 +83,7 @@ int get_max_len(const LoD& lod) {
TEST(attention_padding_mask_x86, retrive_op) {
auto attention_padding_mask =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"attention_padding_mask");
KernelRegistry::Global().Create("attention_padding_mask");
ASSERT_FALSE(attention_padding_mask.empty());
ASSERT_TRUE(attention_padding_mask.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/batch_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(batch_norm_x86, retrive_op) {
auto batch_norm =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"batch_norm");
auto batch_norm = KernelRegistry::Global().Create("batch_norm");
ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/cast_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/cast_compute.h"
namespace paddle {
namespace lite {
......@@ -25,8 +27,7 @@ namespace kernels {
namespace x86 {
TEST(cast_x86, retrive_op) {
auto cast =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("cast");
auto cast = KernelRegistry::Global().Create("cast");
ASSERT_FALSE(cast.empty());
ASSERT_TRUE(cast.front());
}
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/concat_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/concat_compute.h"
namespace paddle {
namespace lite {
......@@ -23,9 +25,7 @@ namespace kernels {
namespace x86 {
TEST(concat_x86, retrive_op) {
auto concat =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"concat");
auto concat = KernelRegistry::Global().Create("concat");
ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/conv_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/conv_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(conv_x86, retrive_op) {
auto conv2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"conv2d");
auto conv2d = KernelRegistry::Global().Create("conv2d");
ASSERT_FALSE(conv2d.empty());
ASSERT_TRUE(conv2d.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/dropout_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/dropout_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(dropout_x86, retrive_op) {
auto dropout =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"dropout");
auto dropout = KernelRegistry::Global().Create("dropout");
ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/elementwise_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(elementwise_add_x86, retrive_op) {
auto elementwise_add =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"elementwise_add");
auto elementwise_add = KernelRegistry::Global().Create("elementwise_add");
ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -26,8 +29,7 @@ namespace x86 {
TEST(fill_constant_batch_size_like_x86, retrive_op) {
auto fill_constant_batch_size_like =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"fill_constant_batch_size_like");
KernelRegistry::Global().Create("fill_constant_batch_size_like");
ASSERT_FALSE(fill_constant_batch_size_like.empty());
ASSERT_TRUE(fill_constant_batch_size_like.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/gather_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/gather_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(gather_x86, retrive_op) {
auto gather =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"gather");
auto gather = KernelRegistry::Global().Create("gather");
ASSERT_FALSE(gather.empty());
int cnt = 0;
for (auto item = gather.begin(); item != gather.end(); ++item) {
......
......@@ -13,10 +13,12 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(gelu_x86, retrive_op) {
auto gelu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gelu");
auto gelu = KernelRegistry::Global().Create("gelu");
ASSERT_FALSE(gelu.empty());
ASSERT_TRUE(gelu.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/gru_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/gru_compute.h"
namespace paddle {
namespace lite {
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(gru_x86, retrive_op) {
auto gru =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gru");
auto gru = KernelRegistry::Global().Create("gru");
ASSERT_FALSE(gru.empty());
ASSERT_TRUE(gru.front());
}
......
......@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/layer_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -74,9 +76,7 @@ std::vector<float> ref(lite::Tensor* x,
// layer_norm
TEST(layer_norm_x86, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"layer_norm");
auto layer_norm = KernelRegistry::Global().Create("layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.h"
......@@ -24,9 +26,7 @@ namespace kernels {
namespace x86 {
TEST(leaky_relu_x86, retrive_op) {
auto leaky_relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"leaky_relu");
auto leaky_relu = KernelRegistry::Global().Create("leaky_relu");
ASSERT_FALSE(leaky_relu.empty());
ASSERT_TRUE(leaky_relu.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(match_matrix_tensor_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"match_matrix_tensor");
auto kernel = KernelRegistry::Global().Create("match_matrix_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/matmul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/matmul_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(matmul_x86, retrive_op) {
auto matmul =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"matmul");
auto matmul = KernelRegistry::Global().Create("matmul");
ASSERT_FALSE(matmul.empty());
ASSERT_TRUE(matmul.front());
}
......
......@@ -12,21 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/mul_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(mul_x86, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("mul");
auto mul = KernelRegistry::Global().Create("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/pool_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/pool_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(pool_x86, retrive_op) {
auto pool2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"pool2d");
auto pool2d = KernelRegistry::Global().Create("pool2d");
ASSERT_FALSE(pool2d.empty());
ASSERT_TRUE(pool2d.front());
}
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.h"
......@@ -24,8 +26,7 @@ namespace kernels {
namespace x86 {
TEST(relu_x86, retrive_op) {
auto relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
auto relu = KernelRegistry::Global().Create("relu");
ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/reshape_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/reshape_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -26,9 +29,7 @@ namespace x86 {
// reshape
TEST(reshape_x86, retrive_op) {
auto reshape =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape");
auto reshape = KernelRegistry::Global().Create("reshape");
ASSERT_FALSE(reshape.empty());
ASSERT_TRUE(reshape.front());
}
......@@ -86,9 +87,7 @@ TEST(reshape_x86, run_test) {
// reshape2
TEST(reshape2_x86, retrive_op) {
auto reshape2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape2");
auto reshape2 = KernelRegistry::Global().Create("reshape2");
ASSERT_FALSE(reshape2.empty());
ASSERT_TRUE(reshape2.front());
}
......
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/scale_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/scale_compute.h"
namespace paddle {
namespace lite {
......@@ -24,8 +26,7 @@ namespace kernels {
namespace x86 {
TEST(scale_x86, retrive_op) {
auto scale =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("scale");
auto scale = KernelRegistry::Global().Create("scale");
ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_fc_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_fc_compute.h"
namespace paddle {
namespace lite {
......@@ -53,9 +55,7 @@ void fc_cpu_base(const lite::Tensor* X,
}
TEST(search_fc_x86, retrive_op) {
auto search_fc =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_fc");
auto search_fc = KernelRegistry::Global().Create("search_fc");
ASSERT_FALSE(search_fc.empty());
ASSERT_TRUE(search_fc.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_grnn_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_grnn_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(search_grnn_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_grnn");
auto kernel = KernelRegistry::Global().Create("search_grnn");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_group_padding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_group_padding_compute.h"
namespace paddle {
namespace lite {
......@@ -26,8 +28,7 @@ namespace x86 {
TEST(search_group_padding_x86, retrieve_op) {
auto search_group_padding =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_group_padding");
KernelRegistry::Global().Create("search_group_padding");
ASSERT_FALSE(search_group_padding.empty());
ASSERT_TRUE(search_group_padding.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_seq_depadding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_seq_depadding_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(search_seq_depadding_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_seq_depadding");
auto kernel = KernelRegistry::Global().Create("search_seq_depadding");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
namespace paddle {
namespace lite {
......@@ -77,8 +79,7 @@ void prepare_input(Tensor* x, const LoD& x_lod) {
TEST(sequence_arithmetic_x86, retrive_op) {
auto sequence_arithmetic =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_arithmetic");
KernelRegistry::Global().Create("sequence_arithmetic");
ASSERT_FALSE(sequence_arithmetic.empty());
ASSERT_TRUE(sequence_arithmetic.front());
}
......
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_concat_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_concat_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -94,9 +97,7 @@ static void sequence_concat_ref(const std::vector<lite::Tensor*>& xs,
} // namespace
TEST(sequence_concat_x86, retrive_op) {
auto sequence_concat =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_concat");
auto sequence_concat = KernelRegistry::Global().Create("sequence_concat");
ASSERT_FALSE(sequence_concat.empty());
ASSERT_TRUE(sequence_concat.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_expand_as_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_expand_as_compute.h"
namespace paddle {
namespace lite {
......@@ -27,8 +29,7 @@ namespace x86 {
TEST(sequence_expand_as_x86, retrive_op) {
auto sequence_expand_as =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_expand_as");
KernelRegistry::Global().Create("sequence_expand_as");
ASSERT_FALSE(sequence_expand_as.empty());
ASSERT_TRUE(sequence_expand_as.front());
}
......
......@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_pool_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_pool_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(sequence_pool_x86, retrive_op) {
auto sequence_pool =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_pool");
auto sequence_pool = KernelRegistry::Global().Create("sequence_pool");
ASSERT_FALSE(sequence_pool.empty());
ASSERT_TRUE(sequence_pool.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_reverse_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_reverse_compute.h"
namespace paddle {
namespace lite {
......@@ -44,9 +46,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
} // namespace
TEST(sequence_reverse_x86, retrive_op) {
auto sequence_reverse =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_reverse");
auto sequence_reverse = KernelRegistry::Global().Create("sequence_reverse");
ASSERT_FALSE(sequence_reverse.empty());
ASSERT_TRUE(sequence_reverse.front());
}
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/shape_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/shape_compute.h"
namespace paddle {
namespace lite {
......@@ -23,8 +25,7 @@ namespace kernels {
namespace x86 {
TEST(shape_x86, retrive_op) {
auto shape =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("shape");
auto shape = KernelRegistry::Global().Create("shape");
ASSERT_FALSE(shape.empty());
ASSERT_TRUE(shape.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/slice_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/slice_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -79,8 +82,7 @@ static void slice_ref(const float* input,
}
TEST(slice_x86, retrive_op) {
auto slice =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("slice");
auto slice = KernelRegistry::Global().Create("slice");
ASSERT_FALSE(slice.empty());
ASSERT_TRUE(slice.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/softmax_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/softmax_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(softmax_x86, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"softmax");
auto softmax = KernelRegistry::Global().Create("softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
......
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/stack_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/stack_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -25,8 +28,7 @@ namespace x86 {
// stack
TEST(stack_x86, retrive_op) {
auto stack =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("stack");
auto stack = KernelRegistry::Global().Create("stack");
ASSERT_FALSE(stack.empty());
ASSERT_TRUE(stack.front());
}
......
......@@ -13,10 +13,12 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(tanh_x86, retrive_op) {
auto tanh =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("tanh");
auto tanh = KernelRegistry::Global().Create("tanh");
ASSERT_FALSE(tanh.empty());
ASSERT_TRUE(tanh.front());
}
......
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/transpose_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -25,9 +28,7 @@ namespace x86 {
// transpose
TEST(transpose_x86, retrive_op) {
auto transpose =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose");
auto transpose = KernelRegistry::Global().Create("transpose");
ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front());
}
......@@ -75,9 +76,7 @@ TEST(transpose_x86, run_test) {
// transpose2
TEST(transpose2_x86, retrive_op) {
auto transpose2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose2");
auto transpose2 = KernelRegistry::Global().Create("transpose2");
ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/var_conv_2d_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/x86/var_conv_2d_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -197,9 +200,7 @@ static void var_conv_2d_ref(const lite::Tensor* bottom,
}
TEST(var_conv_2d_x86, retrive_op) {
auto var_conv_2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"var_conv_2d");
auto var_conv_2d = KernelRegistry::Global().Create("var_conv_2d");
ASSERT_FALSE(var_conv_2d.empty());
ASSERT_TRUE(var_conv_2d.front());
}
......
......@@ -24,7 +24,6 @@
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h"
#include "lite/utils/variant.h"
/*
* This file contains all the argument parameter data structure for operators.
*/
......
......@@ -14,10 +14,16 @@
#pragma once
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "lite/utils/any.h"
#include "lite/utils/check.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/factory.h"
#include "lite/utils/hash.h"
#include "lite/utils/io.h"
#include "lite/utils/macros.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "lite/utils/all.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
/*
* Factor for any Type creator.
*
* Usage:
*
* struct SomeType;
* // Register a creator.
* Factory<SomeType>::Global().Register("some_key", [] ->
* std::unique_ptr<SomeType> { ... });
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template <typename ItemType, typename ItemTypePtr>
class Factory {
public:
using item_t = ItemType;
using self_t = Factory<item_t, ItemTypePtr>;
using item_ptr_t = ItemTypePtr;
using creator_t = std::function<item_ptr_t()>;
static Factory& Global() {
static Factory* x = new self_t;
return *x;
}
void Register(const std::string& op_type, creator_t&& creator) {
creators_[op_type].emplace_back(std::move(creator));
}
item_ptr_t Create(const std::string& op_type) const {
auto res = Creates(op_type);
if (res.empty()) return nullptr;
CHECK_EQ(res.size(), 1UL) << "Get multiple Op for type " << op_type;
return std::move(res.front());
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
std::list<item_ptr_t> res;
auto it = creators_.find(op_type);
if (it == creators_.end()) return res;
for (auto& c : it->second) {
res.emplace_back(c());
}
return res;
}
std::string DebugString() const {
STL::stringstream ss;
for (const auto& item : creators_) {
ss << " - " << item.first << "\n";
}
return ss.str();
}
protected:
std::map<std::string, std::list<creator_t>> creators_;
};
/* A helper function to help run a lambda at the start.
*/
template <typename Type>
class Registor {
public:
explicit Registor(std::function<void()>&& functor) { functor(); }
// Touch will do nothing.
int Touch() { return 0; }
};
} // namespace lite
} // namespace paddle
Subproject commit ac203b20926b13a35ff85277d2e5d3c38698eee8
Subproject commit 6df40a2471737b27271bdd9b900ab5f3aec746c7
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册