backward_graph_opt.cpp 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
/**
 * \file imperative/src/impl/backward_graph_opt.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"

using namespace mgb;
using namespace imperative;

OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
        : input_has_grad(src.input_has_grad) {
    if (!src.backward->same_type<BackwardGraph>()) {
        // backward graph only contains a single op
        backward = src.backward;
        save_for_backward = src.save_for_backward;
        return;
    }
    save_for_backward.resize(src.save_for_backward.size(), false);
    precomp.reset(new BackwardGraph);
    backward.reset(new BackwardGraph);

    auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph();
    auto&& mask = src.save_for_backward;
    size_t input_size = src.input_has_grad.size();
    size_t output_size = (mask.size() - input_size) / 2;
    mgb_assert(input_size + output_size * 2 == mask.size());

    auto& fgraph = precomp->cast_final<BackwardGraph>().graph();
    auto& bgraph = backward->cast_final<BackwardGraph>().graph();

    // optimization: move ops (e.g. GetVarShape) to forward to
    // reduce memory footprint

    struct VInfo {
        bool appears_in_backward = false;
    };
    std::unordered_map<size_t, VInfo> vinfo;

    // step 1.1: ops not in whitelist must run in backward.
    // mark their inputs as always appears in backward
    for (auto&& [op, iv, ov] : graph.exprs) {
        if (!op->same_type<GetVarShape>()) {
            for (auto&& v : iv) {
                vinfo[v].appears_in_backward = true;
            }
        }
    }
    // step 1.2: inputs only available in backward (i.e. grads)
    // should be marked as always appears in backward
    for (size_t i = 0, j = 0; i < mask.size(); ++i) {
        if (!mask[i]) continue;
61
        if (i >= input_size + output_size) {
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            vinfo[graph.inputs[j]].appears_in_backward = true;
        }
        ++j;
    }

    // step 2: try to move ops to forward, if not all their inputs
    // are marked always appears in backward (otherwise no memory saving)
    for (auto&& expr : graph.exprs) {
        auto&& [op, iv, ov] = expr;
        if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) {
            bgraph.exprs.push_back(expr);
            for (auto&& v : ov) {
                vinfo[v].appears_in_backward = true;
            }
            // logically should also mark all inputs as appears in backward
            // but clearly that's a no-op.
        } else {
            fgraph.exprs.push_back(expr);
            for (auto&& v : ov) {
                if (vinfo[v].appears_in_backward) {
                    // appears_in_backward won't change after this point
                    // so it is safe to set fgraph.outputs based on current value
                    fgraph.outputs.push_back(v);
                }
            }
        }
    }

    // initialize remaining parts

    fgraph.constants = graph.constants;
    fgraph.inputs.reserve(input_size + output_size);
    for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
        if (!mask[i]) {
            fgraph.inputs.push_back(1000000000 + i);
            continue;
        }
        fgraph.inputs.push_back(graph.inputs[j++]);
    }

    bgraph.constants = graph.constants;
    bgraph.outputs = graph.outputs;
    bgraph.inputs = fgraph.outputs;
    for (size_t i = 0, j = 0; i < mask.size(); ++i) {
        if (mask[i]) {
            auto&& v = graph.inputs[j++];
            if (vinfo[v].appears_in_backward) {
                save_for_backward[i] = true;
                bgraph.inputs.push_back(v);
            }
        }
    }
}