提交 90dd0716 编写于 作者: M Megvii Engine Team

refactor(imperative): modify the python interface of custom op

GitOrigin-RevId: e82e5de480048bda95faf4107fbf9bbacfb79233
上级 cbf024bf
...@@ -7,24 +7,19 @@ ...@@ -7,24 +7,19 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..._imperative_rt.ops import _custom from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op
__all__ = [] __all__ = ["load"]
for k, v in _custom.__dict__.items(): def _gen_custom_op_maker(custom_op_name):
globals()[k] = v
__all__.append(k)
def gen_custom_op_maker(custom_op_name):
def op_maker(**kwargs): def op_maker(**kwargs):
return make_custom_op(custom_op_name, kwargs) return _make_custom_op(custom_op_name, kwargs)
return op_maker return op_maker
def load(lib_path): def load(lib_path):
op_in_this_lib = install(lib_path[0:-3], lib_path) op_in_this_lib = _install(lib_path[0:-3], lib_path)
for op in op_in_this_lib: for op in op_in_this_lib:
op_maker = gen_custom_op_maker(op) op_maker = _gen_custom_op_maker(op)
globals()[op] = op_maker globals()[op] = op_maker
__all__.append(op) __all__.append(op)
...@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) { ...@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) {
for (const auto &op: ops_in_lib) { for (const auto &op: ops_in_lib) {
ret.append(op); ret.append(op);
} }
return std::move(ret); return ret;
} }
bool uninstall_custom(const std::string &name) { bool uninstall_custom(const std::string &name) {
...@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) { ...@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) {
} }
void init_custom(pybind11::module m) { void init_custom(pybind11::module m) {
m.def("install", &install_custom); m.def("_install", &install_custom);
m.def("uninstall", &uninstall_custom); m.def("_uninstall", &uninstall_custom);
m.def("get_custom_op_list", &get_custom_op_list); m.def("_get_custom_op_list", &get_custom_op_list);
static PyMethodDef method_def = { static PyMethodDef method_def = {
"make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
}; };
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
pybind11::setattr(m, method_def.ml_name, func); pybind11::setattr(m, method_def.ml_name, func);
......
...@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) { ...@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) {
return op.name(); return op.name();
} }
} // custom_opdef
OP_TRAIT_REG(CustomOpDef, CustomOpDef) OP_TRAIT_REG(CustomOpDef, CustomOpDef)
.apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.apply_on_var_node(imperative::custom_opdef::apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) .apply_on_device_tensornd(apply_on_device_tensornd)
.infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) .infer_output_mem_desc(infer_output_mem_desc)
.hash(imperative::custom_opdef::hash) .hash(hash)
.is_same_st(imperative::custom_opdef::is_same_st) .is_same_st(is_same_st)
.props(imperative::custom_opdef::props) .props(props)
.make_name(imperative::custom_opdef::make_name) .make_name(make_name)
.fallback(); .fallback();
} // custom_opdef
} // imperative } // imperative
} // mgb } // mgb
...@@ -60,18 +60,5 @@ public: ...@@ -60,18 +60,5 @@ public:
std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const; std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const;
}; };
namespace custom_opdef { // avoid name conflict
void apply_on_device_tensornd(const OpDef&, const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*);
SmallVector<TensorPtr> apply_on_physical_tensor(const OpDef&, const SmallVector<TensorPtr>&);
VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector<LogicalTensorDesc>&);
size_t hash(const OpDef&);
bool is_same_st(const OpDef&, const OpDef&);
std::vector<std::pair<const char*, std::string>> props(const OpDef&);
std::string make_name(const OpDef&);
} // custom_opdef
} // imperative } // imperative
} // mgb } // mgb
...@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() { ...@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() {
} }
cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const {
// auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop());
// for (auto &&inp_var: input())
// ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE);
// ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE);
// return ret;
return OperatorNodeBase::do_make_node_prop(); return OperatorNodeBase::do_make_node_prop();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册