未验证 提交 0c141322 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] make fast through to linear (#41945) (#41995)

* make fast through to linear

* make fast through to linear

* add to do for later upgrades

* support build once for now
上级 f637e3d2
...@@ -54,9 +54,53 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args, ...@@ -54,9 +54,53 @@ static PyObject *eager_api_run_program(PyObject *self, PyObject *args,
} }
} }
static PyObject *eager_api_final_state_linear(PyObject *self, PyObject *args,
PyObject *kwargs) {
PyThreadState *tstate = nullptr;
try {
auto x = GetTensorFromArgs("linear", "X", args, 0, false);
auto weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true);
tstate = PyEval_SaveThread();
if (bias.initialized()) {
auto mm_out =
matmul_final_state_dygraph_function(x, weight, false, false);
auto out = add_final_state_dygraph_function(bias, mm_out);
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
} else {
auto mm_out =
matmul_final_state_dygraph_function(x, weight, false, false);
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(mm_out);
}
} catch (paddle::platform::EnforceNotMet &exception) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
std::ostringstream sout;
sout << exception.what();
sout << " [operator < linear > error]";
exception.set_error_str(sout.str());
ThrowExceptionToPython(std::current_exception());
return nullptr;
} catch (...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyMethodDef CustomEagerFinalStateMethods[] = { static PyMethodDef CustomEagerFinalStateMethods[] = {
{"run_program", (PyCFunction)(void (*)(void))eager_api_run_program, {"run_program", (PyCFunction)(void (*)(void))eager_api_run_program,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
"C++ interface function for run_program in dygraph."}, "C++ interface function for run_program in dygraph."},
{"final_state_linear",
(PyCFunction)(void (*)(void))eager_api_final_state_linear,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for run_program in dygraph."},
{nullptr, nullptr, 0, nullptr}}; {nullptr, nullptr, 0, nullptr}};
...@@ -134,7 +134,7 @@ const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, M ...@@ -134,7 +134,7 @@ const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, M
// need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE // need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE
std::unordered_set<std::string> CUSTOM_HANDWRITE_OPS_SET = {"run_program"}; std::unordered_set<std::string> CUSTOM_HANDWRITE_OPS_SET = {"run_program"};
const char* CUSTOM_HANDWRITE_OP_FUNC_FILE = const char* CUSTOM_HANDWRITE_OP_FUNC_FILE =
"#include \"paddle/fluid/pybind/custom_handwrite_op_funcs.h\"\n"; "#include \"paddle/fluid/pybind/eager_custom_python_api.h\"\n";
// clang-format on // clang-format on
static inline bool FindInsMap(const std::string& op_type, static inline bool FindInsMap(const std::string& op_type,
......
...@@ -36,7 +36,7 @@ from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, ...@@ -36,7 +36,7 @@ from .base import program_desc_tracing_guard, param_guard, in_declarative_mode,
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_ from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_, in_dygraph_mode
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.core import VarDesc from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad from paddle.fluid.dygraph import no_grad
...@@ -918,7 +918,12 @@ class Layer(object): ...@@ -918,7 +918,12 @@ class Layer(object):
return outputs return outputs
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
return self._dygraph_call_func(*inputs, **kwargs) if (not in_declarative_mode()) and (not self._forward_pre_hooks) \
and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode():
self._build_once(*inputs, **kwargs)
return self.forward(*inputs, **kwargs)
else:
return self._dygraph_call_func(*inputs, **kwargs)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
""" """
......
...@@ -1534,12 +1534,8 @@ def linear(x, weight, bias=None, name=None): ...@@ -1534,12 +1534,8 @@ def linear(x, weight, bias=None, name=None):
# [2.1077576 2.1077576 2.1077576 2.1077576 ]] # [2.1077576 2.1077576 2.1077576 2.1077576 ]]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
pre_bias = _C_ops.final_state_matmul(x, weight, False, False) #TODO(jiabin): using addmm for fast forward route
return _C_ops.final_state_linear(x, weight, bias)
if bias is None:
return pre_bias
return _C_ops.final_state_add(pre_bias, bias)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y', pre_bias = _C_ops.matmul_v2(x, weight, 'trans_x', False, 'trans_y',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册