From bcdffe6698025020f5903401a29c1deaad4f892f Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 1 Apr 2022 16:51:24 +0800 Subject: [PATCH] [Eager]Enhance eager_trace_op logic to support More Op (#41210) * [Eager]Enhance eager_trace_op logic to support Optimizer Op * fix AsDispensable --- paddle/fluid/pybind/op_function_generator.h | 3 +++ python/paddle/fluid/dygraph/tracer.py | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 2bfc16c7d5b..75175958978 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -36,6 +36,8 @@ std::map> op_ins_map = { {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, {"assign", {"X"}}, + {"crop", {"X", "Y", "Offsets"}}, + {"crop_tensor", {"X", "Shape", "Offsets"}}, {"reshape2", {"X", "Shape"}}, {"expand", {"X", "ExpandTimes"}}, {"slice", @@ -55,6 +57,7 @@ std::map> op_ins_map = { {"repeat_interleave", {"X", "RepeatsTensor"}}, {"roi_pool", {"X", "ROIs", "RoisNum"}}, {"roi_align", {"X", "ROIs", "RoisNum"}}, + {"prroi_pool", {"X", "ROIs", "BatchRoINums"}}, {"psroi_pool", {"X", "ROIs", "RoisNum"}}, {"collect_fpn_proposals", {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index e1fabf9aeda..747fe7d32cb 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -110,6 +110,9 @@ class Tracer(core.Tracer): arg_list = [] for i in range(len(op_args)): + # initialized with None + arg_to_append = None + arg_name = op_args[i] arg_type = op_args_type[i] if arg_name in inputs.keys(): @@ -117,14 +120,20 @@ class Tracer(core.Tracer): elif arg_name in outputs.keys(): arg_to_append = outputs[arg_name] else: - if "Num" in arg_name: + if "Num" in arg_name[-3:]: # Remove "Num" suffix to get out_name out_name = arg_name[:-3] assert out_name in outputs.keys() num_outs = len(outputs[out_name]) arg_to_append = num_outs - else: - arg_to_append = None + # NOTE(dev): For MasterParam/MasterParamOut in optimzer op + elif "Var" in arg_name[-3:]: + out_name = arg_name[:-3] + print(out_name) + if out_name in outputs.keys(): + arg_to_append = outputs[out_name] + elif out_name in inputs.keys(): + arg_to_append = inputs[out_name] if arg_to_append is None: arg_list.append(arg_to_append) -- GitLab