diff --git a/paddle/fluid/pybind/custom_handwrite_op_funcs.h b/paddle/fluid/pybind/eager_custom_python_api.h similarity index 59% rename from paddle/fluid/pybind/custom_handwrite_op_funcs.h rename to paddle/fluid/pybind/eager_custom_python_api.h index 044c3d5d176e1a021952469db0623197b6302936..c509ab5674930a8814ccb1934fcfbec2f55fdfef 100644 --- a/paddle/fluid/pybind/custom_handwrite_op_funcs.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -54,9 +54,53 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args, } } +static PyObject *eager_api_final_state_linear(PyObject *self, PyObject *args, + PyObject *kwargs) { + PyThreadState *tstate = nullptr; + try { + auto x = GetTensorFromArgs("linear", "X", args, 0, false); + auto weight = GetTensorFromArgs("linear", "weight", args, 1, false); + auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true); + tstate = PyEval_SaveThread(); + if (bias.initialized()) { + auto mm_out = + matmul_final_state_dygraph_function(x, weight, false, false); + auto out = add_final_state_dygraph_function(bias, mm_out); + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + } else { + auto mm_out = + matmul_final_state_dygraph_function(x, weight, false, false); + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(mm_out); + } + } catch (paddle::platform::EnforceNotMet &exception) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + std::ostringstream sout; + sout << exception.what(); + sout << " [operator < linear > error]"; + exception.set_error_str(sout.str()); + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } catch (...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyMethodDef CustomEagerFinalStateMethods[] = { {"run_program", (PyCFunction)(void (*)(void))eager_api_run_program, METH_VARARGS | METH_KEYWORDS, "C++ interface function for run_program in dygraph."}, - + {"final_state_linear", + (PyCFunction)(void (*)(void))eager_api_final_state_linear, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for run_program in dygraph."}, {nullptr, nullptr, 0, nullptr}}; diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 06d88be9bc8ccd5739d3b88c8b152348aab69393..2ac12165c1a66c0379442284c6ad68f6c2c32bfe 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -134,7 +134,7 @@ const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, M // need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE std::unordered_set CUSTOM_HANDWRITE_OPS_SET = {"run_program"}; const char* CUSTOM_HANDWRITE_OP_FUNC_FILE = - "#include \"paddle/fluid/pybind/custom_handwrite_op_funcs.h\"\n"; + "#include \"paddle/fluid/pybind/eager_custom_python_api.h\"\n"; // clang-format on static inline bool FindInsMap(const std::string& op_type, diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 193025b1864abcfd33e1c40d2ef587f98a070342..41c1a0aa5808e8007cd5d234cacc3f109c3e327d 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -36,7 +36,7 @@ from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, from paddle.fluid import framework from ..param_attr import ParamAttr from paddle.fluid.executor import Executor, global_scope -from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_ +from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.core import VarDesc from paddle.fluid.dygraph import no_grad @@ -918,7 +918,12 @@ class Layer(object): return outputs def __call__(self, *inputs, **kwargs): - return self._dygraph_call_func(*inputs, **kwargs) + if (not in_declarative_mode()) and (not self._forward_pre_hooks) \ + and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode(): + self._build_once(*inputs, **kwargs) + return self.forward(*inputs, **kwargs) + else: + return self._dygraph_call_func(*inputs, **kwargs) def forward(self, *inputs, **kwargs): """ diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 287dc7d67def88680a312c8b988080058d393548..907fd4e6252c6fdfac776c588d5c55ecc6144b2e 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1534,12 +1534,8 @@ def linear(x, weight, bias=None, name=None): # [2.1077576 2.1077576 2.1077576 2.1077576 ]] """ if in_dygraph_mode(): - pre_bias = _C_ops.final_state_matmul(x, weight, False, False) - - if bias is None: - return pre_bias - - return _C_ops.final_state_add(pre_bias, bias) + #TODO(jiabin): using addmm for fast forward route + return _C_ops.final_state_linear(x, weight, bias) else: if _in_legacy_dygraph(): pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y',