提交 90dd0716 编写于 作者: M Megvii Engine Team

refactor(imperative): modify the python interface of custom op

GitOrigin-RevId: e82e5de480048bda95faf4107fbf9bbacfb79233
上级 cbf024bf
......@@ -7,24 +7,19 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..._imperative_rt.ops import _custom
from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op
__all__ = []
__all__ = ["load"]
for k, v in _custom.__dict__.items():
globals()[k] = v
__all__.append(k)
def gen_custom_op_maker(custom_op_name):
def _gen_custom_op_maker(custom_op_name):
def op_maker(**kwargs):
return make_custom_op(custom_op_name, kwargs)
return _make_custom_op(custom_op_name, kwargs)
return op_maker
def load(lib_path):
op_in_this_lib = install(lib_path[0:-3], lib_path)
op_in_this_lib = _install(lib_path[0:-3], lib_path)
for op in op_in_this_lib:
op_maker = gen_custom_op_maker(op)
op_maker = _gen_custom_op_maker(op)
globals()[op] = op_maker
__all__.append(op)
......@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) {
for (const auto &op: ops_in_lib) {
ret.append(op);
}
return std::move(ret);
return ret;
}
bool uninstall_custom(const std::string &name) {
......@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) {
}
void init_custom(pybind11::module m) {
m.def("install", &install_custom);
m.def("uninstall", &uninstall_custom);
m.def("get_custom_op_list", &get_custom_op_list);
m.def("_install", &install_custom);
m.def("_uninstall", &uninstall_custom);
m.def("_get_custom_op_list", &get_custom_op_list);
static PyMethodDef method_def = {
"make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
};
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
pybind11::setattr(m, method_def.ml_name, func);
......
......@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) {
return op.name();
}
} // custom_opdef
OP_TRAIT_REG(CustomOpDef, CustomOpDef)
.apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor)
.apply_on_var_node(imperative::custom_opdef::apply_on_var_node)
.apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd)
.infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible)
.infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc)
.hash(imperative::custom_opdef::hash)
.is_same_st(imperative::custom_opdef::is_same_st)
.props(imperative::custom_opdef::props)
.make_name(imperative::custom_opdef::make_name)
.apply_on_physical_tensor(apply_on_physical_tensor)
.apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.hash(hash)
.is_same_st(is_same_st)
.props(props)
.make_name(make_name)
.fallback();
} // custom_opdef
} // imperative
} // mgb
......@@ -60,18 +60,5 @@ public:
std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const;
};
namespace custom_opdef { // avoid name conflict
void apply_on_device_tensornd(const OpDef&, const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*);
SmallVector<TensorPtr> apply_on_physical_tensor(const OpDef&, const SmallVector<TensorPtr>&);
VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector<LogicalTensorDesc>&);
size_t hash(const OpDef&);
bool is_same_st(const OpDef&, const OpDef&);
std::vector<std::pair<const char*, std::string>> props(const OpDef&);
std::string make_name(const OpDef&);
} // custom_opdef
} // imperative
} // mgb
......@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() {
}
cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const {
// auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop());
// for (auto &&inp_var: input())
// ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE);
// ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE);
// return ret;
return OperatorNodeBase::do_make_node_prop();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册