提交 18274e02 编写于 作者: M Megvii Engine Team

fix(imperative): fix error message when applying custom function with non-tensor arguments

GitOrigin-RevId: 387d6fda4a917f4d822a88c9c56e57cc3667f853
上级 3e58cbb8
...@@ -123,6 +123,11 @@ class Function(ops.PyOpBase): ...@@ -123,6 +123,11 @@ class Function(ops.PyOpBase):
This method should return a tuple of Tensor or a single Tensor representing the output This method should return a tuple of Tensor or a single Tensor representing the output
of the function. of the function.
.. note::
positional arguments should all be Tensor
""" """
raise NotImplementedError raise NotImplementedError
......
...@@ -98,6 +98,12 @@ OpTraitRegistry& OpTraitRegistry::fallback() { ...@@ -98,6 +98,12 @@ OpTraitRegistry& OpTraitRegistry::fallback() {
if (!trait->decide_dispatch_mode) { if (!trait->decide_dispatch_mode) {
trait->decide_dispatch_mode = fallback_decide_dispatch_mode; trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
} }
if (!trait->make_name) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
trait->make_name = make_name;
}
return *this; return *this;
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
namespace mgb::imperative { namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback();
namespace { namespace fastpathcopy { namespace { namespace fastpathcopy {
auto apply_on_var_node( auto apply_on_var_node(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册