From 634de5906105f458d95de35e606e21d249dde8f6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 6 Nov 2020 18:27:43 +0800 Subject: [PATCH] feat(mge/imperative): add valid flag of `infer_output_attrs_fallible` GitOrigin-RevId: b2b32774eeb893503c25d3434fa6f2ba64f1c8c6 --- imperative/src/impl/interpreter_impl.cpp | 19 +++++++--- imperative/src/impl/op_def.cpp | 2 +- imperative/src/impl/ops/backward_graph.cpp | 36 ++++++++++--------- imperative/src/impl/ops/batch_norm.cpp | 4 +-- imperative/src/impl/ops/broadcast.cpp | 6 ++-- imperative/src/impl/ops/cond_take.cpp | 10 +++--- imperative/src/impl/ops/elemwise.cpp | 8 ++--- imperative/src/impl/ops/tensor_manip.cpp | 6 ++-- imperative/src/impl/profiler.cpp | 7 ++-- imperative/src/impl/proxy_graph.cpp | 25 +++++++------ imperative/src/impl/proxy_graph.h | 4 +-- imperative/src/impl/proxy_graph_detail.cpp | 5 ++- imperative/src/impl/proxy_graph_detail.h | 5 ++- .../src/include/megbrain/imperative/op_def.h | 2 +- .../megbrain/imperative/ops/backward_graph.h | 4 +-- 15 files changed, 81 insertions(+), 62 deletions(-) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index b0500549a..f66428486 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -63,15 +63,17 @@ SmallVector ChannelImpl::apply_op( input_infos.push_back(info); input_descs.push_back(info->desc); } - auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs); + + auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); ApplyOp cmd{std::move(op)}; cmd.inputs = std::move(input_infos); cmd.outputs.reserve(output_descs.size()); SmallVector outputs; - bool is_fallible = false; + // FIXME: remove this check when op check is correct + bool validated_bkp = true; for (auto&& desc : output_descs) { if (desc.layout.ndim == 0) { - is_fallible = true; + validated_bkp = false; } auto info = alloc(); info->desc = desc; @@ -80,8 +82,14 @@ SmallVector ChannelImpl::apply_op( outputs.push_back(info); } m_worker.add_task(std::move(cmd)); - if (is_fallible && m_async_level <= 1) { + if (!(validated && validated_bkp) && m_async_level == 1) { + sync(); + } else if (m_async_level == 0) { sync(); + // check device error + for (auto&& oup : cmd.outputs) { + oup->ptr->comp_node().sync(); + } } return outputs; } @@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { MGB_LOCK_GUARD(m_mutex); dest->value_fetched = ptr->value_fetched(); + // update tensor desc for static infer + dest->desc.layout = ptr->layout(); + dest->desc.comp_node = ptr->comp_node(); dest->ptr = std::move(ptr); if (m_waitee == dest) { m_cv.notify_all(); diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 770aab4f1..9f052ac0e 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node( return def.trait()->apply_on_var_node(def, inputs); } -SmallVector OpDef::infer_output_attrs_fallible( +std::tuple, bool> OpDef::infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { return def.trait()->infer_output_attrs_fallible(def, inputs); diff --git a/imperative/src/impl/ops/backward_graph.cpp b/imperative/src/impl/ops/backward_graph.cpp index e452ac590..c20432ac9 100644 --- a/imperative/src/impl/ops/backward_graph.cpp +++ b/imperative/src/impl/ops/backward_graph.cpp @@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply( inputs); } -SmallVector -BackwardGraph::InternalGraph::infer_attrs( +std::tuple, bool> BackwardGraph::InternalGraph::infer_attrs( const SmallVector& inputs) const { using TensorAttr = LogicalTensorDesc; ThinHashMap node2attr; auto&& input_nodes = this->inputs; + auto&& output_nodes = this->outputs; mgb_assert(inputs.size() == input_nodes.size()); for (size_t i = 0; i < inputs.size(); ++ i) { node2attr[input_nodes[i]] = inputs[i]; @@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs( i.second->layout(), i.second->comp_node(), value->proxy_to_default_cpu()}; } + bool validated = true; for (size_t i = 0; i < exprs.size(); ++ i) { - auto&& expr = exprs[i]; - SmallVector inputs; - for (auto &&in : std::get<1>(expr)) { - inputs.push_back(node2attr.at(in)); + auto&& [expr_op, expr_inps, expr_oups] = exprs[i]; + SmallVector expr_input_descs; + for (auto &&inp : expr_inps) { + expr_input_descs.push_back(node2attr.at(inp)); } - auto outputs = OpDef::infer_output_attrs_fallible( - *std::get<0>(expr), inputs); - auto output_nodes = std::get<2>(expr); - mgb_assert(outputs.size() == output_nodes.size()); - for (size_t i = 0; i < outputs.size(); ++ i) { - node2attr[output_nodes[i]] = outputs[i]; + + auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( + *expr_op, expr_input_descs); + validated = validated && expr_validated; + + mgb_assert(expr_output_descs.size() == expr_oups.size()); + for (size_t i = 0; i < expr_output_descs.size(); ++ i) { + node2attr[expr_oups[i]] = expr_output_descs[i]; } } + SmallVector ret; - for (auto &&i : outputs) { + for (auto &&i : output_nodes) { ret.push_back(node2attr.at(i)); } - return ret; + return {ret, validated}; } MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); @@ -72,11 +76,11 @@ SmallVector backward_impl( .graph().apply(tensors); } -SmallVector infer_tensor_attrs( +std::tuple, bool> infer_tensor_attrs( const OpDef& backward_graph, const SmallVector inputs) { return backward_graph.cast_final_safe() - .graph().infer_attrs(inputs); + .graph().infer_attrs(inputs); } OP_TRAIT_REG(BackwardGraph, BackwardGraph) diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index 07e41899c..913dc3b07 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node( } } -SmallVector infer_output_attrs_fallible( +std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); @@ -66,7 +66,7 @@ SmallVector infer_output_attrs_fallible( out_shapes[i] = {i1.layout, i1.comp_node}; } out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; - return out_shapes; + return {out_shapes, true}; } OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 23e0137a0..1a2bb4627 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape, return true; } -SmallVector infer_output_attrs_fallible( +std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { def.cast_final_safe(); @@ -59,7 +59,7 @@ SmallVector infer_output_attrs_fallible( TensorLayout out_layout = src.layout; if (tshp.layout.ndim == 0 || tshp.value.empty()) { out_layout.ndim = 0; - return {{out_layout, src.comp_node}}; + return {{{out_layout, src.comp_node}}, true}; } mgb_assert( tshp.layout.ndim == 1, @@ -77,7 +77,7 @@ SmallVector infer_output_attrs_fallible( src.layout.TensorShape::to_string().c_str(), out_layout.TensorShape::to_string().c_str()); - return {{out_layout, src.comp_node}}; + return {{{out_layout, src.comp_node}}, true}; } OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 133a9a933..cc49e5c23 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -25,7 +25,7 @@ namespace { class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { using Output = std::array; - + CompNode m_cn; Output m_out; @@ -110,14 +110,14 @@ SmallVector apply_on_physical_tensor( return out; } -SmallVector infer_output_attrs_fallible( +std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto cn = inputs[0].comp_node; - return { + return {{ {TensorLayout(inputs[0].layout.dtype), cn}, {TensorLayout(dtype::Int32()), cn} - }; + }, true}; } OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) @@ -128,4 +128,4 @@ OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) } // namespace -} // namespace mgb::imperative \ No newline at end of file +} // namespace mgb::imperative diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index edf55acb6..8b789b490 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node( return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); } -SmallVector infer_output_attrs_fallible( +std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); @@ -40,7 +40,7 @@ SmallVector infer_output_attrs_fallible( TensorShapeArray inp_shapes; DType out_dt; CompNode out_cn; - for (size_t i = 0; i < inputs.size(); ++ i) { + for (size_t i = 0; i < inputs.size(); ++ i) { auto &&t = inputs[i]; if (!i) { out_cn = t.comp_node; @@ -55,12 +55,12 @@ SmallVector infer_output_attrs_fallible( TensorLayout out_layout; out_layout.ndim = 0; out_layout.dtype = out_dt; - return {{out_layout, out_cn}}; + return {{{out_layout, out_cn}}, true}; } } auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); - return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}; + return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; } OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index ae16edfd2..a4d23de37 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -40,21 +40,21 @@ SmallVector apply_on_physical_tensor( return {Tensor::make(std::move(hv))}; } -SmallVector infer_output_attrs_fallible( +std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { def.cast_final_safe(); mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); auto&& desc = inputs[0]; if (!desc.layout.ndim) { - return {{TensorLayout(dtype::Int32()), desc.comp_node}}; + return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; } DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); auto* ptr = value.ptr(); for (size_t i = 0; i < desc.layout.ndim; ++i) { ptr[i] = desc.layout[i]; } - return {{value.layout(), desc.comp_node, std::move(value)}}; + return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { diff --git a/imperative/src/impl/profiler.cpp b/imperative/src/impl/profiler.cpp index ebd28c7fc..65deeb96f 100644 --- a/imperative/src/impl/profiler.cpp +++ b/imperative/src/impl/profiler.cpp @@ -28,12 +28,13 @@ namespace { CompNode::UnorderedSet collect_comp_nodes( const OpDef& def, const SmallVector& inputs) { CompNode::UnorderedSet comp_nodes; - SmallVector descs; + SmallVector inp_descs; for (auto&& i : inputs) { comp_nodes.insert(i->comp_node()); - descs.push_back({i->layout(), i->comp_node(), {}}); + inp_descs.push_back({i->layout(), i->comp_node(), {}}); } - for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) { + SmallVector oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs)); + for (auto&& output_attr : oup_descs) { comp_nodes.insert(output_attr.comp_node); } return comp_nodes; diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index f04a71190..85fc8f39b 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -14,6 +14,7 @@ #include "megbrain/graph/static_infer.h" #include "megbrain/graph/operator_node.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/backward_graph.h" @@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); } auto opr = OpDef::apply_on_var_node(opdef, vinputs); - mgb_assert(opr->dyn_typeinfo() != InputPlaceholder::typeinfo()); + mgb_assert(!opr->same_type()); for (auto &&i : opr->input()) { - mgb_assert(i->owner_opr()->dyn_typeinfo() == - InputPlaceholder::typeinfo()); + mgb_assert(i->owner_opr()->same_type()); } return opr; } @@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef, return get_proxy_opr(opdef, inputs)->usable_output().size(); } -SmallVector ProxyGraph::infer_output_attrs_fallible( +std::tuple, bool> ProxyGraph::infer_output_attrs_fallible( const OpDef& opdef, const SmallVector& inputs) { auto opr = get_proxy_opr(opdef, inputs); CUR_OPR_GUARD(opr); - do_shape_infer(false); - SmallVector ret; + SmallVector outputs; + bool validated = do_shape_infer(false); for (auto&& i : opr->usable_output()) { - ret.push_back({{i->shape(), i->dtype()}, i->comp_node()}); + outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()}); } - return ret; + bool need_check = opr->same_type(); + return {outputs, validated && !need_check}; } struct ProxyGraph::GradGraph { @@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVectorupdate(); + bool validated = true; for (auto* var : m_cur_opr->output()) { if (sync_value) { var->shape(m_static_infer_manager->infer_shape(var)); } else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) { - var->shape(*shape); + var->shape(*shape); + } else { + validated = false; } } + return validated; } TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index e52768a54..d4b946457 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -48,7 +48,7 @@ public: const OpDef& opdef, const SmallVector& inputs); - SmallVector infer_output_attrs_fallible( + std::tuple, bool> infer_output_attrs_fallible( const OpDef& opdef, const SmallVector& inputs); @@ -88,7 +88,7 @@ private: /********************** Common Helper **********************/ - void do_shape_infer(bool sync_value); + bool do_shape_infer(bool sync_value); TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true); diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index 42a02a82d..163f38ddb 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def, return outputs; } -SmallVector -infer_output_attrs_fallible(const OpDef& def, +std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, const SmallVector& inputs) { auto&& graph = ProxyGraph::get_default_graph(); return graph->infer_output_attrs_fallible(def, inputs); @@ -136,4 +135,4 @@ make_backward_graph(const OpDef& def, } // namespace imperative } // namespace mgb -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/proxy_graph_detail.h b/imperative/src/impl/proxy_graph_detail.h index e148b0bba..be0fbe484 100644 --- a/imperative/src/impl/proxy_graph_detail.h +++ b/imperative/src/impl/proxy_graph_detail.h @@ -21,8 +21,7 @@ SmallVector apply_on_physical_tensor(const OpDef& def, const SmallVector& inputs); -SmallVector -infer_output_attrs_fallible(const OpDef& def, +std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, const SmallVector& inputs); BackwardGraphResult @@ -35,4 +34,4 @@ make_backward_graph(const OpDef& def, } // namespace imperative } // namespace mgb -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 0aff1d53e..d57f7edba 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -44,7 +44,7 @@ public: const OpDef& def, const VarNodeArray& inputs); - static SmallVector infer_output_attrs_fallible( + static std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs); diff --git a/imperative/src/include/megbrain/imperative/ops/backward_graph.h b/imperative/src/include/megbrain/imperative/ops/backward_graph.h index cec594573..ba452703f 100644 --- a/imperative/src/include/megbrain/imperative/ops/backward_graph.h +++ b/imperative/src/include/megbrain/imperative/ops/backward_graph.h @@ -38,8 +38,8 @@ public: SmallVector apply(const SmallVector& inputs) const; - SmallVector - infer_attrs(const SmallVector& inputs) const; + std::tuple, bool> infer_attrs( + const SmallVector& inputs) const; template SmallVector interpret(F&& f, C&& c, const SmallVector& inputs) const { -- GitLab