未验证 提交 bcdffe66 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Enhance eager_trace_op logic to support More Op (#41210)

* [Eager]Enhance eager_trace_op logic to support Optimizer Op

* fix AsDispensable
上级 a2c01db1
...@@ -36,6 +36,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -36,6 +36,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}}, {"label_smooth", {"X", "PriorDist"}},
{"assign", {"X"}}, {"assign", {"X"}},
{"crop", {"X", "Y", "Offsets"}},
{"crop_tensor", {"X", "Shape", "Offsets"}},
{"reshape2", {"X", "Shape"}}, {"reshape2", {"X", "Shape"}},
{"expand", {"X", "ExpandTimes"}}, {"expand", {"X", "ExpandTimes"}},
{"slice", {"slice",
...@@ -55,6 +57,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -55,6 +57,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"repeat_interleave", {"X", "RepeatsTensor"}}, {"repeat_interleave", {"X", "RepeatsTensor"}},
{"roi_pool", {"X", "ROIs", "RoisNum"}}, {"roi_pool", {"X", "ROIs", "RoisNum"}},
{"roi_align", {"X", "ROIs", "RoisNum"}}, {"roi_align", {"X", "ROIs", "RoisNum"}},
{"prroi_pool", {"X", "ROIs", "BatchRoINums"}},
{"psroi_pool", {"X", "ROIs", "RoisNum"}}, {"psroi_pool", {"X", "ROIs", "RoisNum"}},
{"collect_fpn_proposals", {"collect_fpn_proposals",
{"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}},
......
...@@ -110,6 +110,9 @@ class Tracer(core.Tracer): ...@@ -110,6 +110,9 @@ class Tracer(core.Tracer):
arg_list = [] arg_list = []
for i in range(len(op_args)): for i in range(len(op_args)):
# initialized with None
arg_to_append = None
arg_name = op_args[i] arg_name = op_args[i]
arg_type = op_args_type[i] arg_type = op_args_type[i]
if arg_name in inputs.keys(): if arg_name in inputs.keys():
...@@ -117,14 +120,20 @@ class Tracer(core.Tracer): ...@@ -117,14 +120,20 @@ class Tracer(core.Tracer):
elif arg_name in outputs.keys(): elif arg_name in outputs.keys():
arg_to_append = outputs[arg_name] arg_to_append = outputs[arg_name]
else: else:
if "Num" in arg_name: if "Num" in arg_name[-3:]:
# Remove "Num" suffix to get out_name # Remove "Num" suffix to get out_name
out_name = arg_name[:-3] out_name = arg_name[:-3]
assert out_name in outputs.keys() assert out_name in outputs.keys()
num_outs = len(outputs[out_name]) num_outs = len(outputs[out_name])
arg_to_append = num_outs arg_to_append = num_outs
else: # NOTE(dev): For MasterParam/MasterParamOut in optimzer op
arg_to_append = None elif "Var" in arg_name[-3:]:
out_name = arg_name[:-3]
print(out_name)
if out_name in outputs.keys():
arg_to_append = outputs[out_name]
elif out_name in inputs.keys():
arg_to_append = inputs[out_name]
if arg_to_append is None: if arg_to_append is None:
arg_list.append(arg_to_append) arg_list.append(arg_to_append)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册