未验证 提交 11a526ab 编写于 作者: W WangZhen 提交者: GitHub

[NewIR]Call _C_ops.xx in both dygraph and static mode (#56809)

上级 179d4264
...@@ -72,7 +72,7 @@ tools/nvcc_lazy ...@@ -72,7 +72,7 @@ tools/nvcc_lazy
# This file is automatically generated. # This file is automatically generated.
# TODO(zhiqiang) Move this file to build directory. # TODO(zhiqiang) Move this file to build directory.
paddle/fluid/pybind/eager_op_function.cc paddle/fluid/pybind/eager_op_function.*
tools/nvcc_lazy tools/nvcc_lazy
paddle/phi/kernels/sparse/gpu/cutlass_generator/all_gemm_operations.h paddle/phi/kernels/sparse/gpu/cutlass_generator/all_gemm_operations.h
paddle/phi/kernels/sparse/gpu/cutlass_generator/configurations.h paddle/phi/kernels/sparse/gpu/cutlass_generator/configurations.h
......
...@@ -18,5 +18,7 @@ ...@@ -18,5 +18,7 @@
namespace egr { namespace egr {
Controller* Controller::controller_ = new Controller(); Controller* Controller::controller_ = new Controller();
thread_local std::shared_ptr<paddle::imperative::Tracer> Controller::tracer_ =
std::make_shared<paddle::imperative::Tracer>();
} // namespace egr } // namespace egr
...@@ -145,8 +145,7 @@ class Controller { ...@@ -145,8 +145,7 @@ class Controller {
private: private:
Controller() = default; Controller() = default;
static Controller* controller_; static Controller* controller_;
std::shared_ptr<paddle::imperative::Tracer> tracer_{ static thread_local std::shared_ptr<paddle::imperative::Tracer> tracer_;
new paddle::imperative::Tracer()};
std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>> std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>>
op_meta_info_map_; op_meta_info_map_;
/* op_type : {{{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}}, /* op_type : {{{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}},
......
...@@ -53,10 +53,14 @@ add_custom_target( ...@@ -53,10 +53,14 @@ add_custom_target(
${nodes_h_path} ${nodes_h_path}
VERBATIM) VERBATIM)
set(tmp_python_c_output_path set(tmp_python_c_source_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc.tmp") "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc.tmp")
set(python_c_output_path set(python_c_source_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc") "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc")
set(tmp_python_c_header_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.h.tmp")
set(python_c_header_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.h")
add_custom_target( add_custom_target(
eager_python_c_codegen eager_python_c_codegen
...@@ -64,7 +68,10 @@ add_custom_target( ...@@ -64,7 +68,10 @@ add_custom_target(
"${PYTHON_EXECUTABLE}" "${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}" "--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
"--output_path=${tmp_python_c_output_path}" "--source_path=${tmp_python_c_source_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path} "--header_path=${tmp_python_c_header_path}"
${python_c_output_path} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_source_path}
${python_c_source_path}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_header_path}
${python_c_header_path}
VERBATIM) VERBATIM)
...@@ -87,7 +87,7 @@ RETURN_INPLACE_PYOBJECT_TEMPLATE = """ ...@@ -87,7 +87,7 @@ RETURN_INPLACE_PYOBJECT_TEMPLATE = """
PYTHON_C_FUNCTION_TEMPLATE = """ PYTHON_C_FUNCTION_TEMPLATE = """
static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{ PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
{} {}
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try {{ try {{
...@@ -173,6 +173,7 @@ PYTHON_C_WRAPPER_TEMPLATE = """ ...@@ -173,6 +173,7 @@ PYTHON_C_WRAPPER_TEMPLATE = """
#include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/eager/amp_utils.h" #include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h" #include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/fluid/pybind/eager_op_function.h"
namespace paddle {{ namespace paddle {{
namespace pybind {{ namespace pybind {{
...@@ -253,6 +254,29 @@ NAMESPACE_WRAPPER_TEMPLATE = """namespace {} {{ ...@@ -253,6 +254,29 @@ NAMESPACE_WRAPPER_TEMPLATE = """namespace {} {{
}} }}
""" """
PYTHON_C_H_TEMPLATE = """
#pragma once
#include <Python.h>
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
namespace paddle {{
namespace pybind {{
{body}
}} // namespace pybind
}} // namespace paddle
"""
PYTHON_C_FUNCTION_DECLARE_TEMPLATE = """
PyObject *eager_api_{name}(PyObject *self, PyObject *args, PyObject *kwargs);
"""
##################### #####################
# Generator Classes # # Generator Classes #
...@@ -279,6 +303,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -279,6 +303,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Generated Results # Generated Results
self.python_c_function_str = "" self.python_c_function_str = ""
self.python_c_function_reg_str = "" self.python_c_function_reg_str = ""
self.python_c_funcion_declare_str = ""
def CollectIsForwardOnly(self): def CollectIsForwardOnly(self):
forward_api_contents = self.forward_api_contents forward_api_contents = self.forward_api_contents
...@@ -428,6 +453,9 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -428,6 +453,9 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
noamp_dygraph_function_str, noamp_dygraph_function_str,
return_str, return_str,
) )
self.python_c_funcion_declare_str = (
PYTHON_C_FUNCTION_DECLARE_TEMPLATE.format(name=forward_api_name)
)
# Set prefix of forward_api_name to avoid conflicts # Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::") prefix = self.namespace.strip("::")
...@@ -483,6 +511,12 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -483,6 +511,12 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return_str, return_str,
) )
python_c_funcion_declare_str = (
PYTHON_C_FUNCTION_DECLARE_TEMPLATE.format(
name=inplaced_forward_api_name
)
)
python_c_inplace_func_reg_str = ( python_c_inplace_func_reg_str = (
PYTHON_C_FUNCTION_REG_TEMPLATE.format( PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, forward_api_name_prefix,
...@@ -496,10 +530,14 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -496,10 +530,14 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# self.forward_api_name ending with '_' means it only has inplace api # self.forward_api_name ending with '_' means it only has inplace api
if self.forward_api_name[-1] == '_': if self.forward_api_name[-1] == '_':
self.python_c_function_str = python_c_inplace_func_str self.python_c_function_str = python_c_inplace_func_str
self.python_c_funcion_declare_str = python_c_funcion_declare_str
# Generate Python-C Function Registration # Generate Python-C Function Registration
self.python_c_function_reg_str = python_c_inplace_func_reg_str self.python_c_function_reg_str = python_c_inplace_func_reg_str
else: else:
self.python_c_function_str += python_c_inplace_func_str self.python_c_function_str += python_c_inplace_func_str
self.python_c_funcion_declare_str += (
python_c_funcion_declare_str
)
# Generate Python-C Function Registration # Generate Python-C Function Registration
self.python_c_function_reg_str += python_c_inplace_func_reg_str self.python_c_function_reg_str += python_c_inplace_func_reg_str
...@@ -541,6 +579,7 @@ class PythonCGenerator(GeneratorBase): ...@@ -541,6 +579,7 @@ class PythonCGenerator(GeneratorBase):
# Generated Result # Generated Result
self.python_c_functions_str = "" self.python_c_functions_str = ""
self.python_c_functions_reg_str = "" self.python_c_functions_reg_str = ""
self.python_c_funcion_declare_str = ""
def GeneratePythonCFunctions(self): def GeneratePythonCFunctions(self):
namespace = self.namespace namespace = self.namespace
...@@ -559,6 +598,9 @@ class PythonCGenerator(GeneratorBase): ...@@ -559,6 +598,9 @@ class PythonCGenerator(GeneratorBase):
self.python_c_functions_reg_str += ( self.python_c_functions_reg_str += (
f_generator.python_c_function_reg_str f_generator.python_c_function_reg_str
) )
self.python_c_funcion_declare_str += (
f_generator.python_c_funcion_declare_str
)
def AttachNamespace(self): def AttachNamespace(self):
namespace = self.namespace namespace = self.namespace
...@@ -570,6 +612,11 @@ class PythonCGenerator(GeneratorBase): ...@@ -570,6 +612,11 @@ class PythonCGenerator(GeneratorBase):
self.python_c_functions_str = NAMESPACE_WRAPPER_TEMPLATE.format( self.python_c_functions_str = NAMESPACE_WRAPPER_TEMPLATE.format(
namespace, python_c_functions_str namespace, python_c_functions_str
) )
self.python_c_funcion_declare_str = (
NAMESPACE_WRAPPER_TEMPLATE.format(
namespace, self.python_c_funcion_declare_str
)
)
def run(self): def run(self):
# Infer namespace from yaml_path # Infer namespace from yaml_path
...@@ -593,7 +640,8 @@ def ParseArguments(): ...@@ -593,7 +640,8 @@ def ParseArguments():
description='Eager Code Generator Args Parser' description='Eager Code Generator Args Parser'
) )
parser.add_argument('--api_yaml_path', type=str) parser.add_argument('--api_yaml_path', type=str)
parser.add_argument('--output_path', type=str) parser.add_argument('--source_path', type=str)
parser.add_argument('--header_path', type=str)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -631,6 +679,7 @@ if __name__ == "__main__": ...@@ -631,6 +679,7 @@ if __name__ == "__main__":
generated_python_c_functions = "" generated_python_c_functions = ""
generated_python_c_registration = "" generated_python_c_registration = ""
generated_python_c_functions_header = ""
for i in range(len(api_yaml_paths)): for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i] api_yaml_path = api_yaml_paths[i]
...@@ -643,14 +692,22 @@ if __name__ == "__main__": ...@@ -643,14 +692,22 @@ if __name__ == "__main__":
generated_python_c_registration += ( generated_python_c_registration += (
py_c_generator.python_c_functions_reg_str py_c_generator.python_c_functions_reg_str
) )
generated_python_c_functions_header += (
py_c_generator.python_c_funcion_declare_str
)
python_c_str = GeneratePythonCWrappers( python_c_str = GeneratePythonCWrappers(
generated_python_c_functions, generated_python_c_registration generated_python_c_functions, generated_python_c_registration
) )
output_path = args.output_path soucre_path = args.source_path
for path in [output_path]: header_path = args.header_path
for path in [soucre_path, header_path]:
if os.path.exists(path): if os.path.exists(path):
os.remove(path) os.remove(path)
GeneratePythonCFile(output_path, python_c_str) GeneratePythonCFile(soucre_path, python_c_str)
GeneratePythonCFile(
header_path,
PYTHON_C_H_TEMPLATE.format(body=generated_python_c_functions_header),
)
...@@ -56,7 +56,7 @@ thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0; ...@@ -56,7 +56,7 @@ thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
thread_local phi::DataType Tracer::amp_dtype_ = phi::DataType::FLOAT32; thread_local phi::DataType Tracer::amp_dtype_ = phi::DataType::FLOAT32;
static std::shared_ptr<Tracer> g_current_tracer(nullptr); static thread_local std::shared_ptr<Tracer> g_current_tracer(nullptr);
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; } const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }
......
...@@ -21,7 +21,9 @@ CPP_FILE_TEMPLATE = """ ...@@ -21,7 +21,9 @@ CPP_FILE_TEMPLATE = """
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "paddle/fluid/pybind/static_op_function.h" #include "paddle/fluid/pybind/static_op_function.h"
#include "paddle/fluid/pybind/eager_op_function.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
{body} {body}
...@@ -44,18 +46,62 @@ void BindOpsAPI(pybind11::module *module) {{ ...@@ -44,18 +46,62 @@ void BindOpsAPI(pybind11::module *module) {{
FUNCTION_IMPL_TEMPLATE = """ FUNCTION_IMPL_TEMPLATE = """
static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{ static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
if (egr::Controller::Instance().GetCurrentTracer() == nullptr) {{
VLOG(6) << "Call static_api_{name}";
return static_api_{name}(self, args, kwargs);
}} else {{
VLOG(6) << "Call eager_api_{name}";
return eager_api_{name}(self, args, kwargs);
}}
}}"""
NO_DY_FUNCTION_IMPL_TEMPLATE = """
static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
VLOG(6) << "Call static_api_{name}";
return static_api_{name}(self, args, kwargs); return static_api_{name}(self, args, kwargs);
}}""" }}"""
OPS_API_TEMPLATE = """ OPS_API_TEMPLATE = """
{{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},""" {{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},"""
SPECIAL_STATIC_ONLY_APIS = [
'fetch',
'set_value_with_tensor',
'set_value_with_tensor_',
'fused_bn_add_activation_',
'fused_batch_norm_act_',
'add_n_',
'set_value',
'assign_value',
'set_value_',
'embedding_grad_sparse',
'add_n_with_kernel',
'print',
'send_v2',
'shadow_feed',
'recv_v2',
'rnn_',
'fused_scale_bias_relu_conv_bnstats',
'batch_norm_',
'c_allreduce_sum',
'c_embedding',
'c_identity',
]
class OpsAPIGen(CodeGen): class OpsAPIGen(CodeGen):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def _gen_one_function_impl(self, name): def _gen_one_function_impl(self, name):
if (
name.endswith('grad')
or name.endswith('grad_')
or name.endswith('xpu')
or name in SPECIAL_STATIC_ONLY_APIS
):
return NO_DY_FUNCTION_IMPL_TEMPLATE.format(name=name)
else:
return FUNCTION_IMPL_TEMPLATE.format(name=name) return FUNCTION_IMPL_TEMPLATE.format(name=name)
def _gen_one_ops_api(self, name): def _gen_one_ops_api(self, name):
......
pybind.h pybind.h
eager_op_function.cc eager_op_function.*
eager_legacy_op_function.cc eager_legacy_op_function.cc
...@@ -16,6 +16,14 @@ from paddle.fluid import core ...@@ -16,6 +16,14 @@ from paddle.fluid import core
__all__ = [] __all__ = []
UNIFIED_APIS = ['mean']
for name in dir(core.eager.ops): for name in dir(core.eager.ops):
globals()[name] = getattr(core.eager.ops, name) globals()[name] = getattr(core.eager.ops, name)
__all__.append(name) __all__.append(name)
for name in dir(core.ir.ops):
if name in UNIFIED_APIS:
globals()[name] = getattr(core.ir.ops, name)
if name not in __all__:
__all__.append(name)
...@@ -459,6 +459,7 @@ from . import hub # noqa: F401 ...@@ -459,6 +459,7 @@ from . import hub # noqa: F401
from . import linalg # noqa: F401 from . import linalg # noqa: F401
from . import fft # noqa: F401 from . import fft # noqa: F401
from . import signal # noqa: F401 from . import signal # noqa: F401
from . import _ir_ops # noqa: F401
import paddle.text # noqa: F401 import paddle.text # noqa: F401
import paddle.vision # noqa: F401 import paddle.vision # noqa: F401
......
...@@ -55,6 +55,8 @@ __all__ = [ ...@@ -55,6 +55,8 @@ __all__ = [
'xpu_places', 'xpu_places',
'cuda_pinned_places', 'cuda_pinned_places',
'in_dygraph_mode', 'in_dygraph_mode',
'in_new_ir_mode',
'in_dynamic_or_new_ir_mode',
'is_compiled_with_cinn', 'is_compiled_with_cinn',
'is_compiled_with_cuda', 'is_compiled_with_cuda',
'is_compiled_with_rocm', 'is_compiled_with_rocm',
...@@ -102,6 +104,7 @@ class GlobalThreadLocal(threading.local): ...@@ -102,6 +104,7 @@ class GlobalThreadLocal(threading.local):
if name == '_dygraph_tracer_': if name == '_dygraph_tracer_':
global _dygraph_tracer_ global _dygraph_tracer_
_dygraph_tracer_ = val _dygraph_tracer_ = val
core._switch_tracer(val)
self.__dict__[name] = val self.__dict__[name] = val
...@@ -209,6 +212,59 @@ def in_dygraph_mode(): ...@@ -209,6 +212,59 @@ def in_dygraph_mode():
return global_var._dygraph_tracer_ is not None return global_var._dygraph_tracer_ is not None
def in_new_ir_mode():
"""
This API checks whether paddle runs in static graph mode and use new ir api.
Returns:
bool: Whether paddle runs in static graph mode and use new ir api.
Examples:
.. code-block:: python
>>> import paddle
>>> print(paddle.framework.in_new_ir_mode())
False
>>> paddle.enable_static()
>>> paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
>>> print(paddle.framework.in_new_ir_mode())
True
"""
return ir.core._use_new_ir_api() and not in_dygraph_mode()
def in_dynamic_or_new_ir_mode():
"""
This API checks whether paddle runs in dynamic graph or new ir mode.
Returns:
bool: Whether paddle runs in static graph mode and use new ir api.
Examples:
.. code-block:: python
>>> import paddle
>>> print(paddle.framework.in_dynamic_or_new_ir_mode())
True
>>> paddle.enable_static()
>>> print(paddle.framework.in_dynamic_or_new_ir_mode())
False
>>> paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
>>> print(paddle.framework.in_dynamic_or_new_ir_mode())
True
"""
return in_dygraph_mode() or in_new_ir_mode()
global_ipu_index = -1 global_ipu_index = -1
global_ipu_stage = -1 global_ipu_stage = -1
ipu_index_attr_name = 'ipu_index' ipu_index_attr_name = 'ipu_index'
...@@ -7604,14 +7660,10 @@ def dygraph_guard_if_declarative(): ...@@ -7604,14 +7660,10 @@ def dygraph_guard_if_declarative():
def _dygraph_guard(tracer): def _dygraph_guard(tracer):
tmp_tracer = global_var._dygraph_tracer_ tmp_tracer = global_var._dygraph_tracer_
global_var._dygraph_tracer_ = tracer global_var._dygraph_tracer_ = tracer
if tracer is not None:
core._switch_tracer(tracer)
try: try:
yield yield
finally: finally:
if tmp_tracer is not None:
core._switch_tracer(tmp_tracer)
global_var._dygraph_tracer_ = tmp_tracer global_var._dygraph_tracer_ = tmp_tracer
...@@ -7622,8 +7674,6 @@ def _static_guard(): ...@@ -7622,8 +7674,6 @@ def _static_guard():
try: try:
yield yield
finally: finally:
if tmp_tracer is not None:
core._switch_tracer(tmp_tracer)
global_var._dygraph_tracer_ = tmp_tracer global_var._dygraph_tracer_ = tmp_tracer
......
...@@ -56,6 +56,8 @@ from ..fluid.framework import Parameter ...@@ -56,6 +56,8 @@ from ..fluid.framework import Parameter
from ..fluid.dygraph.base import enable_dygraph as disable_static # noqa: F401 from ..fluid.dygraph.base import enable_dygraph as disable_static # noqa: F401
from ..fluid.dygraph.base import disable_dygraph as enable_static # noqa: F401 from ..fluid.dygraph.base import disable_dygraph as enable_static # noqa: F401
from ..fluid.framework import in_dygraph_mode as in_dynamic_mode # noqa: F401 from ..fluid.framework import in_dygraph_mode as in_dynamic_mode # noqa: F401
from ..fluid.framework import in_new_ir_mode # noqa: F401
from ..fluid.framework import in_dynamic_or_new_ir_mode # noqa: F401
from ..fluid.framework import ( from ..fluid.framework import (
_current_expected_place, _current_expected_place,
_get_paddle_place, _get_paddle_place,
......
...@@ -190,6 +190,9 @@ class PartialProgramLayer: ...@@ -190,6 +190,9 @@ class PartialProgramLayer:
assert isinstance(self._build_strategy, BuildStrategy) assert isinstance(self._build_strategy, BuildStrategy)
self._origin_main_program = self._verify_program(main_program) self._origin_main_program = self._verify_program(main_program)
with paddle.fluid.framework._dygraph_guard(
paddle.fluid.dygraph.Tracer()
):
self._cuda_graph_vec = self._create_cuda_graph_vec() self._cuda_graph_vec = self._create_cuda_graph_vec()
self._cuda_graph_capture_mode = "" self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0 self._cuda_graph_pool_id = 0
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# TODO: define statistical functions of a tensor # TODO: define statistical functions of a tensor
import paddle import paddle
from paddle import _C_ops, _ir_ops, ir from paddle import _C_ops
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode, in_dynamic_or_new_ir_mode
from ..common_ops_import import Variable from ..common_ops_import import Variable
from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.data_feeder import check_type, check_variable_and_dtype
...@@ -83,11 +83,9 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -83,11 +83,9 @@ def mean(x, axis=None, keepdim=False, name=None):
>>> print(out4.numpy()) >>> print(out4.numpy())
[ 8.5 12.5 16.5] [ 8.5 12.5 16.5]
""" """
if in_dynamic_mode(): if in_dynamic_or_new_ir_mode():
return _C_ops.mean(x, axis, keepdim) return _C_ops.mean(x, axis, keepdim)
else: else:
if ir.core._use_new_ir_api():
return _ir_ops.mean(x, axis, keepdim)
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
check_variable_and_dtype( check_variable_and_dtype(
x, x,
......
...@@ -190,6 +190,7 @@ class TestInferenceBaseAPI(unittest.TestCase): ...@@ -190,6 +190,7 @@ class TestInferenceBaseAPI(unittest.TestCase):
predictor.run() predictor.run()
def test_paddle_tensor(): def test_paddle_tensor():
paddle.disable_static()
config = self.get_config(program, params) config = self.get_config(program, params)
predictor = create_predictor(config) predictor = create_predictor(config)
in_names = predictor.get_input_names() in_names = predictor.get_input_names()
...@@ -197,6 +198,7 @@ class TestInferenceBaseAPI(unittest.TestCase): ...@@ -197,6 +198,7 @@ class TestInferenceBaseAPI(unittest.TestCase):
in_data = paddle.Tensor(np.ones((1, 6, 32, 32)).astype(np.float32)) in_data = paddle.Tensor(np.ones((1, 6, 32, 32)).astype(np.float32))
in_handle.share_external_data(in_data) in_handle.share_external_data(in_data)
predictor.run() predictor.run()
paddle.enable_static()
test_lod_tensor() test_lod_tensor()
test_paddle_tensor() test_paddle_tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册