提交 7bf1c38c 编写于 作者: M Megvii Engine Team

fix(mgb/imperative): fix imperative code gen

GitOrigin-RevId: da9e8a280acff1295fb455676cb2063dcb493e53
上级 c49d3070
0df57b38e71a4d1882ed6c24f3a26b57 ../../dnn/scripts/opr_param_defs.py
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td
c613316001b5f0294ede198f5563f041 generated/opdef.h.inl
a1f7f13c909f9d4c173277f4ed28fb61 generated/opdef.cpp.inl
cf48f9ca352fabaeb6c846c11c6b1662 generated/opdef.py.inl
12365b938f564e5b3639d309f7c83414 generated/opdef.cpy.inl
2a5851d0e2470d4d045811e7a20b1a3f generated/opdef.h.inl
55b862badeed19aed8e84c5d6f468ff2 generated/opdef.cpp.inl
f3f4c7f0ee1b39392df8a679f6d22596 generated/opdef.py.inl
6b11ca844a7855fdc5eebffaf563a89c generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -3067,6 +3067,15 @@ std::vector<std::pair<const char*, std::string>> Elemwise_props_impl(const OpDef
case Elemwise::Mode::COND_LT_MOV:
props_.emplace_back("mode", "COND_LT_MOV");
break;
case Elemwise::Mode::NEQ:
props_.emplace_back("mode", "NEQ");
break;
case Elemwise::Mode::ISNAN:
props_.emplace_back("mode", "ISNAN");
break;
case Elemwise::Mode::ISINF:
props_.emplace_back("mode", "ISINF");
break;
default:
props_.emplace_back("mode", "INVALID");
break;
......@@ -3285,6 +3294,24 @@ std::vector<std::pair<const char*, std::string>> ElemwiseMultiType_props_impl(co
case ElemwiseMultiType::Mode::QCOND_LT_MOV:
props_.emplace_back("mode", "QCOND_LT_MOV");
break;
case ElemwiseMultiType::Mode::EQ:
props_.emplace_back("mode", "EQ");
break;
case ElemwiseMultiType::Mode::NEQ:
props_.emplace_back("mode", "NEQ");
break;
case ElemwiseMultiType::Mode::LT:
props_.emplace_back("mode", "LT");
break;
case ElemwiseMultiType::Mode::LEQ:
props_.emplace_back("mode", "LEQ");
break;
case ElemwiseMultiType::Mode::ISNAN:
props_.emplace_back("mode", "ISNAN");
break;
case ElemwiseMultiType::Mode::ISINF:
props_.emplace_back("mode", "ISINF");
break;
default:
props_.emplace_back("mode", "INVALID");
break;
......
......@@ -780,6 +780,9 @@ case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD";
case Elemwise::Mode::GELU: return "GELU";
case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD";
case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV";
case Elemwise::Mode::NEQ: return "NEQ";
case Elemwise::Mode::ISNAN: return "ISNAN";
case Elemwise::Mode::ISINF: return "ISINF";
default:
return "Elemwise::Mode::Unknown";
}
......@@ -863,6 +866,12 @@ case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: return "FUSE_MUL_
case ElemwiseMultiType::Mode::MUL_INT16xF32xF32: return "MUL_INT16xF32xF32";
case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: return "FUSE_MUL_ADD3_UINT8xF32xF32xF32";
case ElemwiseMultiType::Mode::QCOND_LT_MOV: return "QCOND_LT_MOV";
case ElemwiseMultiType::Mode::EQ: return "EQ";
case ElemwiseMultiType::Mode::NEQ: return "NEQ";
case ElemwiseMultiType::Mode::LT: return "LT";
case ElemwiseMultiType::Mode::LEQ: return "LEQ";
case ElemwiseMultiType::Mode::ISNAN: return "ISNAN";
case ElemwiseMultiType::Mode::ISINF: return "ISINF";
default:
return "ElemwiseMultiType::Mode::Unknown";
}
......
......@@ -893,6 +893,9 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
.value("GELU", Elemwise::Mode::GELU)
.value("GELU_GRAD", Elemwise::Mode::GELU_GRAD)
.value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV)
.value("NEQ", Elemwise::Mode::NEQ)
.value("ISNAN", Elemwise::Mode::ISNAN)
.value("ISINF", Elemwise::Mode::ISINF)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "RELU") return Elemwise::Mode::RELU;
......@@ -956,6 +959,9 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
if (str == "GELU") return Elemwise::Mode::GELU;
if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD;
if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV;
if (str == "NEQ") return Elemwise::Mode::NEQ;
if (str == "ISNAN") return Elemwise::Mode::ISNAN;
if (str == "ISINF") return Elemwise::Mode::ISINF;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Elemwise::Mode>();
......@@ -1025,6 +1031,12 @@ py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
.value("MUL_INT16xF32xF32", ElemwiseMultiType::Mode::MUL_INT16xF32xF32)
.value("FUSE_MUL_ADD3_UINT8xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32)
.value("QCOND_LT_MOV", ElemwiseMultiType::Mode::QCOND_LT_MOV)
.value("EQ", ElemwiseMultiType::Mode::EQ)
.value("NEQ", ElemwiseMultiType::Mode::NEQ)
.value("LT", ElemwiseMultiType::Mode::LT)
.value("LEQ", ElemwiseMultiType::Mode::LEQ)
.value("ISNAN", ElemwiseMultiType::Mode::ISNAN)
.value("ISINF", ElemwiseMultiType::Mode::ISINF)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "FUSE_MUL_ADD3_INT16x32x32x32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
......@@ -1085,6 +1097,12 @@ py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
if (str == "MUL_INT16xF32xF32") return ElemwiseMultiType::Mode::MUL_INT16xF32xF32;
if (str == "FUSE_MUL_ADD3_UINT8xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32;
if (str == "QCOND_LT_MOV") return ElemwiseMultiType::Mode::QCOND_LT_MOV;
if (str == "EQ") return ElemwiseMultiType::Mode::EQ;
if (str == "NEQ") return ElemwiseMultiType::Mode::NEQ;
if (str == "LT") return ElemwiseMultiType::Mode::LT;
if (str == "LEQ") return ElemwiseMultiType::Mode::LEQ;
if (str == "ISNAN") return ElemwiseMultiType::Mode::ISNAN;
if (str == "ISINF") return ElemwiseMultiType::Mode::ISINF;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, ElemwiseMultiType::Mode>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册