未验证 提交 0624ea56 编写于 作者: C Chen Weihang 提交者: GitHub

polish custom api content for performence (#32209)

上级 4b5cb22f
...@@ -793,19 +793,26 @@ def _custom_api_content(op_name): ...@@ -793,19 +793,26 @@ def _custom_api_content(op_name):
params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name) params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name)
API_TEMPLATE = textwrap.dedent(""" API_TEMPLATE = textwrap.dedent("""
from paddle.fluid.core import VarBase
from paddle.fluid.framework import in_dygraph_mode, _dygraph_tracer
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
def {op_name}({inputs}): def {op_name}({inputs}):
helper = LayerHelper("{op_name}", **locals())
# prepare inputs and outputs # prepare inputs and outputs
ins = {ins} ins = {ins}
attrs = {attrs} attrs = {attrs}
outs = {{}} outs = {{}}
out_names = {out_names} out_names = {out_names}
# The output variable's dtype use default value 'float32',
# and the actual dtype of output variable will be inferred in runtime.
if in_dygraph_mode():
for out_name in out_names:
outs[out_name] = VarBase()
_dygraph_tracer().trace_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs)
else:
helper = LayerHelper("{op_name}", **locals())
for out_name in out_names: for out_name in out_names:
# Set 'float32' temporarily, and the actual dtype of output variable will be inferred
# in runtime.
outs[out_name] = helper.create_variable(dtype='float32') outs[out_name] = helper.create_variable(dtype='float32')
helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册