From 90dd07161c18a64d396597230b514b2d903c561b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Sep 2021 13:46:29 +0800 Subject: [PATCH] refactor(imperative): modify the python interface of custom op GitOrigin-RevId: e82e5de480048bda95faf4107fbf9bbacfb79233 --- .../ops/{custom/__init__.py => custom.py} | 17 +++++--------- imperative/python/src/ops.cpp | 10 ++++----- imperative/src/impl/ops/custom_opdef.cpp | 22 +++++++++---------- .../megbrain/imperative/ops/custom_opdef.h | 13 ----------- src/opr/impl/custom_opnode.cpp | 5 ----- 5 files changed, 22 insertions(+), 45 deletions(-) rename imperative/python/megengine/core/ops/{custom/__init__.py => custom.py} (61%) diff --git a/imperative/python/megengine/core/ops/custom/__init__.py b/imperative/python/megengine/core/ops/custom.py similarity index 61% rename from imperative/python/megengine/core/ops/custom/__init__.py rename to imperative/python/megengine/core/ops/custom.py index 3f701da0e..75e56aa89 100644 --- a/imperative/python/megengine/core/ops/custom/__init__.py +++ b/imperative/python/megengine/core/ops/custom.py @@ -7,24 +7,19 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..._imperative_rt.ops import _custom +from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op -__all__ = [] +__all__ = ["load"] -for k, v in _custom.__dict__.items(): - globals()[k] = v - __all__.append(k) - - -def gen_custom_op_maker(custom_op_name): +def _gen_custom_op_maker(custom_op_name): def op_maker(**kwargs): - return make_custom_op(custom_op_name, kwargs) + return _make_custom_op(custom_op_name, kwargs) return op_maker def load(lib_path): - op_in_this_lib = install(lib_path[0:-3], lib_path) + op_in_this_lib = _install(lib_path[0:-3], lib_path) for op in op_in_this_lib: - op_maker = gen_custom_op_maker(op) + op_maker = _gen_custom_op_maker(op) globals()[op] = op_maker __all__.append(op) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 2ecfadf58..f9e1fbde4 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) { for (const auto &op: ops_in_lib) { ret.append(op); } - return std::move(ret); + return ret; } bool uninstall_custom(const std::string &name) { @@ -701,12 +701,12 @@ py::list get_custom_op_list(void) { } void init_custom(pybind11::module m) { - m.def("install", &install_custom); - m.def("uninstall", &uninstall_custom); - m.def("get_custom_op_list", &get_custom_op_list); + m.def("_install", &install_custom); + m.def("_uninstall", &uninstall_custom); + m.def("_get_custom_op_list", &get_custom_op_list); static PyMethodDef method_def = { - "make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" + "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" }; auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); pybind11::setattr(m, method_def.ml_name, func); diff --git a/imperative/src/impl/ops/custom_opdef.cpp b/imperative/src/impl/ops/custom_opdef.cpp index d3376f6e2..f1353c935 100644 --- a/imperative/src/impl/ops/custom_opdef.cpp +++ b/imperative/src/impl/ops/custom_opdef.cpp @@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) { return op.name(); } -} // custom_opdef - OP_TRAIT_REG(CustomOpDef, CustomOpDef) - .apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) - .apply_on_var_node(imperative::custom_opdef::apply_on_var_node) - .apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) - .infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) - .infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) - .hash(imperative::custom_opdef::hash) - .is_same_st(imperative::custom_opdef::is_same_st) - .props(imperative::custom_opdef::props) - .make_name(imperative::custom_opdef::make_name) + .apply_on_physical_tensor(apply_on_physical_tensor) + .apply_on_var_node(apply_on_var_node) + .apply_on_device_tensornd(apply_on_device_tensornd) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .infer_output_mem_desc(infer_output_mem_desc) + .hash(hash) + .is_same_st(is_same_st) + .props(props) + .make_name(make_name) .fallback(); +} // custom_opdef + } // imperative } // mgb diff --git a/imperative/src/include/megbrain/imperative/ops/custom_opdef.h b/imperative/src/include/megbrain/imperative/ops/custom_opdef.h index 5f69762f7..82568a827 100644 --- a/imperative/src/include/megbrain/imperative/ops/custom_opdef.h +++ b/imperative/src/include/megbrain/imperative/ops/custom_opdef.h @@ -60,18 +60,5 @@ public: std::shared_ptr create_opdef(const custom::RunTimeId&, const custom::Param&) const; }; -namespace custom_opdef { // avoid name conflict - -void apply_on_device_tensornd(const OpDef&, const SmallVector&, SmallVector*); -SmallVector apply_on_physical_tensor(const OpDef&, const SmallVector&); -VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&); -std::tuple, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector&); -size_t hash(const OpDef&); -bool is_same_st(const OpDef&, const OpDef&); -std::vector> props(const OpDef&); -std::string make_name(const OpDef&); - -} // custom_opdef - } // imperative } // mgb diff --git a/src/opr/impl/custom_opnode.cpp b/src/opr/impl/custom_opnode.cpp index 3b7931d2f..4fb49a185 100644 --- a/src/opr/impl/custom_opnode.cpp +++ b/src/opr/impl/custom_opnode.cpp @@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() { } cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { - // auto ret = &const_cast(node_prop()); - // for (auto &&inp_var: input()) - // ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE); - // ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE); - // return ret; return OperatorNodeBase::do_make_node_prop(); } -- GitLab