diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index f4ae8cfcc46f7ba382985f1224fc6d8a6171280e..831d7f75535d4235c23660cc5f38a1ab2da31478 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -11,6 +11,7 @@ #include "./grad.h" #include "megbrain/imperative/proxy_graph_detail.h" +#include "megbrain/imperative/backward_graph_opt.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/utils/mempool.h" @@ -32,14 +33,14 @@ struct GradSlotWeakPtr { size_t idx; }; -struct BackwardGraphCache : std::unordered_map>, CompNodeDepedentObject { +struct BackwardGraphCache : std::unordered_map>, CompNodeDepedentObject { std::shared_ptr on_comp_node_finalize() override { clear(); return {}; } } backward_graph_cache; -std::shared_ptr make_backward_graph( +std::shared_ptr make_backward_graph( ApplyContext& ctx, const apply_result_t& outputs) { // hash static_assert(alignof(size_t) % alignof(bool) == 0); @@ -72,23 +73,23 @@ std::shared_ptr make_backward_graph( inputs[i].layout.dtype = ctx.args[i]->dtype(); input_requires_grad[i] = python::input_requires_grad(ctx, i); } - auto result = std::make_shared( - proxy_graph_detail::make_backward_graph( - *ctx.op, inputs, input_requires_grad, output_has_grad)); - if (!result->backward) { - result.reset(); + std::shared_ptr ret; + auto bg = proxy_graph_detail::make_backward_graph( + *ctx.op, inputs, input_requires_grad, output_has_grad); + if (bg.backward) { + ret = std::make_shared(bg); } - backward_graph_cache.emplace(key, result); - return result; + backward_graph_cache.emplace(key, ret); + return ret; } struct BackwardGraphWithClosure { - std::shared_ptr backward_graph; + std::shared_ptr backward_graph; SmallVector> closure; size_t output_mask_offset; size_t grad_mask_offset; - BackwardGraphWithClosure(std::shared_ptr backward_graph_, + BackwardGraphWithClosure(std::shared_ptr backward_graph_, ApplyContext& ctx, const apply_result_t& outputs) : backward_graph(backward_graph_), output_mask_offset(ctx.nargs), @@ -107,9 +108,18 @@ struct BackwardGraphWithClosure { // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] auto& save_for_backward = backward_graph->save_for_backward; mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); - closure.reserve(std::count_if(save_for_backward.begin(), - save_for_backward.end(), - ranges::identity{})); + size_t count = std::count_if(save_for_backward.begin(), + save_for_backward.end(), + ranges::identity{}); + if (backward_graph->precomp) { + auto&& irng = ranges::span(ctx.args, ctx.nargs); + auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); + auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); + closure.reserve(precomp.size() + count); + std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); + } else { + closure.reserve(count); + } for (size_t i = 0; i < ctx.nargs; ++i) { if (save_for_backward[i]) { closure.push_back(ctx.args[i]->shared_from_this()); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 3d78bb9006d16176730e41ab4f8d4228f4ebbcda..7dbd2ed97d73e765b2f5151b654eb1893368d42c 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) { if constexpr (std::is_invocable_v) { return resolve_arrow(p.operator->()); } else { - return p; + return std::forward(p); } } } diff --git a/imperative/src/impl/backward_graph_opt.cpp b/imperative/src/impl/backward_graph_opt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..838fd49ef99d817a9e40bc4889f9a0667db802df --- /dev/null +++ b/imperative/src/impl/backward_graph_opt.cpp @@ -0,0 +1,114 @@ +/** + * \file imperative/src/impl/backward_graph_opt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/backward_graph_opt.h" +#include "megbrain/imperative/ops/backward_graph.h" +#include "megbrain/imperative/ops/autogen.h" + +using namespace mgb; +using namespace imperative; + +OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) + : input_has_grad(src.input_has_grad) { + if (!src.backward->same_type()) { + // backward graph only contains a single op + backward = src.backward; + save_for_backward = src.save_for_backward; + return; + } + save_for_backward.resize(src.save_for_backward.size(), false); + precomp.reset(new BackwardGraph); + backward.reset(new BackwardGraph); + + auto&& graph = src.backward->cast_final_safe().graph(); + auto&& mask = src.save_for_backward; + size_t input_size = src.input_has_grad.size(); + size_t output_size = (mask.size() - input_size) / 2; + mgb_assert(input_size + output_size * 2 == mask.size()); + + auto& fgraph = precomp->cast_final().graph(); + auto& bgraph = backward->cast_final().graph(); + + // optimization: move ops (e.g. GetVarShape) to forward to + // reduce memory footprint + + struct VInfo { + bool appears_in_backward = false; + }; + std::unordered_map vinfo; + + // step 1.1: ops not in whitelist must run in backward. + // mark their inputs as always appears in backward + for (auto&& [op, iv, ov] : graph.exprs) { + if (!op->same_type()) { + for (auto&& v : iv) { + vinfo[v].appears_in_backward = true; + } + } + } + // step 1.2: inputs only available in backward (i.e. grads) + // should be marked as always appears in backward + for (size_t i = 0, j = 0; i < mask.size(); ++i) { + if (!mask[i]) continue; + if (i > input_size + output_size) { + vinfo[graph.inputs[j]].appears_in_backward = true; + } + ++j; + } + + // step 2: try to move ops to forward, if not all their inputs + // are marked always appears in backward (otherwise no memory saving) + for (auto&& expr : graph.exprs) { + auto&& [op, iv, ov] = expr; + if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) { + bgraph.exprs.push_back(expr); + for (auto&& v : ov) { + vinfo[v].appears_in_backward = true; + } + // logically should also mark all inputs as appears in backward + // but clearly that's a no-op. + } else { + fgraph.exprs.push_back(expr); + for (auto&& v : ov) { + if (vinfo[v].appears_in_backward) { + // appears_in_backward won't change after this point + // so it is safe to set fgraph.outputs based on current value + fgraph.outputs.push_back(v); + } + } + } + } + + // initialize remaining parts + + fgraph.constants = graph.constants; + fgraph.inputs.reserve(input_size + output_size); + for (size_t i = 0, j = 0; i < input_size + output_size; ++i) { + if (!mask[i]) { + fgraph.inputs.push_back(1000000000 + i); + continue; + } + fgraph.inputs.push_back(graph.inputs[j++]); + } + + bgraph.constants = graph.constants; + bgraph.outputs = graph.outputs; + bgraph.inputs = fgraph.outputs; + for (size_t i = 0, j = 0; i < mask.size(); ++i) { + if (mask[i]) { + auto&& v = graph.inputs[j++]; + if (vinfo[v].appears_in_backward) { + save_for_backward[i] = true; + bgraph.inputs.push_back(v); + } + } + } +} diff --git a/imperative/src/include/megbrain/imperative/backward_graph_opt.h b/imperative/src/include/megbrain/imperative/backward_graph_opt.h new file mode 100644 index 0000000000000000000000000000000000000000..457c40f4afa8b4adfe8f9337f7a63f8c8968c3cd --- /dev/null +++ b/imperative/src/include/megbrain/imperative/backward_graph_opt.h @@ -0,0 +1,25 @@ +/** + * \file imperative/src/include/megbrain/imperative/backward_graph_opt.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./op_def.h" + +namespace mgb::imperative { + +struct OptimizedBackwardGraphResult { + std::shared_ptr precomp; + std::shared_ptr backward; + std::vector save_for_backward; + std::vector input_has_grad; + + OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); +}; + +} // namespace mgb::imperative diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index d673c7f6232a9bc29a17b8278c4d835ec658ab3d..8830e429af81d07c0974b56aaadbbc23fe3d5bc0 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -13,11 +13,68 @@ #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/backward_graph_opt.h" using namespace mgb; using namespace cg; using namespace imperative; +template +T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { + T ret; + size_t i = 0; + for (auto&& t : inputs) { + if (bg.save_for_backward[i++]) { + ret.push_back(t); + } + } + for (auto&& t : outputs) { + if (bg.save_for_backward[i++]) { + ret.push_back(t); + } + } + for (auto&& t : grads) { + if (bg.save_for_backward[i++]) { + ret.push_back(t); + } + } + return ret; +} + +template +T expand_grads(const U& bg, const T& outputs) { + T ret(bg.input_has_grad.size()); + for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { + if (bg.input_has_grad[i]) { + ret[i] = outputs[j++]; + } + } + return ret; +} + +template +T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { + T ret = precomp; + size_t i = 0; + for (auto&& t : inputs) { + if (bg.save_for_backward[i++]) { + ret.push_back(t); + } + } + for (auto&& t : outputs) { + if (bg.save_for_backward[i++]) { + ret.push_back(t); + } + } + for (auto&& t : grads) { + if (bg.save_for_backward[i++]) { + ret.push_back(t); + } + } + return ret; +} + TEST(TestImperative, BackwardGraphBasic) { HostTensorGenerator<> gen; SmallVector hvs; @@ -121,27 +178,65 @@ TEST(TestImperative, BackwardGraphIdentity) { } TEST(TestImperative, BatchNormGrad) { - auto cn = CompNode::load("xpux"); - using Param = opr::BatchNorm::Param; - size_t N=2, C=3, H=5, W=5; - LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; - LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; - { - auto op = OprAttr::make("BatchNorm"); - auto&& attr = op->cast_final_safe(); - Param param; - param.fwd_mode = Param::FwdMode::TRAINING; - attr.param.write_pod(param); - OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, - {true, true ,true, false, false}, {false, false, false, false, true}); - } - { - auto op = OprAttr::make("BatchNorm"); - auto&& attr = op->cast_final_safe(); - Param param; - param.fwd_mode = Param::FwdMode::TRAINING; - attr.param.write_pod(param); - OpDef::make_backward_graph(attr, {inp, stat, stat}, - {true, true ,true}, {false, false, true}); - } + auto cn = CompNode::load("xpux"); + using Param = opr::BatchNorm::Param; + size_t N=2, C=3, H=5, W=5; + LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; + LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; + { + auto op = OprAttr::make("BatchNorm"); + auto&& attr = op->cast_final_safe(); + Param param; + param.fwd_mode = Param::FwdMode::TRAINING; + attr.param.write_pod(param); + OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, + {true, true ,true, false, false}, {false, false, false, false, true}); + } + { + auto op = OprAttr::make("BatchNorm"); + auto&& attr = op->cast_final_safe(); + Param param; + param.fwd_mode = Param::FwdMode::TRAINING; + attr.param.write_pod(param); + OpDef::make_backward_graph(attr, {inp, stat, stat}, + {true, true ,true}, {false, false, true}); + } +} + +TEST(TestImperative, OptimizedBackwardGraphBasic) { + auto cn = CompNode::load("xpux"); + LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; + HostTensorGenerator<> gen; + auto op = std::shared_ptr(Elemwise::make(Elemwise::Mode::ADD)); + auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); + auto obg = OptimizedBackwardGraphResult(bg); + ASSERT_EQ(obg.save_for_backward.size(), 4); + ASSERT_FALSE(obg.save_for_backward[0]); + ASSERT_FALSE(obg.save_for_backward[1]); + ASSERT_FALSE(obg.save_for_backward[2]); + + auto a_hv = gen({42}); + auto b_hv = gen({5, 42}); + auto dc_hv = gen({5, 42}); + auto a_tn = Tensor::make(*a_hv); + auto b_tn = Tensor::make(*b_hv); + auto dc_tn = Tensor::make(*dc_hv); + auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; + + auto backward_graph_inputs = prepare_backward_graph_inputs>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); + auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); + + auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); + ASSERT_EQ(precomp.size(), 2); + ASSERT_EQ(precomp[0]->shape().ndim, 1); + ASSERT_LE(precomp[0]->shape()[0], 2); + ASSERT_EQ(precomp[1]->shape().ndim, 1); + ASSERT_LE(precomp[1]->shape()[0], 2); + + auto backward_inputs = prepare_optimized_backward_inputs>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); + auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); + + ASSERT_EQ(grads2.size(), 2); + MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); + MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value()); }