提交 05186e7b 编写于 作者: M Megvii Engine Team

fix(midout): fix elemwise crash after midout

some dnn backends opr will use agency opr,
for example: softmax cpu naive imp will call elemwise opr,
at model dump stage, we can not get dnn runtime logic,
so we record elemwise mode info at runtime stage.

GitOrigin-RevId: 6528b4c85da90251df513e6d5bcabffb8a0e2c61
上级 9be8de60
......@@ -17,6 +17,9 @@
#include "midout.h"
MIDOUT_DECL(megdnn_common_elemwise)
//! this tag will be used at tools/gen_header_for_bin_reduce.py
//! please do not modify it
MIDOUT_DECL(megdnn_common_elemwise_mode)
#include <mutex>
#include <vector>
......@@ -154,6 +157,88 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
#if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL
megdnn_assert(ret.arity);
#endif
//! Some DNN backend OPRS will use proxy OPRS. For example, softmax@cpu Naive imp
//! will call elemwise OPR. In the model dump stage, we have no information about
//! this logic, which will lead to the loss of elemwise mode. As a solution, we
//! record the elemwise mode information by adding the 'midout' case flag in the run
//! stage.
#define CB_MODE(mode) \
case mode: \
MIDOUT_BEGIN(megdnn_common_elemwise_mode, midout_iv(mode)) { return ret; } \
MIDOUT_END(); \
break;
switch (mode) {
CB_MODE(Mode::RELU);
CB_MODE(Mode::ABS);
CB_MODE(Mode::ACOS);
CB_MODE(Mode::ASIN);
CB_MODE(Mode::CEIL);
CB_MODE(Mode::COS);
CB_MODE(Mode::EXP);
CB_MODE(Mode::EXPM1);
CB_MODE(Mode::FLOOR);
CB_MODE(Mode::LOG);
CB_MODE(Mode::LOG1P);
CB_MODE(Mode::NEGATE);
CB_MODE(Mode::SIGMOID);
CB_MODE(Mode::SIN);
CB_MODE(Mode::TANH);
CB_MODE(Mode::ABS_GRAD);
CB_MODE(Mode::ADD);
CB_MODE(Mode::FLOOR_DIV);
CB_MODE(Mode::MAX);
CB_MODE(Mode::MIN);
CB_MODE(Mode::MOD);
CB_MODE(Mode::MUL);
CB_MODE(Mode::POW);
CB_MODE(Mode::SIGMOID_GRAD);
CB_MODE(Mode::SUB);
CB_MODE(Mode::SWITCH_GT0);
CB_MODE(Mode::TANH_GRAD);
CB_MODE(Mode::TRUE_DIV);
CB_MODE(Mode::LOG_SUM_EXP);
CB_MODE(Mode::LT);
CB_MODE(Mode::LEQ);
CB_MODE(Mode::EQ);
CB_MODE(Mode::SHL);
CB_MODE(Mode::SHR);
CB_MODE(Mode::COND_LEQ_MOV);
CB_MODE(Mode::FUSE_MUL_ADD3);
CB_MODE(Mode::FUSE_MUL_ADD4);
CB_MODE(Mode::FUSE_ADD_RELU);
CB_MODE(Mode::FUSE_ADD_SIGMOID);
CB_MODE(Mode::FUSE_ADD_TANH);
CB_MODE(Mode::FAST_TANH);
CB_MODE(Mode::FAST_TANH_GRAD);
CB_MODE(Mode::ROUND);
CB_MODE(Mode::RMULH);
CB_MODE(Mode::ATAN2);
CB_MODE(Mode::ERF);
CB_MODE(Mode::ERFINV);
CB_MODE(Mode::ERFC);
CB_MODE(Mode::ERFCINV);
CB_MODE(Mode::H_SWISH);
CB_MODE(Mode::H_SWISH_GRAD);
CB_MODE(Mode::FUSE_ADD_H_SWISH);
CB_MODE(Mode::NOT);
CB_MODE(Mode::AND);
CB_MODE(Mode::OR);
CB_MODE(Mode::XOR);
CB_MODE(Mode::SILU);
CB_MODE(Mode::SILU_GRAD);
CB_MODE(Mode::GELU);
CB_MODE(Mode::GELU_GRAD);
default:
megdnn_assert(
0,
"code issue happened!!, please add new elemwise to switch mode.");
return ret;
#undef CB_MODE
}
return ret;
}
......
......@@ -77,18 +77,40 @@ class HeaderGen:
self._dtypes.add(i)
for i in data["opr_types"]:
self._oprs.add(i)
for i in data["elemwise_modes"]:
self._elemwise_modes.add(i)
def extend_midout(self, fname):
self._midout_files.append(fname)
def extend_elemwise_mode_info(self, fname):
for line in open(fname):
# tag write in dnn/src/common/elemwise/opr_impl.cpp
idx = line.find("megdnn_common_elemwise_mode")
if idx > 0:
cmd = "c++filt -t {}".format(line)
demangle = subprocess.check_output(cmd, shell=True).decode("utf-8")
demangle = demangle.replace(">", "").split()
is_find_number = False
for i in demangle:
if i.isnumeric():
self._elemwise_modes.add(i)
is_find_number = True
break
assert (
is_find_number
), "code issue happened!! can not find elemwise mode in: {}".format(
line
)
def generate(self, fout):
self._fout = fout
self._write_def("MGB_BINREDUCE_VERSION", "20190219")
self._write_def("MGB_BINREDUCE_VERSION", "20220507")
if self._has_netinfo:
self._write_dtype()
if len(self._elemwise_modes) > 0:
self._write_elemwise_modes()
if self._has_netinfo:
self._write_oprs()
self._write_hash()
self._write_midout()
......@@ -156,22 +178,32 @@ class HeaderGen:
with open(fpath) as fin:
mode_list = [i.strip() for i in fin]
all_elemwise_modes = set()
for i in mode_list:
i = i.split(" ")[0].split("=")[0]
if i in self._elemwise_modes:
content = "_cb({})".format(i)
i_type = i.replace(" ", "").replace("=", " ").split()[0]
i_id = i.replace(" ", "").replace("=", " ").split()[1]
all_elemwise_modes.add(i_id)
if i_id in self._elemwise_modes:
content = "_cb({})".format(i_type)
else:
content = ""
self._write_def(
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(
i.split(" ")[0].split("=")[0]
),
content,
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format(i_type), content,
)
# write end of elemwise macro
self._write_def(
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)",
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)",
)
# finally check all self._elemwise_modes is in all_elemwise_modes
for i in self._elemwise_modes:
assert (
i in all_elemwise_modes
), "code issue happened, can not find elemwise mode: {} in {}".format(
i, all_elemwise_modes
)
def _write_dtype(self):
if "Float16" not in self._dtypes:
......@@ -267,6 +299,7 @@ def main():
with open(i) as fin:
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
gen.extend_midout(i)
gen.extend_elemwise_mode_info(i)
else:
fin.seek(0)
gen.extend_netinfo(json.loads(fin.read()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册