未验证 提交 181af56c 编写于 作者: J jackzhang235 提交者: GitHub

[MLU][lib_increased] add some basic mlu definition (#3275)

上级 1f8b5c2b
......@@ -54,7 +54,8 @@ enum class TargetType : int {
kXPU = 9,
kBM = 10,
kAny = 6, // any target
NUM = 11, // number of fields.
kMLU = 11,
NUM = 12, // number of fields.
};
enum class PrecisionType : int {
kUnk = 0,
......
......@@ -52,6 +52,7 @@ using XPUContext = Context<TargetType::kXPU>;
using OpenCLContext = Context<TargetType::kOpenCL>;
using FPGAContext = Context<TargetType::kFPGA>;
using BMContext = Context<TargetType::kBM>;
using MLUContext = Context<TargetType::kMLU>;
template <>
class Context<TargetType::kHost> {
......
......@@ -107,6 +107,9 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
case TARGET(kBM): {
CREATE_KERNEL(kBM);
} break;
case TARGET(kMLU): {
CREATE_KERNEL(kMLU);
} break;
default:
CHECK(false) << "not supported kernel target " << TargetToStr(target);
}
......@@ -139,6 +142,15 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kCUDA, kInt64, kNCHW);
INIT_FOR(kCUDA, kInt64, kNHWC);
INIT_FOR(kMLU, kFloat, kNHWC);
INIT_FOR(kMLU, kFloat, kNCHW);
INIT_FOR(kMLU, kFP16, kNHWC);
INIT_FOR(kMLU, kFP16, kNCHW);
INIT_FOR(kMLU, kInt8, kNHWC);
INIT_FOR(kMLU, kInt8, kNCHW);
INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
......
......@@ -268,7 +268,32 @@ class KernelRegistry final {
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kAny),
DATALAYOUT(kAny)> * //
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16),
DATALAYOUT(kNCHW)> * //
>;
KernelRegistry();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册