提交 7bf1c38c 编写于 作者: M Megvii Engine Team

fix(mgb/imperative): fix imperative code gen

GitOrigin-RevId: da9e8a280acff1295fb455676cb2063dcb493e53
上级 c49d3070
0df57b38e71a4d1882ed6c24f3a26b57 ../../dnn/scripts/opr_param_defs.py 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td 759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td
c613316001b5f0294ede198f5563f041 generated/opdef.h.inl 2a5851d0e2470d4d045811e7a20b1a3f generated/opdef.h.inl
a1f7f13c909f9d4c173277f4ed28fb61 generated/opdef.cpp.inl 55b862badeed19aed8e84c5d6f468ff2 generated/opdef.cpp.inl
cf48f9ca352fabaeb6c846c11c6b1662 generated/opdef.py.inl f3f4c7f0ee1b39392df8a679f6d22596 generated/opdef.py.inl
12365b938f564e5b3639d309f7c83414 generated/opdef.cpy.inl 6b11ca844a7855fdc5eebffaf563a89c generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
...@@ -3067,6 +3067,15 @@ std::vector<std::pair<const char*, std::string>> Elemwise_props_impl(const OpDef ...@@ -3067,6 +3067,15 @@ std::vector<std::pair<const char*, std::string>> Elemwise_props_impl(const OpDef
case Elemwise::Mode::COND_LT_MOV: case Elemwise::Mode::COND_LT_MOV:
props_.emplace_back("mode", "COND_LT_MOV"); props_.emplace_back("mode", "COND_LT_MOV");
break; break;
case Elemwise::Mode::NEQ:
props_.emplace_back("mode", "NEQ");
break;
case Elemwise::Mode::ISNAN:
props_.emplace_back("mode", "ISNAN");
break;
case Elemwise::Mode::ISINF:
props_.emplace_back("mode", "ISINF");
break;
default: default:
props_.emplace_back("mode", "INVALID"); props_.emplace_back("mode", "INVALID");
break; break;
...@@ -3285,6 +3294,24 @@ std::vector<std::pair<const char*, std::string>> ElemwiseMultiType_props_impl(co ...@@ -3285,6 +3294,24 @@ std::vector<std::pair<const char*, std::string>> ElemwiseMultiType_props_impl(co
case ElemwiseMultiType::Mode::QCOND_LT_MOV: case ElemwiseMultiType::Mode::QCOND_LT_MOV:
props_.emplace_back("mode", "QCOND_LT_MOV"); props_.emplace_back("mode", "QCOND_LT_MOV");
break; break;
case ElemwiseMultiType::Mode::EQ:
props_.emplace_back("mode", "EQ");
break;
case ElemwiseMultiType::Mode::NEQ:
props_.emplace_back("mode", "NEQ");
break;
case ElemwiseMultiType::Mode::LT:
props_.emplace_back("mode", "LT");
break;
case ElemwiseMultiType::Mode::LEQ:
props_.emplace_back("mode", "LEQ");
break;
case ElemwiseMultiType::Mode::ISNAN:
props_.emplace_back("mode", "ISNAN");
break;
case ElemwiseMultiType::Mode::ISINF:
props_.emplace_back("mode", "ISINF");
break;
default: default:
props_.emplace_back("mode", "INVALID"); props_.emplace_back("mode", "INVALID");
break; break;
......
...@@ -780,6 +780,9 @@ case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD"; ...@@ -780,6 +780,9 @@ case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD";
case Elemwise::Mode::GELU: return "GELU"; case Elemwise::Mode::GELU: return "GELU";
case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD"; case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD";
case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV"; case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV";
case Elemwise::Mode::NEQ: return "NEQ";
case Elemwise::Mode::ISNAN: return "ISNAN";
case Elemwise::Mode::ISINF: return "ISINF";
default: default:
return "Elemwise::Mode::Unknown"; return "Elemwise::Mode::Unknown";
} }
...@@ -863,6 +866,12 @@ case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: return "FUSE_MUL_ ...@@ -863,6 +866,12 @@ case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: return "FUSE_MUL_
case ElemwiseMultiType::Mode::MUL_INT16xF32xF32: return "MUL_INT16xF32xF32"; case ElemwiseMultiType::Mode::MUL_INT16xF32xF32: return "MUL_INT16xF32xF32";
case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: return "FUSE_MUL_ADD3_UINT8xF32xF32xF32"; case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: return "FUSE_MUL_ADD3_UINT8xF32xF32xF32";
case ElemwiseMultiType::Mode::QCOND_LT_MOV: return "QCOND_LT_MOV"; case ElemwiseMultiType::Mode::QCOND_LT_MOV: return "QCOND_LT_MOV";
case ElemwiseMultiType::Mode::EQ: return "EQ";
case ElemwiseMultiType::Mode::NEQ: return "NEQ";
case ElemwiseMultiType::Mode::LT: return "LT";
case ElemwiseMultiType::Mode::LEQ: return "LEQ";
case ElemwiseMultiType::Mode::ISNAN: return "ISNAN";
case ElemwiseMultiType::Mode::ISINF: return "ISINF";
default: default:
return "ElemwiseMultiType::Mode::Unknown"; return "ElemwiseMultiType::Mode::Unknown";
} }
......
...@@ -893,6 +893,9 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode") ...@@ -893,6 +893,9 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
.value("GELU", Elemwise::Mode::GELU) .value("GELU", Elemwise::Mode::GELU)
.value("GELU_GRAD", Elemwise::Mode::GELU_GRAD) .value("GELU_GRAD", Elemwise::Mode::GELU_GRAD)
.value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV) .value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV)
.value("NEQ", Elemwise::Mode::NEQ)
.value("ISNAN", Elemwise::Mode::ISNAN)
.value("ISINF", Elemwise::Mode::ISINF)
.def(py::init([](const std::string& in) { .def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in); auto&& str = normalize_enum(in);
if (str == "RELU") return Elemwise::Mode::RELU; if (str == "RELU") return Elemwise::Mode::RELU;
...@@ -956,6 +959,9 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode") ...@@ -956,6 +959,9 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
if (str == "GELU") return Elemwise::Mode::GELU; if (str == "GELU") return Elemwise::Mode::GELU;
if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD; if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD;
if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV; if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV;
if (str == "NEQ") return Elemwise::Mode::NEQ;
if (str == "ISNAN") return Elemwise::Mode::ISNAN;
if (str == "ISINF") return Elemwise::Mode::ISINF;
throw py::cast_error("invalid enum value " + in); throw py::cast_error("invalid enum value " + in);
})); }));
py::implicitly_convertible<std::string, Elemwise::Mode>(); py::implicitly_convertible<std::string, Elemwise::Mode>();
...@@ -1025,6 +1031,12 @@ py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode") ...@@ -1025,6 +1031,12 @@ py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
.value("MUL_INT16xF32xF32", ElemwiseMultiType::Mode::MUL_INT16xF32xF32) .value("MUL_INT16xF32xF32", ElemwiseMultiType::Mode::MUL_INT16xF32xF32)
.value("FUSE_MUL_ADD3_UINT8xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32) .value("FUSE_MUL_ADD3_UINT8xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32)
.value("QCOND_LT_MOV", ElemwiseMultiType::Mode::QCOND_LT_MOV) .value("QCOND_LT_MOV", ElemwiseMultiType::Mode::QCOND_LT_MOV)
.value("EQ", ElemwiseMultiType::Mode::EQ)
.value("NEQ", ElemwiseMultiType::Mode::NEQ)
.value("LT", ElemwiseMultiType::Mode::LT)
.value("LEQ", ElemwiseMultiType::Mode::LEQ)
.value("ISNAN", ElemwiseMultiType::Mode::ISNAN)
.value("ISINF", ElemwiseMultiType::Mode::ISINF)
.def(py::init([](const std::string& in) { .def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in); auto&& str = normalize_enum(in);
if (str == "FUSE_MUL_ADD3_INT16x32x32x32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32; if (str == "FUSE_MUL_ADD3_INT16x32x32x32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
...@@ -1085,6 +1097,12 @@ py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode") ...@@ -1085,6 +1097,12 @@ py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
if (str == "MUL_INT16xF32xF32") return ElemwiseMultiType::Mode::MUL_INT16xF32xF32; if (str == "MUL_INT16xF32xF32") return ElemwiseMultiType::Mode::MUL_INT16xF32xF32;
if (str == "FUSE_MUL_ADD3_UINT8xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32; if (str == "FUSE_MUL_ADD3_UINT8xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32;
if (str == "QCOND_LT_MOV") return ElemwiseMultiType::Mode::QCOND_LT_MOV; if (str == "QCOND_LT_MOV") return ElemwiseMultiType::Mode::QCOND_LT_MOV;
if (str == "EQ") return ElemwiseMultiType::Mode::EQ;
if (str == "NEQ") return ElemwiseMultiType::Mode::NEQ;
if (str == "LT") return ElemwiseMultiType::Mode::LT;
if (str == "LEQ") return ElemwiseMultiType::Mode::LEQ;
if (str == "ISNAN") return ElemwiseMultiType::Mode::ISNAN;
if (str == "ISINF") return ElemwiseMultiType::Mode::ISINF;
throw py::cast_error("invalid enum value " + in); throw py::cast_error("invalid enum value " + in);
})); }));
py::implicitly_convertible<std::string, ElemwiseMultiType::Mode>(); py::implicitly_convertible<std::string, ElemwiseMultiType::Mode>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册