提交 cdb692d2 编写于 作者: M Megvii Engine Team

refactor(imperative): add TODO tag for some functions

GitOrigin-RevId: e295a1fa5537f13bc65f9e82b44a3f9cd56992a6
上级 90dd0716
...@@ -7,13 +7,20 @@ ...@@ -7,13 +7,20 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op from .._imperative_rt.ops._custom import (
_get_custom_op_list,
_install,
_make_custom_op,
_uninstall,
)
__all__ = ["load"] __all__ = ["load"]
def _gen_custom_op_maker(custom_op_name): def _gen_custom_op_maker(custom_op_name):
def op_maker(**kwargs): def op_maker(**kwargs):
return _make_custom_op(custom_op_name, kwargs) return _make_custom_op(custom_op_name, kwargs)
return op_maker return op_maker
......
...@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs ...@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs
for (auto i_shape: i_shapes) { for (auto i_shape: i_shapes) {
if (i_shape.ndim == 0) { if (i_shape.ndim == 0) {
success = false; success = false;
break;
} }
} }
...@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def, ...@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def,
auto cn = output.comp_node(); auto cn = output.comp_node();
cn.activate(); cn.activate();
} }
// [TODO] sync should be modified
CompNode::sync_all(); CompNode::sync_all();
auto&& op = static_cast<const CustomOpDef&>(def); auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inputs, outputs); op.compute(inputs, outputs);
// for (auto &&output: (*outputs)) {
// auto cn = output.comp_node();
// cn.sync(); // cannot sync ??????????
// }
CompNode::sync_all(); CompNode::sync_all();
} }
...@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
} }
VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) { VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) {
SymbolVarArray input_syms;
for (auto &input_var: inputs)
input_syms.emplace_back(input_var);
auto&& op = static_cast<const CustomOpDef&>(def); auto&& op = static_cast<const CustomOpDef&>(def);
OperatorNodeConfig config; OperatorNodeConfig config;
SymbolVarArray output_syms = opr::CustomOpNode::make( VarNodeArray outputs = opr::CustomOpNode::make(
op.impl(), input_syms, op.param(), config op.impl(), inputs, op.param(), config
); );
VarNodeArray outputs;
for (auto &output_sym: output_syms)
outputs.push_back(output_sym.node());
return outputs; return outputs;
} }
...@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) { ...@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) {
return a.param() == b.param() && a.runtime_id() == b.runtime_id(); return a.param() == b.param() && a.runtime_id() == b.runtime_id();
} }
// [TODO] to be implemented
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now");
// can be implement with param schema // can be implement with param schema
......
...@@ -140,7 +140,8 @@ void CustomOpNode::do_execute(ExecEnv &env) { ...@@ -140,7 +140,8 @@ void CustomOpNode::do_execute(ExecEnv &env) {
std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs); std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs); std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
m_op->compute(custom_inputs, m_param, custom_outputs); m_op->compute(custom_inputs, m_param, custom_outputs);
CompNode::sync_all(); // whether reasonable // [TODO] sync should be modified
CompNode::sync_all();
this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>( this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
this, m_comp_node this, m_comp_node
...@@ -157,7 +158,8 @@ void CustomOpNode::init_output_static_infer_desc() { ...@@ -157,7 +158,8 @@ void CustomOpNode::init_output_static_infer_desc() {
auto &&mgr = owner_graph()->static_infer_manager(); auto &&mgr = owner_graph()->static_infer_manager();
DepVal dep; DepVal dep;
if (true) { // need design a function to allow user to decide it // [TODO] need design a interface to allow user to decide it
if (true) {
for (auto input_var: input()) for (auto input_var: input())
dep.push_back({input_var, DepType::SHAPE}); dep.push_back({input_var, DepType::SHAPE});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册