From 8631d73a0a0c7e7be452641236b1a820c3baedd8 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 19 Apr 2022 17:45:08 +0800 Subject: [PATCH] [Eager] make fast through to linear (#41945) * make fast through to linear * make fast through to linear * add to do for later upgrades * support build once for now --- ...e_op_funcs.h => eager_custom_python_api.h} | 46 ++++++++++++++++++- .../pybind/eager_op_function_generator.cc | 2 +- python/paddle/fluid/dygraph/layers.py | 9 +++- python/paddle/nn/functional/common.py | 8 +--- 4 files changed, 55 insertions(+), 10 deletions(-) rename paddle/fluid/pybind/{custom_handwrite_op_funcs.h => eager_custom_python_api.h} (59%) diff --git a/paddle/fluid/pybind/custom_handwrite_op_funcs.h b/paddle/fluid/pybind/eager_custom_python_api.h similarity index 59% rename from paddle/fluid/pybind/custom_handwrite_op_funcs.h rename to paddle/fluid/pybind/eager_custom_python_api.h index 044c3d5d176..c509ab56749 100644 --- a/paddle/fluid/pybind/custom_handwrite_op_funcs.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -54,9 +54,53 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args, } } +static PyObject *eager_api_final_state_linear(PyObject *self, PyObject *args, + PyObject *kwargs) { + PyThreadState *tstate = nullptr; + try { + auto x = GetTensorFromArgs("linear", "X", args, 0, false); + auto weight = GetTensorFromArgs("linear", "weight", args, 1, false); + auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true); + tstate = PyEval_SaveThread(); + if (bias.initialized()) { + auto mm_out = + matmul_final_state_dygraph_function(x, weight, false, false); + auto out = add_final_state_dygraph_function(bias, mm_out); + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } else { + auto mm_out = + matmul_final_state_dygraph_function(x, weight, false, false); + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(mm_out); + } + } catch (paddle::platform::EnforceNotMet &exception) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + std::ostringstream sout; + sout << exception.what(); + sout << " [operator < linear > error]"; + exception.set_error_str(sout.str()); + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } catch (...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyMethodDef CustomEagerFinalStateMethods[] = { {"run_program", (PyCFunction)(void (*)(void))eager_api_run_program, METH_VARARGS | METH_KEYWORDS, "C++ interface function for run_program in dygraph."}, - + {"final_state_linear", + (PyCFunction)(void (*)(void))eager_api_final_state_linear, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for run_program in dygraph."}, {nullptr, nullptr, 0, nullptr}}; diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 06d88be9bc8..2ac12165c1a 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -134,7 +134,7 @@ const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, M // need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE std::unordered_set CUSTOM_HANDWRITE_OPS_SET = {"run_program"}; const char* CUSTOM_HANDWRITE_OP_FUNC_FILE = - "#include \"paddle/fluid/pybind/custom_handwrite_op_funcs.h\"\n"; + "#include \"paddle/fluid/pybind/eager_custom_python_api.h\"\n"; // clang-format on static inline bool FindInsMap(const std::string& op_type, diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 193025b1864..41c1a0aa580 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -36,7 +36,7 @@ from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, from paddle.fluid import framework from ..param_attr import ParamAttr from paddle.fluid.executor import Executor, global_scope -from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_ +from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.core import VarDesc from paddle.fluid.dygraph import no_grad @@ -918,7 +918,12 @@ class Layer(object): return outputs def __call__(self, *inputs, **kwargs): - return self._dygraph_call_func(*inputs, **kwargs) + if (not in_declarative_mode()) and (not self._forward_pre_hooks) \ + and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode(): + self._build_once(*inputs, **kwargs) + return self.forward(*inputs, **kwargs) + else: + return self._dygraph_call_func(*inputs, **kwargs) def forward(self, *inputs, **kwargs): """ diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 287dc7d67de..907fd4e6252 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1534,12 +1534,8 @@ def linear(x, weight, bias=None, name=None): # [2.1077576 2.1077576 2.1077576 2.1077576 ]] """ if in_dygraph_mode(): - pre_bias = _C_ops.final_state_matmul(x, weight, False, False) - - if bias is None: - return pre_bias - - return _C_ops.final_state_add(pre_bias, bias) + #TODO(jiabin): using addmm for fast forward route + return _C_ops.final_state_linear(x, weight, bias) else: if _in_legacy_dygraph(): pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y', -- GitLab