diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index cd68ca5187146fcf4d88b2008fc44533b3e1cf10..1de46a39467af125e705cfcb7a9eeae64a0be133 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -54,7 +54,8 @@ enum class TargetType : int { kXPU = 9, kBM = 10, kAny = 6, // any target - NUM = 11, // number of fields. + kMLU = 11, + NUM = 12, // number of fields. }; enum class PrecisionType : int { kUnk = 0, diff --git a/lite/core/context.h b/lite/core/context.h index 978fb5d67a2fae8025fa725ed1f717aa3df611c0..88fe00d0f2aab41cfd3e5562d29f0a8a82598428 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -52,6 +52,7 @@ using XPUContext = Context; using OpenCLContext = Context; using FPGAContext = Context; using BMContext = Context; +using MLUContext = Context; template <> class Context { diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 4b6d3282ed300654c612325ff9c53c153ccea30a..fe1dff3c99c1d2413888e78c89c999caea0ab030 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -107,6 +107,9 @@ std::list> KernelRegistry::Create( case TARGET(kBM): { CREATE_KERNEL(kBM); } break; + case TARGET(kMLU): { + CREATE_KERNEL(kMLU); + } break; default: CHECK(false) << "not supported kernel target " << TargetToStr(target); } @@ -139,6 +142,15 @@ KernelRegistry::KernelRegistry() INIT_FOR(kCUDA, kInt64, kNCHW); INIT_FOR(kCUDA, kInt64, kNHWC); + INIT_FOR(kMLU, kFloat, kNHWC); + INIT_FOR(kMLU, kFloat, kNCHW); + INIT_FOR(kMLU, kFP16, kNHWC); + INIT_FOR(kMLU, kFP16, kNCHW); + INIT_FOR(kMLU, kInt8, kNHWC); + INIT_FOR(kMLU, kInt8, kNCHW); + INIT_FOR(kMLU, kInt16, kNHWC); + INIT_FOR(kMLU, kInt16, kNCHW); + INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kFloat, kNHWC); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 6f8f1e8bc6662a7b22fd8f4c3b9683eb6f4da139..3c41c1fd8af240401c3edf0343433f8d8d9c85db 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -268,7 +268,32 @@ class KernelRegistry final { DATALAYOUT(kAny)> *, // KernelRegistryForTarget * // + DATALAYOUT(kAny)> *, // + + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // >; KernelRegistry();