未验证 提交 db7639ca 编写于 作者: S Shibo Tao 提交者: GitHub

optimize register mechanism (#3745)

* refactor register mechanism, current so size: 1.20MB. test=develop

* fix KernelRegistry::Global().Create. test=develop

* fix cpplint errors. test=develop

* fix test_subgraph_pass bug. test=develop

* register kernel with target,precision,datalayout combination. test=develop

* fix test_paddle_api no op found bug. test=develop

* enhance comment

* fix lite/kernels/arm/elementwise_compute_test.cc. test=develop

* fix code style

* revert format of unchanged files. test=develop

* fix code format according to cpplint 1.5.1. test=develop

* remove redundant include header. test=develop
上级 732bb91b
......@@ -15,8 +15,6 @@
#include "lite/api/light_api.h"
#include <algorithm>
#include <map>
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
namespace paddle {
namespace lite {
......
......@@ -15,8 +15,11 @@
#include "lite/api/paddle_api.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/io.h"
DEFINE_string(model_dir, "", "");
namespace paddle {
......
......@@ -13,8 +13,12 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
......
......@@ -17,277 +17,5 @@
#include <set>
namespace paddle {
namespace lite {
const std::map<std::string, std::string> &GetOp2PathDict() {
return OpKernelInfoCollector::Global().GetOp2PathDict();
}
std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type,
TargetType target,
PrecisionType precision,
DataLayoutType layout) {
Place place{target, precision, layout};
VLOG(5) << "creating " << op_type << " kernel for " << place.DebugString();
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
case DATALAYOUT(kNHWC): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNHWC)>(op_type); \
case DATALAYOUT(kImageDefault): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageDefault)>(op_type); \
case DATALAYOUT(kImageFolder): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageFolder)>(op_type); \
case DATALAYOUT(kImageNW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageNW)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kFP16): \
CREATE_KERNEL1(target__, kFP16); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
case PRECISION(kInt32): \
CREATE_KERNEL1(target__, kInt32); \
case PRECISION(kInt64): \
CREATE_KERNEL1(target__, kInt64); \
default: \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
}
switch (target) {
case TARGET(kHost): {
CREATE_KERNEL(kHost);
} break;
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_X86)
case TARGET(kX86): {
CREATE_KERNEL(kX86);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_CUDA)
case TARGET(kCUDA): {
CREATE_KERNEL(kCUDA);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_ARM)
case TARGET(kARM): {
CREATE_KERNEL(kARM);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_OPENCL)
case TARGET(kOpenCL): {
CREATE_KERNEL(kOpenCL);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_NPU)
case TARGET(kNPU): {
CREATE_KERNEL(kNPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_APU)
case TARGET(kAPU): {
CREATE_KERNEL(kAPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_XPU)
case TARGET(kXPU): {
CREATE_KERNEL(kXPU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_FPGA)
case TARGET(kFPGA): {
CREATE_KERNEL(kFPGA);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_BM)
case TARGET(kBM): {
CREATE_KERNEL(kBM);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_MLU)
case TARGET(kMLU): {
CREATE_KERNEL(kMLU);
} break;
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_RKNPU)
case TARGET(kRKNPU): {
CREATE_KERNEL(kRKNPU);
} break;
#endif
default:
CHECK(false) << "not supported kernel target " << TargetToStr(target);
}
#undef CREATE_KERNEL
return std::list<std::unique_ptr<KernelBase>>();
}
KernelRegistry::KernelRegistry() : registries_() {
#define INIT_FOR(target__, precision__, layout__) \
registries_[std::make_tuple(TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__))] \
.set<KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_CUDA)
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kFloat, kNHWC);
INIT_FOR(kCUDA, kInt8, kNCHW);
INIT_FOR(kCUDA, kFP16, kNCHW);
INIT_FOR(kCUDA, kFP16, kNHWC);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
INIT_FOR(kCUDA, kInt64, kNCHW);
INIT_FOR(kCUDA, kInt64, kNHWC);
INIT_FOR(kCUDA, kInt32, kNCHW);
INIT_FOR(kCUDA, kInt32, kNHWC);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_MLU)
INIT_FOR(kMLU, kFloat, kNHWC);
INIT_FOR(kMLU, kFloat, kNCHW);
INIT_FOR(kMLU, kFP16, kNHWC);
INIT_FOR(kMLU, kFP16, kNCHW);
INIT_FOR(kMLU, kInt8, kNHWC);
INIT_FOR(kMLU, kInt8, kNCHW);
INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW);
#endif
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kAny, kNHWC);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kHost, kBool, kNCHW);
INIT_FOR(kHost, kBool, kNHWC);
INIT_FOR(kHost, kBool, kAny);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
INIT_FOR(kHost, kFloat, kAny);
INIT_FOR(kHost, kFP16, kNCHW);
INIT_FOR(kHost, kFP16, kNHWC);
INIT_FOR(kHost, kFP16, kAny);
INIT_FOR(kHost, kInt8, kNCHW);
INIT_FOR(kHost, kInt8, kNHWC);
INIT_FOR(kHost, kInt8, kAny);
INIT_FOR(kHost, kInt16, kNCHW);
INIT_FOR(kHost, kInt16, kNHWC);
INIT_FOR(kHost, kInt16, kAny);
INIT_FOR(kHost, kInt32, kNCHW);
INIT_FOR(kHost, kInt32, kNHWC);
INIT_FOR(kHost, kInt32, kAny);
INIT_FOR(kHost, kInt64, kNCHW);
INIT_FOR(kHost, kInt64, kNHWC);
INIT_FOR(kHost, kInt64, kAny);
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_X86)
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_ARM)
INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kFloat, kNHWC);
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kInt8, kNHWC);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW);
INIT_FOR(kARM, kInt64, kNCHW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_OPENCL)
INIT_FOR(kOpenCL, kFloat, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNHWC);
INIT_FOR(kOpenCL, kAny, kNCHW);
INIT_FOR(kOpenCL, kAny, kNHWC);
INIT_FOR(kOpenCL, kFloat, kAny);
INIT_FOR(kOpenCL, kInt8, kNCHW);
INIT_FOR(kOpenCL, kAny, kAny);
INIT_FOR(kOpenCL, kFP16, kNCHW);
INIT_FOR(kOpenCL, kFP16, kNHWC);
INIT_FOR(kOpenCL, kFP16, kImageDefault);
INIT_FOR(kOpenCL, kFP16, kImageFolder);
INIT_FOR(kOpenCL, kFP16, kImageNW);
INIT_FOR(kOpenCL, kFloat, kImageDefault);
INIT_FOR(kOpenCL, kFloat, kImageFolder);
INIT_FOR(kOpenCL, kFloat, kImageNW);
INIT_FOR(kOpenCL, kAny, kImageDefault);
INIT_FOR(kOpenCL, kAny, kImageFolder);
INIT_FOR(kOpenCL, kAny, kImageNW);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_NPU)
INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kFloat, kNHWC);
INIT_FOR(kNPU, kInt8, kNCHW);
INIT_FOR(kNPU, kInt8, kNHWC);
INIT_FOR(kNPU, kAny, kNCHW);
INIT_FOR(kNPU, kAny, kNHWC);
INIT_FOR(kNPU, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_APU)
INIT_FOR(kAPU, kInt8, kNCHW);
INIT_FOR(kXPU, kFloat, kNCHW);
INIT_FOR(kXPU, kInt8, kNCHW);
INIT_FOR(kXPU, kAny, kNCHW);
INIT_FOR(kXPU, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_FPGA)
INIT_FOR(kFPGA, kFP16, kNHWC);
INIT_FOR(kFPGA, kFP16, kAny);
INIT_FOR(kFPGA, kFloat, kNHWC);
INIT_FOR(kFPGA, kAny, kNHWC);
INIT_FOR(kFPGA, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_BM)
INIT_FOR(kBM, kFloat, kNCHW);
INIT_FOR(kBM, kInt8, kNCHW);
INIT_FOR(kBM, kAny, kNCHW);
INIT_FOR(kBM, kAny, kAny);
#endif
#if !defined(LITE_ON_TINY_PUBLISH) || defined(LITE_WITH_RKNPU)
INIT_FOR(kRKNPU, kFloat, kNCHW);
INIT_FOR(kRKNPU, kInt8, kNCHW);
INIT_FOR(kRKNPU, kAny, kNCHW);
INIT_FOR(kRKNPU, kAny, kAny);
#endif
#undef INIT_FOR
}
KernelRegistry &KernelRegistry::Global() {
static auto *x = new KernelRegistry;
return *x;
}
} // namespace lite
namespace lite {} // namespace lite
} // namespace paddle
此差异已折叠。
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/argmax_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/argmax_compute.h"
namespace paddle {
namespace lite {
......@@ -66,9 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
}
TEST(argmax_arm, retrive_op) {
auto argmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"arg_max");
auto argmax = KernelRegistry::Global().Create("arg_max");
ASSERT_FALSE(argmax.empty());
ASSERT_TRUE(argmax.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/axpy_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/axpy_compute.h"
namespace paddle {
namespace lite {
......@@ -61,8 +63,7 @@ void axpy_compute_ref(const operators::AxpyParam& param) {
}
TEST(axpy_arm, retrive_op) {
auto axpy =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("axpy");
auto axpy = KernelRegistry::Global().Create("axpy");
ASSERT_FALSE(axpy.empty());
ASSERT_TRUE(axpy.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/batch_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -78,9 +80,7 @@ void batch_norm_compute_ref(const operators::BatchNormParam& param) {
}
TEST(batch_norm_arm, retrive_op) {
auto batch_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"batch_norm");
auto batch_norm = KernelRegistry::Global().Create("batch_norm");
ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/concat_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/concat_compute.h"
namespace paddle {
namespace lite {
......@@ -221,8 +223,7 @@ TEST(concat_arm, compute_input_multi) {
}
TEST(concat, retrive_op) {
auto concat =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>("concat");
auto concat = KernelRegistry::Global().Create("concat");
ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/decode_bboxes_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/decode_bboxes_compute.h"
namespace paddle {
namespace lite {
......@@ -115,9 +117,7 @@ void decode_bboxes_compute_ref(const operators::DecodeBboxesParam& param) {
}
TEST(decode_bboxes_arm, retrive_op) {
auto decode_bboxes =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"decode_bboxes");
auto decode_bboxes = KernelRegistry::Global().Create("decode_bboxes");
ASSERT_FALSE(decode_bboxes.empty());
ASSERT_TRUE(decode_bboxes.front());
}
......
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/dropout_compute.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/dropout_compute.h"
namespace paddle {
namespace lite {
......@@ -30,9 +32,7 @@ TEST(dropout_arm, init) {
}
TEST(dropout, retrive_op) {
auto dropout =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"dropout");
auto dropout = KernelRegistry::Global().Create("dropout");
ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/elementwise_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/elementwise_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace arm {
TEST(elementwise_add_arm, retrive_op) {
auto elementwise_add =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_add");
auto elementwise_add = KernelRegistry::Global().Create("elementwise_add");
ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front());
}
......@@ -336,8 +336,7 @@ TEST(elementwise_add, compute) {
TEST(fusion_elementwise_add_activation_arm, retrive_op) {
auto fusion_elementwise_add_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_add_activation");
KernelRegistry::Global().Create("fusion_elementwise_add_activation");
ASSERT_FALSE(fusion_elementwise_add_activation.empty());
ASSERT_TRUE(fusion_elementwise_add_activation.front());
}
......@@ -435,9 +434,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
}
TEST(elementwise_mul_arm, retrive_op) {
auto elementwise_mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_mul");
auto elementwise_mul = KernelRegistry::Global().Create("elementwise_mul");
ASSERT_FALSE(elementwise_mul.empty());
ASSERT_TRUE(elementwise_mul.front());
}
......@@ -530,8 +527,7 @@ TEST(elementwise_mul, compute) {
TEST(fusion_elementwise_mul_activation_arm, retrive_op) {
auto fusion_elementwise_mul_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_mul_activation");
KernelRegistry::Global().Create("fusion_elementwise_mul_activation");
ASSERT_FALSE(fusion_elementwise_mul_activation.empty());
ASSERT_TRUE(fusion_elementwise_mul_activation.front());
}
......@@ -629,9 +625,7 @@ TEST(fusion_elementwise_mul_activation_arm, compute) {
}
TEST(elementwise_max_arm, retrive_op) {
auto elementwise_max =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"elementwise_max");
auto elementwise_max = KernelRegistry::Global().Create("elementwise_max");
ASSERT_FALSE(elementwise_max.empty());
ASSERT_TRUE(elementwise_max.front());
}
......@@ -724,8 +718,7 @@ TEST(elementwise_max, compute) {
TEST(fusion_elementwise_max_activation_arm, retrive_op) {
auto fusion_elementwise_max_activation =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"fusion_elementwise_max_activation");
KernelRegistry::Global().Create("fusion_elementwise_max_activation");
ASSERT_FALSE(fusion_elementwise_max_activation.empty());
ASSERT_TRUE(fusion_elementwise_max_activation.front());
}
......@@ -823,9 +816,7 @@ TEST(fusion_elementwise_max_activation_arm, compute) {
}
TEST(elementwise_mod_int64_arm, retrive_op) {
auto elementwise_mod =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kInt64)>(
"elementwise_mod");
auto elementwise_mod = KernelRegistry::Global().Create("elementwise_mod");
ASSERT_FALSE(elementwise_mod.empty());
ASSERT_TRUE(elementwise_mod.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/layer_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -181,9 +183,7 @@ TEST(layer_norm_arm, compute) {
}
TEST(layer_norm, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"layer_norm");
auto layer_norm = KernelRegistry::Global().Create("layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lrn_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/lrn_compute.h"
namespace paddle {
namespace lite {
......@@ -133,8 +135,7 @@ void lrn_compute_ref(const operators::LrnParam& param) {
}
TEST(lrn_arm, retrive_op) {
auto lrn =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("lrn");
auto lrn = KernelRegistry::Global().Create("lrn");
ASSERT_FALSE(lrn.empty());
ASSERT_TRUE(lrn.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace arm {
TEST(merge_lod_tensor_arm, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"merge_lod_tensor");
auto kernel = KernelRegistry::Global().Create("merge_lod_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/mul_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/mul_compute.h"
namespace paddle {
namespace lite {
......@@ -69,8 +71,7 @@ void FillData(T* a,
}
TEST(mul_arm, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("mul");
auto mul = KernelRegistry::Global().Create("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/pool_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/pool_compute.h"
namespace paddle {
namespace lite {
......@@ -341,8 +343,7 @@ TEST(pool_arm, compute) {
}
TEST(pool_arm, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"pool2d");
auto pool = KernelRegistry::Global().Create("pool2d");
ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front());
}
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/scale_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/scale_compute.h"
namespace paddle {
namespace lite {
......@@ -103,8 +105,7 @@ TEST(scale_arm, compute) {
}
TEST(scale, retrive_op) {
auto scale =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("scale");
auto scale = KernelRegistry::Global().Create("scale");
ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/softmax_compute.h"
namespace paddle {
namespace lite {
......@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) {
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"softmax");
auto softmax = KernelRegistry::Global().Create("softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/split_compute.h"
#include <gtest/gtest.h>
#include <cstring>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/split_compute.h"
namespace paddle {
namespace lite {
......@@ -165,8 +167,7 @@ TEST(split_arm, compute) {
}
TEST(split, retrive_op) {
auto split =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("split");
auto split = KernelRegistry::Global().Create("split");
ASSERT_FALSE(split.empty());
ASSERT_TRUE(split.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/split_lod_tensor_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/split_lod_tensor_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace arm {
TEST(split_lod_tensor_arm, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"split_lod_tensor");
auto kernel = KernelRegistry::Global().Create("split_lod_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/transpose_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/arm/transpose_compute.h"
namespace paddle {
namespace lite {
......@@ -121,9 +123,7 @@ TEST(transpose_arm, compute_shape_nchw) {
}
TEST(transpose, retrive_op) {
auto transpose =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"transpose");
auto transpose = KernelRegistry::Global().Create("transpose");
ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front());
}
......@@ -189,9 +189,7 @@ TEST(transpose2_arm, compute_shape_nchw) {
}
TEST(transpose2, retrive_op) {
auto transpose2 =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"transpose2");
auto transpose2 = KernelRegistry::Global().Create("transpose2");
ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front());
}
......
......@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/lookup_table_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/lookup_table_compute.h"
namespace paddle {
namespace lite {
......@@ -56,9 +58,7 @@ void LookupTableComputeRef(const operators::LookupTableParam& param) {
}
TEST(lookup_table_cuda, retrieve_op) {
auto lookup_table =
KernelRegistry::Global().Create<TARGET(kCUDA), PRECISION(kFloat)>(
"lookup_table");
auto lookup_table = KernelRegistry::Global().Create("lookup_table");
ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/activation_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/activation_compute.h"
namespace paddle {
namespace lite {
......@@ -37,8 +39,7 @@ void activation_compute_ref(const operators::ActivationParam& param) {
}
TEST(activation_fpga, retrive_op) {
auto activation =
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>("relu");
auto activation = KernelRegistry::Global().Create("relu");
ASSERT_FALSE(activation.empty());
ASSERT_TRUE(activation.front());
}
......
......@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/fc_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/fc_compute.h"
namespace paddle {
namespace lite {
......@@ -76,8 +78,7 @@ void FillData(T* a,
}
TEST(fc_fpga, retrive_op) {
auto fc =
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>("fc");
auto fc = KernelRegistry::Global().Create("fc");
ASSERT_FALSE(fc.empty());
ASSERT_TRUE(fc.front());
}
......
......@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/pooling_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/pooling_compute.h"
namespace paddle {
namespace lite {
......@@ -277,8 +278,7 @@ TEST(pool_fpga, compute) {
}
TEST(pool_fpga, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>(
"pool2d");
auto pool = KernelRegistry::Global().Create("pool2d");
ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/fpga/softmax_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <vector>
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/core/op_registry.h"
#include "lite/kernels/fpga/softmax_compute.h"
namespace paddle {
namespace lite {
......@@ -121,9 +123,7 @@ TEST(softmax_arm, compute) {
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kFPGA), PRECISION(kFP16)>(
"softmax");
auto softmax = KernelRegistry::Global().Create("softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/activation_compute.cc"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
namespace paddle {
namespace lite {
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(relu_x86, retrive_op) {
auto relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
auto relu = KernelRegistry::Global().Create("relu");
ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/attention_padding_mask_compute.cc"
namespace paddle {
namespace lite {
......@@ -81,8 +83,7 @@ int get_max_len(const LoD& lod) {
TEST(attention_padding_mask_x86, retrive_op) {
auto attention_padding_mask =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"attention_padding_mask");
KernelRegistry::Global().Create("attention_padding_mask");
ASSERT_FALSE(attention_padding_mask.empty());
ASSERT_TRUE(attention_padding_mask.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/batch_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(batch_norm_x86, retrive_op) {
auto batch_norm =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"batch_norm");
auto batch_norm = KernelRegistry::Global().Create("batch_norm");
ASSERT_FALSE(batch_norm.empty());
ASSERT_TRUE(batch_norm.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/cast_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/cast_compute.h"
namespace paddle {
namespace lite {
......@@ -25,8 +27,7 @@ namespace kernels {
namespace x86 {
TEST(cast_x86, retrive_op) {
auto cast =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("cast");
auto cast = KernelRegistry::Global().Create("cast");
ASSERT_FALSE(cast.empty());
ASSERT_TRUE(cast.front());
}
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/concat_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/concat_compute.h"
namespace paddle {
namespace lite {
......@@ -23,9 +25,7 @@ namespace kernels {
namespace x86 {
TEST(concat_x86, retrive_op) {
auto concat =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"concat");
auto concat = KernelRegistry::Global().Create("concat");
ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/conv_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/conv_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(conv_x86, retrive_op) {
auto conv2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"conv2d");
auto conv2d = KernelRegistry::Global().Create("conv2d");
ASSERT_FALSE(conv2d.empty());
ASSERT_TRUE(conv2d.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/dropout_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/dropout_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(dropout_x86, retrive_op) {
auto dropout =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"dropout");
auto dropout = KernelRegistry::Global().Create("dropout");
ASSERT_FALSE(dropout.empty());
ASSERT_TRUE(dropout.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/elementwise_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(elementwise_add_x86, retrive_op) {
auto elementwise_add =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"elementwise_add");
auto elementwise_add = KernelRegistry::Global().Create("elementwise_add");
ASSERT_FALSE(elementwise_add.empty());
ASSERT_TRUE(elementwise_add.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -26,8 +29,7 @@ namespace x86 {
TEST(fill_constant_batch_size_like_x86, retrive_op) {
auto fill_constant_batch_size_like =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"fill_constant_batch_size_like");
KernelRegistry::Global().Create("fill_constant_batch_size_like");
ASSERT_FALSE(fill_constant_batch_size_like.empty());
ASSERT_TRUE(fill_constant_batch_size_like.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/gather_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/gather_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(gather_x86, retrive_op) {
auto gather =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"gather");
auto gather = KernelRegistry::Global().Create("gather");
ASSERT_FALSE(gather.empty());
int cnt = 0;
for (auto item = gather.begin(); item != gather.end(); ++item) {
......
......@@ -13,10 +13,12 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(gelu_x86, retrive_op) {
auto gelu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gelu");
auto gelu = KernelRegistry::Global().Create("gelu");
ASSERT_FALSE(gelu.empty());
ASSERT_TRUE(gelu.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/gru_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/gru_compute.h"
namespace paddle {
namespace lite {
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(gru_x86, retrive_op) {
auto gru =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("gru");
auto gru = KernelRegistry::Global().Create("gru");
ASSERT_FALSE(gru.empty());
ASSERT_TRUE(gru.front());
}
......
......@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/layer_norm_compute.h"
namespace paddle {
namespace lite {
......@@ -74,9 +76,7 @@ std::vector<float> ref(lite::Tensor* x,
// layer_norm
TEST(layer_norm_x86, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"layer_norm");
auto layer_norm = KernelRegistry::Global().Create("layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.h"
......@@ -24,9 +26,7 @@ namespace kernels {
namespace x86 {
TEST(leaky_relu_x86, retrive_op) {
auto leaky_relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"leaky_relu");
auto leaky_relu = KernelRegistry::Global().Create("leaky_relu");
ASSERT_FALSE(leaky_relu.empty());
ASSERT_TRUE(leaky_relu.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(match_matrix_tensor_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"match_matrix_tensor");
auto kernel = KernelRegistry::Global().Create("match_matrix_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/matmul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/matmul_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(matmul_x86, retrive_op) {
auto matmul =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"matmul");
auto matmul = KernelRegistry::Global().Create("matmul");
ASSERT_FALSE(matmul.empty());
ASSERT_TRUE(matmul.front());
}
......
......@@ -12,21 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/mul_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(mul_x86, retrive_op) {
auto mul =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("mul");
auto mul = KernelRegistry::Global().Create("mul");
ASSERT_FALSE(mul.empty());
ASSERT_TRUE(mul.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/pool_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/pool_compute.h"
namespace paddle {
namespace lite {
......@@ -26,9 +28,7 @@ namespace kernels {
namespace x86 {
TEST(pool_x86, retrive_op) {
auto pool2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"pool2d");
auto pool2d = KernelRegistry::Global().Create("pool2d");
ASSERT_FALSE(pool2d.empty());
ASSERT_TRUE(pool2d.front());
}
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.h"
......@@ -24,8 +26,7 @@ namespace kernels {
namespace x86 {
TEST(relu_x86, retrive_op) {
auto relu =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("relu");
auto relu = KernelRegistry::Global().Create("relu");
ASSERT_FALSE(relu.empty());
ASSERT_TRUE(relu.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/reshape_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/reshape_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -26,9 +29,7 @@ namespace x86 {
// reshape
TEST(reshape_x86, retrive_op) {
auto reshape =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape");
auto reshape = KernelRegistry::Global().Create("reshape");
ASSERT_FALSE(reshape.empty());
ASSERT_TRUE(reshape.front());
}
......@@ -86,9 +87,7 @@ TEST(reshape_x86, run_test) {
// reshape2
TEST(reshape2_x86, retrive_op) {
auto reshape2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape2");
auto reshape2 = KernelRegistry::Global().Create("reshape2");
ASSERT_FALSE(reshape2.empty());
ASSERT_TRUE(reshape2.front());
}
......
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/scale_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/scale_compute.h"
namespace paddle {
namespace lite {
......@@ -24,8 +26,7 @@ namespace kernels {
namespace x86 {
TEST(scale_x86, retrive_op) {
auto scale =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("scale");
auto scale = KernelRegistry::Global().Create("scale");
ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_fc_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_fc_compute.h"
namespace paddle {
namespace lite {
......@@ -53,9 +55,7 @@ void fc_cpu_base(const lite::Tensor* X,
}
TEST(search_fc_x86, retrive_op) {
auto search_fc =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_fc");
auto search_fc = KernelRegistry::Global().Create("search_fc");
ASSERT_FALSE(search_fc.empty());
ASSERT_TRUE(search_fc.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_grnn_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_grnn_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(search_grnn_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_grnn");
auto kernel = KernelRegistry::Global().Create("search_grnn");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_group_padding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_group_padding_compute.h"
namespace paddle {
namespace lite {
......@@ -26,8 +28,7 @@ namespace x86 {
TEST(search_group_padding_x86, retrieve_op) {
auto search_group_padding =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_group_padding");
KernelRegistry::Global().Create("search_group_padding");
ASSERT_FALSE(search_group_padding.empty());
ASSERT_TRUE(search_group_padding.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_seq_depadding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/search_seq_depadding_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(search_seq_depadding_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_seq_depadding");
auto kernel = KernelRegistry::Global().Create("search_seq_depadding");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
namespace paddle {
namespace lite {
......@@ -77,8 +79,7 @@ void prepare_input(Tensor* x, const LoD& x_lod) {
TEST(sequence_arithmetic_x86, retrive_op) {
auto sequence_arithmetic =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_arithmetic");
KernelRegistry::Global().Create("sequence_arithmetic");
ASSERT_FALSE(sequence_arithmetic.empty());
ASSERT_TRUE(sequence_arithmetic.front());
}
......
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_concat_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_concat_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -94,9 +97,7 @@ static void sequence_concat_ref(const std::vector<lite::Tensor*>& xs,
} // namespace
TEST(sequence_concat_x86, retrive_op) {
auto sequence_concat =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_concat");
auto sequence_concat = KernelRegistry::Global().Create("sequence_concat");
ASSERT_FALSE(sequence_concat.empty());
ASSERT_TRUE(sequence_concat.front());
}
......
......@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_expand_as_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_expand_as_compute.h"
namespace paddle {
namespace lite {
......@@ -27,8 +29,7 @@ namespace x86 {
TEST(sequence_expand_as_x86, retrive_op) {
auto sequence_expand_as =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_expand_as");
KernelRegistry::Global().Create("sequence_expand_as");
ASSERT_FALSE(sequence_expand_as.empty());
ASSERT_TRUE(sequence_expand_as.front());
}
......
......@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_pool_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_pool_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(sequence_pool_x86, retrive_op) {
auto sequence_pool =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_pool");
auto sequence_pool = KernelRegistry::Global().Create("sequence_pool");
ASSERT_FALSE(sequence_pool.empty());
ASSERT_TRUE(sequence_pool.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_reverse_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/sequence_reverse_compute.h"
namespace paddle {
namespace lite {
......@@ -44,9 +46,7 @@ static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
} // namespace
TEST(sequence_reverse_x86, retrive_op) {
auto sequence_reverse =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_reverse");
auto sequence_reverse = KernelRegistry::Global().Create("sequence_reverse");
ASSERT_FALSE(sequence_reverse.empty());
ASSERT_TRUE(sequence_reverse.front());
}
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/shape_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/shape_compute.h"
namespace paddle {
namespace lite {
......@@ -23,8 +25,7 @@ namespace kernels {
namespace x86 {
TEST(shape_x86, retrive_op) {
auto shape =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("shape");
auto shape = KernelRegistry::Global().Create("shape");
ASSERT_FALSE(shape.empty());
ASSERT_TRUE(shape.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/slice_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/slice_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -79,8 +82,7 @@ static void slice_ref(const float* input,
}
TEST(slice_x86, retrive_op) {
auto slice =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("slice");
auto slice = KernelRegistry::Global().Create("slice");
ASSERT_FALSE(slice.empty());
ASSERT_TRUE(slice.front());
}
......
......@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/softmax_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/softmax_compute.h"
namespace paddle {
namespace lite {
......@@ -25,9 +27,7 @@ namespace kernels {
namespace x86 {
TEST(softmax_x86, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"softmax");
auto softmax = KernelRegistry::Global().Create("softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
......
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/stack_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/stack_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -25,8 +28,7 @@ namespace x86 {
// stack
TEST(stack_x86, retrive_op) {
auto stack =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("stack");
auto stack = KernelRegistry::Global().Create("stack");
ASSERT_FALSE(stack.empty());
ASSERT_TRUE(stack.front());
}
......
......@@ -13,10 +13,12 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/activation_compute.cc"
......@@ -26,8 +28,7 @@ namespace kernels {
namespace x86 {
TEST(tanh_x86, retrive_op) {
auto tanh =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("tanh");
auto tanh = KernelRegistry::Global().Create("tanh");
ASSERT_FALSE(tanh.empty());
ASSERT_TRUE(tanh.front());
}
......
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/x86/transpose_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -25,9 +28,7 @@ namespace x86 {
// transpose
TEST(transpose_x86, retrive_op) {
auto transpose =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose");
auto transpose = KernelRegistry::Global().Create("transpose");
ASSERT_FALSE(transpose.empty());
ASSERT_TRUE(transpose.front());
}
......@@ -75,9 +76,7 @@ TEST(transpose_x86, run_test) {
// transpose2
TEST(transpose2_x86, retrive_op) {
auto transpose2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"transpose2");
auto transpose2 = KernelRegistry::Global().Create("transpose2");
ASSERT_FALSE(transpose2.empty());
ASSERT_TRUE(transpose2.front());
}
......
......@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/var_conv_2d_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/x86/var_conv_2d_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -197,9 +200,7 @@ static void var_conv_2d_ref(const lite::Tensor* bottom,
}
TEST(var_conv_2d_x86, retrive_op) {
auto var_conv_2d =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"var_conv_2d");
auto var_conv_2d = KernelRegistry::Global().Create("var_conv_2d");
ASSERT_FALSE(var_conv_2d.empty());
ASSERT_TRUE(var_conv_2d.front());
}
......
......@@ -24,7 +24,6 @@
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h"
#include "lite/utils/variant.h"
/*
* This file contains all the argument parameter data structure for operators.
*/
......
......@@ -14,10 +14,16 @@
#pragma once
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "lite/utils/any.h"
#include "lite/utils/check.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/factory.h"
#include "lite/utils/hash.h"
#include "lite/utils/io.h"
#include "lite/utils/macros.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "lite/utils/all.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
/*
* Factor for any Type creator.
*
* Usage:
*
* struct SomeType;
* // Register a creator.
* Factory<SomeType>::Global().Register("some_key", [] ->
* std::unique_ptr<SomeType> { ... });
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template <typename ItemType, typename ItemTypePtr>
class Factory {
public:
using item_t = ItemType;
using self_t = Factory<item_t, ItemTypePtr>;
using item_ptr_t = ItemTypePtr;
using creator_t = std::function<item_ptr_t()>;
static Factory& Global() {
static Factory* x = new self_t;
return *x;
}
void Register(const std::string& op_type, creator_t&& creator) {
creators_[op_type].emplace_back(std::move(creator));
}
item_ptr_t Create(const std::string& op_type) const {
auto res = Creates(op_type);
if (res.empty()) return nullptr;
CHECK_EQ(res.size(), 1UL) << "Get multiple Op for type " << op_type;
return std::move(res.front());
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
std::list<item_ptr_t> res;
auto it = creators_.find(op_type);
if (it == creators_.end()) return res;
for (auto& c : it->second) {
res.emplace_back(c());
}
return res;
}
std::string DebugString() const {
STL::stringstream ss;
for (const auto& item : creators_) {
ss << " - " << item.first << "\n";
}
return ss.str();
}
protected:
std::map<std::string, std::list<creator_t>> creators_;
};
/* A helper function to help run a lambda at the start.
*/
template <typename Type>
class Registor {
public:
explicit Registor(std::function<void()>&& functor) { functor(); }
// Touch will do nothing.
int Touch() { return 0; }
};
} // namespace lite
} // namespace paddle
Subproject commit ac203b20926b13a35ff85277d2e5d3c38698eee8
Subproject commit 6df40a2471737b27271bdd9b900ab5f3aec746c7
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册