diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 898919ccc48525cf52e885d2065021712a0ae998..991d80bd2f012c3987e3ba892d18719326b36c07 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -77,7 +77,7 @@ std::shared_ptr make_backward_graph( std::shared_ptr ret; auto bg = OpDef::make_backward_graph( *ctx.op, inputs, input_requires_grad, output_has_grad); - if (!bg.backward.empty()) { + if (!bg.graph.empty()) { ret = std::make_shared(bg); } backward_graph_cache.emplace(key, ret); diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index 2fed9188cfc78c765b2e8e53ae0bad25a556a72f..94a1abafc51dcacd2981a7ed105aa7052b18f2e6 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) { const SmallVector& input_requires_grad, const SmallVector& output_has_grad){ auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); - return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); + return std::make_tuple("backward_graph", result.input_mask, result.output_mask); }; m.def("make_backward_graph", make_backward_graph); } diff --git a/imperative/src/impl/backward_graph_opt.cpp b/imperative/src/impl/backward_graph_opt.cpp index 49dd7673b66ccd0add85aa7e76f185fcc0ebbdfa..007f18c2a4bf20d32f15672ef338b394315b16fc 100644 --- a/imperative/src/impl/backward_graph_opt.cpp +++ b/imperative/src/impl/backward_graph_opt.cpp @@ -16,19 +16,19 @@ using namespace mgb; using namespace imperative; -OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) - : input_has_grad(src.input_has_grad) { - if (src.backward.exprs.size() <= 1) { +OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubraph& src) + : input_has_grad(src.output_mask) { + if (src.graph.exprs.size() <= 1) { // backward graph only contains a single op - backward = src.backward; - save_for_backward = src.save_for_backward; + backward = src.graph; + save_for_backward = src.input_mask; return; } - save_for_backward.resize(src.save_for_backward.size(), false); + save_for_backward.resize(src.input_mask.size(), false); - auto&& graph = src.backward; - auto&& mask = src.save_for_backward; - size_t input_size = src.input_has_grad.size(); + auto&& graph = src.graph; + auto&& mask = src.input_mask; + size_t input_size = src.output_mask.size(); size_t output_size = (mask.size() - input_size) / 2; mgb_assert(input_size + output_size * 2 == mask.size()); diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index ede6c6bca62397ec3d41b2fb7f5fc21863d1b3d2..2f856c02e9ed13333bf26ce952dca6dfc2ea13e2 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -80,7 +80,7 @@ std::tuple, bool> OpDef::infer_output_attrs_falli return def.trait()->infer_output_attrs_fallible(def, inputs); } -BackwardGraphResult OpDef::make_backward_graph( +EncodedSubraph OpDef::make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index a54e6a4606c50ac1aeb094c2a5153bd80f0e4522..2263f9f7b5247ccab0ad1f6cfb01b16a8ea9dea2 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph { cg::VarNode* grad; }; -BackwardGraphResult +EncodedSubraph ProxyGraph::make_backward_graph( const OpDef& opdef, const SmallVector& input_descs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { ThinHashMap var2idx; - auto push = [&var2idx, cnt=0](VarNode* var) mutable { + auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero auto&& ret = var2idx.emplace(var, cnt ++); mgb_assert(ret.second, "var %s has been already inserted", var->cname()); return ret.first->second; @@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph( } auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); - BackwardGraphResult result; - auto&& igraph = result.backward; + EncodedSubraph result; + auto&& igraph = result.graph; size_t nr_backward_graph_inputs = 0; auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, @@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph( // set backward graph outputs cg::DepOprIter iter{gen_expr}; iter.set_visited(fwd); - result.input_has_grad.resize(inputs.size()); + result.output_mask.resize(inputs.size()); VarNodeArray output_grads_with_unused_var; { @@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph( if (grad_results.valid()) { grad = grad_results.val()[i]; } else { + mgb_assert(gfunc, "could not find grad function"); auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); if (res.from_single()) { grad = res.single(); @@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph( fwd->dyn_typeinfo()->name, i); iter.add(grad); igraph.outputs.push_back(var2idx.at(grad)); - result.input_has_grad[i] = true; + result.output_mask[i] = true; } else { - result.input_has_grad[i] = false; + result.output_mask[i] = false; } } if (igraph.outputs.empty()) { @@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph( // set backward graph inputs igraph.inputs.reserve(nr_backward_graph_inputs); - result.save_for_backward.reserve(nr_backward_graph_inputs); + result.input_mask.reserve(nr_backward_graph_inputs); auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { for (auto&& i: vars) { auto&& iter = var2idx.find(i); if (iter != var2idx.end()) { igraph.inputs.push_back(iter->second); - result.save_for_backward.push_back(true); + result.input_mask.push_back(true); } else { - result.save_for_backward.push_back(false); + result.input_mask.push_back(false); } } }; diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index 787f5cb15c5c4230dac10946b16065c239c03892..eecd7191a5ec67e12df5dc885f067595ee8376ed 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -40,7 +40,7 @@ public: const SmallVector& outputs, const SmallVector& workspace); - BackwardGraphResult make_backward_graph( + EncodedSubraph make_backward_graph( const OpDef& opdef, const SmallVector& input_descs, const SmallVector& input_requires_grad, diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index 93d75b3daf4d567cd0502c22af3d24c29b249ad2..b49e672067f0373128cf6f867ed164a666a412b7 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def, return state.digest(); } -struct BackwardGraphCache : std::unordered_map, CompNodeDepedentObject { +struct BackwardGraphCache : std::unordered_map, CompNodeDepedentObject { std::shared_ptr on_comp_node_finalize() override { clear(); return {}; @@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map, Com } // anonymous namespace -BackwardGraphResult +EncodedSubraph make_backward_graph(const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, diff --git a/imperative/src/impl/subgraph.cpp b/imperative/src/impl/subgraph.cpp index 087bd1fbcb0b3f2e10ab2aa41d77cf066738a077..212627b7f2d45b652d7b065e9d7657955a2cb9d4 100644 --- a/imperative/src/impl/subgraph.cpp +++ b/imperative/src/impl/subgraph.cpp @@ -101,5 +101,26 @@ void Subgraph::replace_vars( } } +std::string EncodedSubraph::repr() const { + std::string buffer; + buffer.push_back('|'); + for (size_t i = 0; i < input_mask.size(); ++i) { + buffer.push_back(input_mask[i] ? '#' : ' '); + } + buffer.push_back('|'); + buffer.push_back('\n'); + buffer.append(graph.repr()); + buffer.push_back('|'); + for (size_t i = 0; i < output_mask.size(); ++i) { + buffer.push_back(output_mask[i] ? '#' : ' '); + } + buffer.push_back('|'); + return buffer; +} + +size_t EncodedSubraph::hash() const { + return std::hash{}(repr()); +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/backward_graph_opt.h b/imperative/src/include/megbrain/imperative/backward_graph_opt.h index ef6e6461fdbd76b14fa5942ef3073a3898c51c11..de1e2c1c8f3499258be5e1fadf43337a8846f3ed 100644 --- a/imperative/src/include/megbrain/imperative/backward_graph_opt.h +++ b/imperative/src/include/megbrain/imperative/backward_graph_opt.h @@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult { SmallVector save_for_backward; SmallVector input_has_grad; - OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); + OptimizedBackwardGraphResult(const EncodedSubraph& bgraph); }; } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 7ab725070f777b84f2759a298d7e6de396f9e0c5..591a2323ec61a17c4f8e37c1a5d55cc1218cf592 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -29,12 +29,6 @@ enum DispatchMode { using SharedOp = std::shared_ptr; -struct BackwardGraphResult { - Subgraph backward; - SmallVector save_for_backward; - SmallVector input_has_grad; -}; - class OpDef : public Hashable, public NonCopyableObj, public std::enable_shared_from_this { @@ -91,7 +85,7 @@ public: const SmallVector& inputs_tensors, const SmallVector& inputs_mems); - static BackwardGraphResult make_backward_graph( + static EncodedSubraph make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, diff --git a/imperative/src/include/megbrain/imperative/proxy_graph_detail.h b/imperative/src/include/megbrain/imperative/proxy_graph_detail.h index 5bdaab730dd37ed537492daefced4aa0e3c8378f..4ff95dcefa1986eb0947c90216062d92101dac94 100644 --- a/imperative/src/include/megbrain/imperative/proxy_graph_detail.h +++ b/imperative/src/include/megbrain/imperative/proxy_graph_detail.h @@ -38,7 +38,7 @@ void exec(const OpDef& def, const SmallVector& inputs, const SmallVector& outputs); -BackwardGraphResult +EncodedSubraph make_backward_graph(const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, diff --git a/imperative/src/include/megbrain/imperative/subgraph.h b/imperative/src/include/megbrain/imperative/subgraph.h index 68221e4538d1db09d5c32430ce6a1629fd66bae1..dce8dc6fb9ad814af37be838804eb90430fc1550 100644 --- a/imperative/src/include/megbrain/imperative/subgraph.h +++ b/imperative/src/include/megbrain/imperative/subgraph.h @@ -96,5 +96,185 @@ struct Subgraph { bool operator==(const Subgraph& rhs) const; }; +struct EncodedSubraph { + Subgraph graph; + SmallVector input_mask; + SmallVector output_mask; + + template + TContainer encode_inputs(TContainer inputs) const { + TContainer encoded_inputs; + size_t index = 0; + for (auto&& input : inputs) { + mgb_assert(index < input_mask.size(), "index out of range"); + if (input_mask[index++]) { + encoded_inputs.push_back(input); + } + } + mgb_assert(index == input_mask.size(), "mask size mismatch"); + return encoded_inputs; + } + + template + TContainer encode_outputs(TContainer outputs) const { + TContainer encoded_outputs; + size_t index = 0; + for (auto&& output : outputs) { + mgb_assert(index < output_mask.size(), "index out of range"); + if (output_mask[index++]) { + encoded_outputs.push_back(output); + } + } + mgb_assert(index == output_mask.size(), "mask size mismatch"); + return encoded_outputs; + } + + template + TContainer decode_outputs(TContainer outputs) const { + TContainer decoded_outputs; + size_t index = 0; + for (size_t i = 0; i < output_mask.size(); i++) { + mgb_assert(index < output_mask.size(), "index out of range"); + if (output_mask[i]) { + decoded_outputs.push_back(outputs[index++]); + } else { + decoded_outputs.emplace_back(); + } + } + mgb_assert(decoded_outputs.size() == output_mask.size(), + "mask size mismatch"); + return decoded_outputs; + } + + static EncodedSubraph make(Subgraph graph) { + EncodedSubraph result; + result.input_mask = graph.gen_input_mask(); + result.output_mask = graph.gen_output_mask(); + graph.inputs = result.encode_inputs(graph.inputs); + graph.outputs = result.encode_outputs(graph.outputs); + result.graph = graph; + return result; + } + + static EncodedSubraph make_single( + std::shared_ptr op, + SmallVector input_mask, + SmallVector output_mask) { + EncodedSubraph result; + result.input_mask = input_mask; + result.output_mask = output_mask; + Subgraph::var_t last_var = 0; + for (auto&& mask: input_mask) { + if (mask) { + result.graph.inputs.push_back(++last_var); + } + } + for (auto&& mask: output_mask) { + if (mask) { + result.graph.outputs.push_back(++last_var); + } + } + result.graph.exprs = {Subgraph::expr_t{op, result.graph.inputs, result.graph.outputs}}; + return result; + } + + template + SmallVector apply(SmallVector input_vars, F&& f, C&& c) const { + auto encoded_inputs = encode_inputs(input_vars); + auto encoded_outputs = graph.apply(encoded_inputs, std::forward(f), + std::forward(c)); + return decode_outputs(encoded_outputs); + } + + std::string repr() const; + size_t hash() const; +}; + +template +class GradContext { +public: + using var_t = T; + using vars_t = SmallVector; + using expr_t = Expr; +private: + std::unordered_map m_grads; + std::unordered_set m_vars_require_grad; + std::function m_accumulator; + std::vector m_exprs; +public: + GradContext(std::function accumulator): m_accumulator{std::move(accumulator)}{} + SmallVector get_require_grads(vars_t dests) { + SmallVector mask; + for (auto&& dest: dests) { + mask.push_back(bool(m_vars_require_grad.count(dest))); + } + return mask; + } + SmallVector get_has_grads(vars_t dests) { + SmallVector mask; + for (auto&& dest: dests) { + mask.push_back(bool(m_grads.count(dest))); + } + return mask; + } + void mark_require_grads(vars_t dests) { + for (auto&& dest: dests) { + m_vars_require_grad.insert(dest); + } + } + var_t accumulate_grad(var_t dest, var_t grad) { + if (!m_grads.count(dest)) { + return m_grads[dest] = grad; + } else { + return m_grads[dest] = m_accumulator(m_grads[dest], grad); + } + } + void record_expr(std::shared_ptr op, vars_t inputs, vars_t outputs) { + bool require_grad = false; + for (auto&& input: inputs) { + if (m_vars_require_grad.count(input)) { + require_grad = true; + break; + } + } + if (require_grad) { + m_exprs.push_back({op, inputs, outputs}); + mark_require_grads(outputs); + } + } + template + void backward(vars_t outputs, vars_t output_grads, TFunctor functor) { + size_t nr_outputs = outputs.size(); + for (size_t i = 0; i < nr_outputs; ++i) { + m_grads[outputs[i]] = output_grads[i]; + } + auto exprs = m_exprs; + std::reverse(exprs.begin(), exprs.end()); + for (const expr_t& expr: exprs) { + size_t nr_inputs = expr.inputs.size(); + vars_t input_grads = functor(expr, get_grads(expr.outputs)); + mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); + for (size_t i = 0; i < nr_inputs; ++i) { + if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { + accumulate_grad(expr.inputs[i], input_grads[i]); + } + } + } + } + var_t get_grad(var_t dest) { + if (m_grads.count(dest)) { + return m_grads.at(dest); + } + return 0; + } + vars_t get_grads(vars_t dests) { + vars_t grads; + for (auto&& dest: dests) { + grads.push_back(get_grad(dest)); + } + return grads; + } +}; + } // namespace imperative } // namespace mgb \ No newline at end of file diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index 1ef56c588a2610c98ef393cc793db91e7200ffa4..0dddc73b113b7e329c095afb7f2911529afd96e2 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -22,22 +22,22 @@ using namespace cg; using namespace imperative; template -T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, +T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs, const T& outputs, const T& grads) { T ret; size_t i = 0; for (auto&& t : inputs) { - if (bg.save_for_backward[i++]) { + if (bg.input_mask[i++]) { ret.push_back(t); } } for (auto&& t : outputs) { - if (bg.save_for_backward[i++]) { + if (bg.input_mask[i++]) { ret.push_back(t); } } for (auto&& t : grads) { - if (bg.save_for_backward[i++]) { + if (bg.input_mask[i++]) { ret.push_back(t); } } @@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, } template -T expand_grads(const U& bg, const T& outputs) { - T ret(bg.input_has_grad.size()); - for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { - if (bg.input_has_grad[i]) { +T expand_grads(const U& mask, const T& outputs) { + T ret(mask.size()); + for (size_t i = 0, j = 0; i < mask.size(); ++i) { + if (mask[i]) { ret[i] = outputs[j++]; } } @@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, } SmallVector apply_shared_on_physical_tensor( - std::shared_ptr def, SmallVector inputs) { + std::shared_ptr def, SmallVector inputs, size_t nr_outputs) { return OpDef::apply_on_physical_tensor(*def, inputs); } @@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) { } auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); - auto&& save_for_backward = result.save_for_backward; - auto&& input_has_grad = result.input_has_grad; + auto&& save_for_backward = result.input_mask; + auto&& input_has_grad = result.output_mask; auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); inputs.push_back(outputs[0]); @@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) { } } inputs.clear(); - auto input_grads = result.backward.apply(backward_graph_inputs, + auto input_grads = result.graph.apply(backward_graph_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }); mgb_assert(input_grads.size() == input_has_grad.size()); @@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) { input_descs.push_back({a->layout(), a->comp_node()}); auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); - auto&& save_for_backward = result.save_for_backward; - auto&& input_has_grad = result.input_has_grad; + auto&& save_for_backward = result.input_mask; + auto&& input_has_grad = result.output_mask; auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); inputs.push_back(outputs[0]); @@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) { } } inputs.clear(); - auto input_grads = result.backward.apply(backward_graph_inputs, + auto input_grads = result.graph.apply(backward_graph_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }); mgb_assert(input_grads.size() == input_has_grad.size()); @@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { prepare_backward_graph_inputs>( bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); auto grads = - expand_grads(bg, bg.backward.apply(backward_graph_inputs, + expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; })); @@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { prepare_optimized_backward_inputs>( obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); auto grads2 = expand_grads( - obg, + obg.input_has_grad, obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }));