未验证 提交 49d2a778 编写于 作者: W Weilong Wu 提交者: GitHub

Polish black_ops_list logic in eager_gen (#44188)

* Polish black_ops_list logic in eager_gen

* update black_ops_list
上级 b61d8f77
...@@ -40,8 +40,11 @@ from codegen_utils import AssertMessage, GetIndent ...@@ -40,8 +40,11 @@ from codegen_utils import AssertMessage, GetIndent
# keeping the code compatible, here we also skip inplace check in new dygraph temporarily, # keeping the code compatible, here we also skip inplace check in new dygraph temporarily,
# and this will be fixed in the futrue. # and this will be fixed in the futrue.
inplace_check_blacklist = set(["assign_out_"]) inplace_check_blacklist = set(["assign_out_"])
# # --- Black Ops list that's NO NEED to apply backward code generation
black_ops_list = ["conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n"] # Black Ops list that's NO NEED to apply code generation
black_ops_list = [
"conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n", "add_n_grad"
]
########### ###########
...@@ -1637,7 +1640,6 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): ...@@ -1637,7 +1640,6 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
if 'backward' not in forward_api_contents.keys(): return None if 'backward' not in forward_api_contents.keys(): return None
backward_api_name = forward_api_contents['backward'] backward_api_name = forward_api_contents['backward']
if backward_api_name in black_ops_list: return None
assert backward_api_name in grad_api_dict.keys(), AssertMessage( assert backward_api_name in grad_api_dict.keys(), AssertMessage(
backward_api_name, grad_api_dict.keys()) backward_api_name, grad_api_dict.keys())
backward_api_contents = grad_api_dict[backward_api_name] backward_api_contents = grad_api_dict[backward_api_name]
...@@ -1655,7 +1657,7 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): ...@@ -1655,7 +1657,7 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
backward_api_contents = self.GetBackwardAPIContents( backward_api_contents = self.GetBackwardAPIContents(
forward_api_contents) forward_api_contents)
if backward_api_contents is None: continue if backward_api_contents is None: continue
if forward_api_contents['api'] in black_ops_list: continue
# Generate Dygraph Forward Function # Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator( function_generator = DygraphForwardFunctionGenerator(
forward_api_contents, backward_api_contents, namespace) forward_api_contents, backward_api_contents, namespace)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册