未验证 提交 05c71543 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Python] Polish auto-gen python codes (#53080)

上级 fa4ae88a
......@@ -1037,16 +1037,22 @@ def _generate_python_module(
def _gen_output_content(
op_name, in_names, out_names, ins_map, attrs_map, inplace_reverse_idx
op_name,
in_names,
out_names,
ins_map,
attrs_map,
outs_list,
inplace_reverse_idx,
):
# ' ' * tab space * tab number
indent = ' ' * 4 * 2
dynamic_content = f"""
{indent}res = []
dynamic_content = f"""res = []
{indent}start_idx = 0"""
static_content = f"""
{indent}ins = {{}}
static_content = f"""ins = {{}}
{indent}ins_map = {ins_map}
{indent}outs = {{}}
{indent}outs_list = {outs_list}
{indent}for key, value in ins_map.items():
{indent} # handle optional inputs
{indent} if value is not None:
......@@ -1131,20 +1137,16 @@ def _custom_api_content(op_name):
out_names,
ins_map,
attrs_map,
outs_list,
inplace_reverse_idx,
)
API_TEMPLATE = textwrap.dedent(
"""
import paddle.fluid.core as core
from paddle.fluid.core import Tensor
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
def {op_name}({params_list}):
# prepare inputs and outputs
outs = {{}}
outs_list = {outs_list}
# The output variable's dtype use default value 'float32',
# and the actual dtype of output variable will be inferred in runtime.
if in_dygraph_mode():
......@@ -1159,7 +1161,6 @@ def _custom_api_content(op_name):
api_content = API_TEMPLATE.format(
op_name=op_name,
params_list=params_list,
outs_list=outs_list,
dynamic_content=dynamic_content,
static_content=static_content,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册