From 49d2a7788368ec4f856b53edd3a4299c034690c4 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Thu, 11 Aug 2022 11:01:39 +0800 Subject: [PATCH] Polish black_ops_list logic in eager_gen (#44188) * Polish black_ops_list logic in eager_gen * update black_ops_list --- .../final_state_generator/eager_gen.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index a3beb268cfa..bc48fe75149 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -40,8 +40,11 @@ from codegen_utils import AssertMessage, GetIndent # keeping the code compatible, here we also skip inplace check in new dygraph temporarily, # and this will be fixed in the futrue. inplace_check_blacklist = set(["assign_out_"]) -# # --- Black Ops list that's NO NEED to apply backward code generation -black_ops_list = ["conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n"] + +# Black Ops list that's NO NEED to apply code generation +black_ops_list = [ + "conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n", "add_n_grad" +] ########### @@ -1637,7 +1640,6 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): if 'backward' not in forward_api_contents.keys(): return None backward_api_name = forward_api_contents['backward'] - if backward_api_name in black_ops_list: return None assert backward_api_name in grad_api_dict.keys(), AssertMessage( backward_api_name, grad_api_dict.keys()) backward_api_contents = grad_api_dict[backward_api_name] @@ -1655,7 +1657,7 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): backward_api_contents = self.GetBackwardAPIContents( forward_api_contents) if backward_api_contents is None: continue - if forward_api_contents['api'] in black_ops_list: continue + # Generate Dygraph Forward Function function_generator = DygraphForwardFunctionGenerator( forward_api_contents, backward_api_contents, namespace) -- GitLab