From c2293815b2dc0dfebaa510d4ddf94dffa7dee90f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 17 Mar 2022 17:20:21 +0800 Subject: [PATCH] fix(autodiff): proxy_graph_detail::make_backward_graph support multiple opnodes GitOrigin-RevId: 2c0c8f330da645438f2a5ef17c9acef588f89fb3 --- imperative/src/impl/proxy_graph.cpp | 177 ++++++------------ imperative/src/impl/subgraph_detail.cpp | 92 ++++++++- .../include/megbrain/imperative/subgraph.h | 17 +- .../megbrain/imperative/subgraph_detail.h | 2 + 4 files changed, 158 insertions(+), 130 deletions(-) diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 8cb19e987..b25b9124a 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -11,10 +11,13 @@ #include "./proxy_graph.h" #include "./blob_manager_impl.h" +#include "megbrain/graph.h" #include "megbrain/graph/operator_node.h" #include "megbrain/graph/static_infer.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/imperative/subgraph_detail.h" +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" @@ -486,139 +489,83 @@ EncodedSubgraph ProxyGraph::make_backward_graph( const OpDef& opdef, const SmallVector& input_descs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { - ThinHashMap var2idx; - auto push = [&var2idx, - cnt = 1](VarNode* var) mutable { // cnt is always greater non zero - auto&& ret = var2idx.emplace(var, cnt++); - mgb_assert(ret.second, "var %s has been already inserted", var->cname()); - return ret.first->second; - }; + using op_t = OperatorNodeBase*; + using var_t = VarNode*; + using vars_t = VarNodeArray; auto inputs = make_input_place_holders(input_descs); - auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); - auto&& outputs = fwd->usable_output(); + auto outputs = OpDef::apply_on_var_node(opdef, inputs); SmallVector output_descs; for (auto&& i : outputs) { output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); } + GradContext grad_context{[&](VarNode* lhs, VarNode* rhs) -> VarNode* { + auto add = opr::Elemwise::Mode::ADD; + return opr::Elemwise::make(VarNodeArray{lhs, rhs}, add).node(); + }}; + cg::DepOprIter iter{[&](OperatorNodeBase* op) { + grad_context.record_expr(op, op->input(), op->output()); + }}; + for (size_t i = 0; i < inputs.size(); ++i) { + auto& input = inputs[i]; + iter.set_visited(input->owner_opr()); + if (input_requires_grad[i]) { + grad_context.mark_require_grad(input); + } + } + for (auto&& output : outputs) { + iter.add(output); + } auto output_grads = make_input_place_holders(output_descs); - mgb_assert( - output_grads.size() == output_has_grad.size(), "%d vs %d", - output_grads.size(), output_has_grad.size()); - bool any_input_has_grad = false; - for (size_t i = 0; i < output_grads.size(); ++i) { + for (size_t i = 0; i < outputs.size(); ++i) { if (!output_has_grad[i]) { output_grads[i] = nullptr; - } else { - any_input_has_grad = true; } } - if (!any_input_has_grad) { - return {}; - } - auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); - - EncodedSubgraph result; - auto&& igraph = result.graph; - - size_t nr_backward_graph_inputs = 0; - auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, - &nr_backward_graph_inputs](cg::OperatorNodeBase* op) { - if (auto t = as_tensor(op)) { - mgb_assert(op->output().size() == 1); - igraph.constants.emplace_back(push(op->output(0)), std::move(t)); - } else if (op->same_type()) { - ++nr_backward_graph_inputs; - push(op->output(0)); - } else { - SmallVector inputs, outputs; - for (auto&& i : op->input()) { - if (i->owner_opr() == fwd) { - if (var2idx.find(i) == var2idx.end()) { - ++nr_backward_graph_inputs; - push(i); - } - } - inputs.push_back(var2idx.at(i)); - } - for (auto&& i : op->usable_output()) { - outputs.push_back(push(i)); + auto compute_input_grads = [&](op_t op, vars_t inputs, vars_t outputs, + vars_t output_grads) { + auto* gfunc = cg::lookup_grad_func(op->dyn_typeinfo()); + vars_t input_grads(inputs.size(), nullptr); + bool any_grad = false; + for (auto&& output_grad : output_grads) { + if (output_grad) { + any_grad = true; } - igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); } - }; - - // set backward graph outputs - cg::DepOprIter iter{gen_expr}; - iter.set_visited(fwd); - result.output_mask.resize(inputs.size()); - - VarNodeArray output_grads_with_unused_var; - { - auto iter = output_grads.begin(); - for (auto&& i : fwd->output()) { - if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - // the var node with VOLATILE_CONTENT(e.g. workspace - // or an empty var) would not be considered as a normal - // output, so its grad is always NULL - output_grads_with_unused_var.push_back(nullptr); - } else { - output_grads_with_unused_var.push_back(*iter); - ++iter; - } + if (!gfunc || !any_grad) { + return input_grads; } - mgb_assert(iter == output_grads.end()); - } - - Maybe grad_results; - for (size_t i = 0; i < inputs.size(); ++i) { - VarNode* grad; - if (grad_results.valid()) { - grad = grad_results.val()[i]; - } else { - mgb_assert(gfunc, "could not find grad function"); - auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); - if (res.from_single()) { - grad = res.single(); - } else { - grad_results.emplace(res.all(fwd)); + Maybe grad_results; + auto&& input_requires_grad = grad_context.get_require_grads(inputs); + for (size_t i = 0; i < inputs.size(); ++i) { + VarNode* grad; + if (grad_results.valid()) { grad = grad_results.val()[i]; - } - } - if (grad && !grad->owner_opr()->same_type() && - input_requires_grad[i]) { - mgb_assert( - !grad->owner_opr()->same_type(), - "gradient of operator %s w.r.t. input #%lu is " - "either not well defined or not implemented", - fwd->dyn_typeinfo()->name, i); - iter.add(grad); - igraph.outputs.push_back(var2idx.at(grad)); - result.output_mask[i] = true; - } else { - result.output_mask[i] = false; - } - } - if (igraph.outputs.empty()) { - return {}; - } - - // set backward graph inputs - auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { - for (auto&& i : vars) { - auto&& iter = var2idx.find(i); - if (iter != var2idx.end()) { - igraph.inputs.push_back(iter->second); - result.input_mask.push_back(true); } else { - result.input_mask.push_back(false); + mgb_assert(gfunc, "could not find grad function"); + auto res = (*gfunc)(op, i, output_grads); + if (res.from_single()) { + grad = res.single(); + } else { + grad_results.emplace(res.all(op)); + grad = grad_results.val()[i]; + } + } + if (grad && !grad->owner_opr()->same_type()) { + if (input_requires_grad[i]) { + input_grads[i] = grad; + } } } + return input_grads; }; - write_inputs(inputs); - write_inputs(outputs); - write_inputs(output_grads); - mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); - return result; + grad_context.backward(outputs, output_grads, compute_input_grads); + auto input_grads = grad_context.get_grads(inputs); + VarNodeArray bgraph_inputs; + bgraph_inputs.insert(bgraph_inputs.end(), inputs.begin(), inputs.end()); + bgraph_inputs.insert(bgraph_inputs.end(), outputs.begin(), outputs.end()); + bgraph_inputs.insert(bgraph_inputs.end(), output_grads.begin(), output_grads.end()); + auto graph = subgraph_detail::make_from_computing_graph(bgraph_inputs, input_grads); + return graph; } VarNodeArray ProxyGraph::make_input_place_holders( diff --git a/imperative/src/impl/subgraph_detail.cpp b/imperative/src/impl/subgraph_detail.cpp index 724f6cd6b..f0736fcd6 100644 --- a/imperative/src/impl/subgraph_detail.cpp +++ b/imperative/src/impl/subgraph_detail.cpp @@ -107,13 +107,16 @@ EncodedSubgraph make_backward_graph_from_forward( Subgraph::Builder builder( [](auto&& op, auto&& input_descs, size_t nr_outputs) { auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs); + mgb_assert( + descs.size() == nr_outputs, "nr_outputs mismatch for %s", + op->to_string().c_str()); return descs; }); auto accum_grad = [&](var_t lhs, var_t rhs) { return builder.write_expr( Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0]; }; - GradContext grad_context{accum_grad}; + GradContext, var_t> grad_context{accum_grad}; auto input_vars = builder.write_inputs(inputs); auto outputs = forward_graph.apply( input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), @@ -143,19 +146,17 @@ EncodedSubgraph make_backward_graph_from_forward( grad_context.backward( apply_mask(outputs, output_has_grad), apply_mask(output_grads, output_has_grad), - [&](Subgraph::expr_t expr, vars_t output_grads) { + [&](Subgraph::op_t op, vars_t inputs, vars_t outputs, vars_t output_grads) { auto bg = OpDef::make_backward_graph( - *expr.op, builder.get_descs(expr.inputs), - grad_context.get_require_grads(expr.inputs), - grad_context.get_has_grads(expr.outputs)); + *op, builder.get_descs(inputs), + grad_context.get_require_grads(inputs), + grad_context.get_has_grads(outputs)); if (bg.graph.empty()) { - return vars_t(expr.inputs.size(), 0); + return vars_t(inputs.size(), 0); } vars_t grad_inputs; - grad_inputs.insert( - grad_inputs.end(), expr.inputs.begin(), expr.inputs.end()); - grad_inputs.insert( - grad_inputs.end(), expr.outputs.begin(), expr.outputs.end()); + grad_inputs.insert(grad_inputs.end(), inputs.begin(), inputs.end()); + grad_inputs.insert(grad_inputs.end(), outputs.begin(), outputs.end()); grad_inputs.insert( grad_inputs.end(), output_grads.begin(), output_grads.end()); auto apply_functor = @@ -183,6 +184,77 @@ EncodedSubgraph make_backward_graph( forward_graph, inputs, input_requires_grad, output_has_grad); } +EncodedSubgraph make_from_computing_graph( + const VarNodeArray& inputs, const VarNodeArray& outputs) { + Subgraph subgraph; + std::unordered_map var2idx; + size_t next_idx = 0; + var2idx[nullptr] = next_idx++; + for (auto&& input : inputs) { + if (input) { + var2idx[input] = next_idx++; + } + } + auto is_tensor_holder = [](cg::OperatorNodeBase* op) { + return op->input().empty(); + }; + auto as_tensor = [](VarNode* var) -> TensorPtr { + auto* opr = var->owner_opr(); + if (auto* imm_tensor = opr->try_cast_final()) { + auto&& dv = imm_tensor->value(); + HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype()); + // get host value + auto&& cpu_value = imm_tensor->host_value(); + mgb_assert(cpu_value.comp_node() == CompNode::default_cpu()); + // default_cpu is synchronous with respect to caller + hv.proxy_to_default_cpu().copy_from_fixlayout(cpu_value); + return Tensor::make(dv, hv); + } else if ( + auto* shared_tensor = opr->try_cast_final()) { + return Tensor::make(shared_tensor->get_dev_tensor()); + } else { + mgb_assert( + false, "unsupported tensor holder opr %s", + opr->dyn_typeinfo()->name); + } + }; + cg::DepOprIter iter{[&](cg::OperatorNodeBase* op) { + // TODO: implement make_backward_graph for mm ops + // mgb_assert(!op->node_prop().contain(cg::OperatorNodeProp::Flag::IMPURE_FUNC)); + if (is_tensor_holder(op)) { + for (auto&& output : op->usable_output()) { + subgraph.constants.push_back( + {var2idx[output] = next_idx++, as_tensor(output)}); + } + } else { + Subgraph::vars_t inputs; + Subgraph::vars_t outputs; + for (auto&& input : op->input()) { + inputs.push_back(var2idx.at(input)); + } + // NOTE: use usable_output + for (auto&& output : op->usable_output()) { + outputs.push_back(var2idx[output] = next_idx++); + } + auto opdef = OpDef::make_from_op_node(op); + subgraph.exprs.push_back({opdef, inputs, outputs}); + } + }}; + for (auto&& input : inputs) { + if (input) { + iter.set_visited(input->owner_opr()); + } + subgraph.inputs.push_back(var2idx.at(input)); + } + for (auto&& output : outputs) { + if (output) { + iter.add(output); + } + subgraph.outputs.push_back(var2idx.at(output)); + } + return EncodedSubgraph::make(subgraph); +} + } // namespace subgraph_detail } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/subgraph.h b/imperative/src/include/megbrain/imperative/subgraph.h index 564443726..0e18b6be5 100644 --- a/imperative/src/include/megbrain/imperative/subgraph.h +++ b/imperative/src/include/megbrain/imperative/subgraph.h @@ -189,12 +189,17 @@ struct EncodedSubgraph { size_t hash() const; }; -template +template class GradContext { public: - using var_t = T; + using op_t = TOp; + using var_t = TVar; using vars_t = SmallVector; - using expr_t = Expr; + struct expr_t { + op_t op; + vars_t inputs; + vars_t outputs; + }; private: std::unordered_map m_grads; @@ -219,6 +224,7 @@ public: } return mask; } + void mark_require_grad(var_t dest) { m_vars_require_grad.insert(dest); } void mark_require_grads(vars_t dests) { for (auto&& dest : dests) { m_vars_require_grad.insert(dest); @@ -231,7 +237,7 @@ public: return m_grads[dest] = m_accumulator(m_grads[dest], grad); } } - void record_expr(std::shared_ptr op, vars_t inputs, vars_t outputs) { + void record_expr(op_t op, vars_t inputs, vars_t outputs) { bool require_grad = false; for (auto&& input : inputs) { if (m_vars_require_grad.count(input)) { @@ -254,7 +260,8 @@ public: std::reverse(exprs.begin(), exprs.end()); for (const expr_t& expr : exprs) { size_t nr_inputs = expr.inputs.size(); - vars_t input_grads = functor(expr, get_grads(expr.outputs)); + vars_t input_grads = functor( + expr.op, expr.inputs, expr.outputs, get_grads(expr.outputs)); mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); for (size_t i = 0; i < nr_inputs; ++i) { if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { diff --git a/imperative/src/include/megbrain/imperative/subgraph_detail.h b/imperative/src/include/megbrain/imperative/subgraph_detail.h index e55f19ad1..ccebce1fd 100644 --- a/imperative/src/include/megbrain/imperative/subgraph_detail.h +++ b/imperative/src/include/megbrain/imperative/subgraph_detail.h @@ -43,6 +43,8 @@ EncodedSubgraph make_backward_graph_from_forward( const EncodedSubgraph& forward, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad); +EncodedSubgraph make_from_computing_graph( + const VarNodeArray& inputs, const VarNodeArray& outputs); } // namespace subgraph_detail } // namespace imperative -- GitLab