diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 06596c0fae8886832d365f80312a69cfc32a2b5e..30ff3f81ca7af076effd17986ae2ca19d60174ac 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -793,22 +793,29 @@ def _custom_api_content(op_name): params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name) API_TEMPLATE = textwrap.dedent(""" + from paddle.fluid.core import VarBase + from paddle.fluid.framework import in_dygraph_mode, _dygraph_tracer from paddle.fluid.layer_helper import LayerHelper def {op_name}({inputs}): - helper = LayerHelper("{op_name}", **locals()) - # prepare inputs and outputs ins = {ins} attrs = {attrs} outs = {{}} out_names = {out_names} - for out_name in out_names: - # Set 'float32' temporarily, and the actual dtype of output variable will be inferred - # in runtime. - outs[out_name] = helper.create_variable(dtype='float32') - helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) + # The output variable's dtype use default value 'float32', + # and the actual dtype of output variable will be inferred in runtime. + if in_dygraph_mode(): + for out_name in out_names: + outs[out_name] = VarBase() + _dygraph_tracer().trace_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) + else: + helper = LayerHelper("{op_name}", **locals()) + for out_name in out_names: + outs[out_name] = helper.create_variable(dtype='float32') + + helper.append_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) res = [outs[out_name] for out_name in out_names]