未验证 提交 49d2a778 编写于 作者: W Weilong Wu 提交者: GitHub

Polish black_ops_list logic in eager_gen (#44188)

* Polish black_ops_list logic in eager_gen

* update black_ops_list
上级 b61d8f77
......@@ -40,8 +40,11 @@ from codegen_utils import AssertMessage, GetIndent
# keeping the code compatible, here we also skip inplace check in new dygraph temporarily,
# and this will be fixed in the futrue.
inplace_check_blacklist = set(["assign_out_"])
# # --- Black Ops list that's NO NEED to apply backward code generation
black_ops_list = ["conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n"]
# Black Ops list that's NO NEED to apply code generation
black_ops_list = [
"conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n", "add_n_grad"
]
###########
......@@ -1637,7 +1640,6 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
if 'backward' not in forward_api_contents.keys(): return None
backward_api_name = forward_api_contents['backward']
if backward_api_name in black_ops_list: return None
assert backward_api_name in grad_api_dict.keys(), AssertMessage(
backward_api_name, grad_api_dict.keys())
backward_api_contents = grad_api_dict[backward_api_name]
......@@ -1655,7 +1657,7 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
backward_api_contents = self.GetBackwardAPIContents(
forward_api_contents)
if backward_api_contents is None: continue
if forward_api_contents['api'] in black_ops_list: continue
# Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator(
forward_api_contents, backward_api_contents, namespace)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册