From 0624ea568b1067eb6dc4139c85f0778149f526fe Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 12 Apr 2021 21:19:28 +0800 Subject: [PATCH] polish custom api content for performence (#32209) --- .../utils/cpp_extension/extension_utils.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 06596c0fae..30ff3f81ca 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -793,22 +793,29 @@ def _custom_api_content(op_name): params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name) API_TEMPLATE = textwrap.dedent(""" + from paddle.fluid.core import VarBase + from paddle.fluid.framework import in_dygraph_mode, _dygraph_tracer from paddle.fluid.layer_helper import LayerHelper def {op_name}({inputs}): - helper = LayerHelper("{op_name}", **locals()) - # prepare inputs and outputs ins = {ins} attrs = {attrs} outs = {{}} out_names = {out_names} - for out_name in out_names: - # Set 'float32' temporarily, and the actual dtype of output variable will be inferred - # in runtime. - outs[out_name] = helper.create_variable(dtype='float32') - helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) + # The output variable's dtype use default value 'float32', + # and the actual dtype of output variable will be inferred in runtime. + if in_dygraph_mode(): + for out_name in out_names: + outs[out_name] = VarBase() + _dygraph_tracer().trace_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) + else: + helper = LayerHelper("{op_name}", **locals()) + for out_name in out_names: + outs[out_name] = helper.create_variable(dtype='float32') + + helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) res = [outs[out_name] for out_name in out_names] -- GitLab