backward_graph_opt.cpp 3.9 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/backward_graph_opt.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"

using namespace mgb;
using namespace imperative;

M
Megvii Engine Team 已提交
19
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubgraph& src)
20 21
        : input_has_grad(src.output_mask) {
    if (src.graph.exprs.size() <= 1) {
22
        // backward graph only contains a single op
23 24
        backward = src.graph;
        save_for_backward = src.input_mask;
25 26
        return;
    }
27
    save_for_backward.resize(src.input_mask.size(), false);
28

29 30 31
    auto&& graph = src.graph;
    auto&& mask = src.input_mask;
    size_t input_size = src.output_mask.size();
32 33 34
    size_t output_size = (mask.size() - input_size) / 2;
    mgb_assert(input_size + output_size * 2 == mask.size());

35 36
    auto& fgraph = precomp;
    auto& bgraph = backward;
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    // optimization: move ops (e.g. GetVarShape) to forward to
    // reduce memory footprint

    struct VInfo {
        bool appears_in_backward = false;
    };
    std::unordered_map<size_t, VInfo> vinfo;

    // step 1.1: ops not in whitelist must run in backward.
    // mark their inputs as always appears in backward
    for (auto&& [op, iv, ov] : graph.exprs) {
        if (!op->same_type<GetVarShape>()) {
            for (auto&& v : iv) {
                vinfo[v].appears_in_backward = true;
            }
        }
    }
    // step 1.2: inputs only available in backward (i.e. grads)
    // should be marked as always appears in backward
    for (size_t i = 0, j = 0; i < mask.size(); ++i) {
        if (!mask[i]) continue;
59
        if (i >= input_size + output_size) {
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
            vinfo[graph.inputs[j]].appears_in_backward = true;
        }
        ++j;
    }

    // step 2: try to move ops to forward, if not all their inputs
    // are marked always appears in backward (otherwise no memory saving)
    for (auto&& expr : graph.exprs) {
        auto&& [op, iv, ov] = expr;
        if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) {
            bgraph.exprs.push_back(expr);
            for (auto&& v : ov) {
                vinfo[v].appears_in_backward = true;
            }
            // logically should also mark all inputs as appears in backward
            // but clearly that's a no-op.
        } else {
            fgraph.exprs.push_back(expr);
            for (auto&& v : ov) {
                if (vinfo[v].appears_in_backward) {
                    // appears_in_backward won't change after this point
                    // so it is safe to set fgraph.outputs based on current value
                    fgraph.outputs.push_back(v);
                }
            }
        }
    }

    // initialize remaining parts

    fgraph.constants = graph.constants;
    fgraph.inputs.reserve(input_size + output_size);
    for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
        if (!mask[i]) {
            fgraph.inputs.push_back(1000000000 + i);
            continue;
        }
        fgraph.inputs.push_back(graph.inputs[j++]);
    }

    bgraph.constants = graph.constants;
    bgraph.outputs = graph.outputs;
    bgraph.inputs = fgraph.outputs;
    for (size_t i = 0, j = 0; i < mask.size(); ++i) {
        if (mask[i]) {
            auto&& v = graph.inputs[j++];
            if (vinfo[v].appears_in_backward) {
                save_for_backward[i] = true;
                bgraph.inputs.push_back(v);
            }
        }
    }
112 113

    if (!fgraph.outputs.size()) {
114
        precomp = {};
115
    }
116
}