未验证 提交 05c71543 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Python] Polish auto-gen python codes (#53080)

上级 fa4ae88a
...@@ -1037,16 +1037,22 @@ def _generate_python_module( ...@@ -1037,16 +1037,22 @@ def _generate_python_module(
def _gen_output_content( def _gen_output_content(
op_name, in_names, out_names, ins_map, attrs_map, inplace_reverse_idx op_name,
in_names,
out_names,
ins_map,
attrs_map,
outs_list,
inplace_reverse_idx,
): ):
# ' ' * tab space * tab number # ' ' * tab space * tab number
indent = ' ' * 4 * 2 indent = ' ' * 4 * 2
dynamic_content = f""" dynamic_content = f"""res = []
{indent}res = []
{indent}start_idx = 0""" {indent}start_idx = 0"""
static_content = f""" static_content = f"""ins = {{}}
{indent}ins = {{}}
{indent}ins_map = {ins_map} {indent}ins_map = {ins_map}
{indent}outs = {{}}
{indent}outs_list = {outs_list}
{indent}for key, value in ins_map.items(): {indent}for key, value in ins_map.items():
{indent} # handle optional inputs {indent} # handle optional inputs
{indent} if value is not None: {indent} if value is not None:
...@@ -1131,20 +1137,16 @@ def _custom_api_content(op_name): ...@@ -1131,20 +1137,16 @@ def _custom_api_content(op_name):
out_names, out_names,
ins_map, ins_map,
attrs_map, attrs_map,
outs_list,
inplace_reverse_idx, inplace_reverse_idx,
) )
API_TEMPLATE = textwrap.dedent( API_TEMPLATE = textwrap.dedent(
""" """
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.core import Tensor from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
def {op_name}({params_list}): def {op_name}({params_list}):
# prepare inputs and outputs
outs = {{}}
outs_list = {outs_list}
# The output variable's dtype use default value 'float32', # The output variable's dtype use default value 'float32',
# and the actual dtype of output variable will be inferred in runtime. # and the actual dtype of output variable will be inferred in runtime.
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -1159,7 +1161,6 @@ def _custom_api_content(op_name): ...@@ -1159,7 +1161,6 @@ def _custom_api_content(op_name):
api_content = API_TEMPLATE.format( api_content = API_TEMPLATE.format(
op_name=op_name, op_name=op_name,
params_list=params_list, params_list=params_list,
outs_list=outs_list,
dynamic_content=dynamic_content, dynamic_content=dynamic_content,
static_content=static_content, static_content=static_content,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册