backward_graph.cpp 9.5 KB
Newer Older
1 2
/**
 * \file imperative/src/test/backward_graph.cpp
M
Megvii Engine Team 已提交
3
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
M
Megvii Engine Team 已提交
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13
 */

#include "./helper.h"
M
Megvii Engine Team 已提交
14 15 16
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
17 18 19 20 21 22 23
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/batch_norm.h"

using namespace mgb;
using namespace cg;
using namespace imperative;

24
template <typename T>
25
T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs,
M
Megvii Engine Team 已提交
26
                                const T& outputs, const T& grads) {
27 28 29
    T ret;
    size_t i = 0;
    for (auto&& t : inputs) {
30
        if (bg.input_mask[i++]) {
31 32 33 34
            ret.push_back(t);
        }
    }
    for (auto&& t : outputs) {
35
        if (bg.input_mask[i++]) {
36 37 38 39
            ret.push_back(t);
        }
    }
    for (auto&& t : grads) {
40
        if (bg.input_mask[i++]) {
41 42 43 44 45 46 47
            ret.push_back(t);
        }
    }
    return ret;
}

template <typename T, typename U>
48 49 50 51
T expand_grads(const U& mask, const T& outputs) {
    T ret(mask.size());
    for (size_t i = 0, j = 0; i < mask.size(); ++i) {
        if (mask[i]) {
52 53 54 55 56 57 58
            ret[i] = outputs[j++];
        }
    }
    return ret;
}

template <typename T>
M
Megvii Engine Team 已提交
59 60 61
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg,
                                    const T& precomp, const T& inputs,
                                    const T& outputs, const T& grads) {
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    T ret = precomp;
    size_t i = 0;
    for (auto&& t : inputs) {
        if (bg.save_for_backward[i++]) {
            ret.push_back(t);
        }
    }
    for (auto&& t : outputs) {
        if (bg.save_for_backward[i++]) {
            ret.push_back(t);
        }
    }
    for (auto&& t : grads) {
        if (bg.save_for_backward[i++]) {
            ret.push_back(t);
        }
    }
    return ret;
}

M
Megvii Engine Team 已提交
82
SmallVector<TensorPtr> apply_shared_on_physical_tensor(
83
        std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) {
84 85 86
    return OpDef::apply_on_physical_tensor(*def, inputs);
}

87 88 89 90
TEST(TestImperative, BackwardGraphBasic) {
    HostTensorGenerator<> gen;
    SmallVector<HostTensorND> hvs;
    SmallVector<TensorPtr> inputs;
M
Megvii Engine Team 已提交
91
    for (size_t i = 0; i < 2; ++i) {
92 93 94 95 96 97
        hvs.push_back(*gen({42}));
        inputs.push_back(Tensor::make(hvs.back()));
    }

    using Param = opr::Elemwise::Param;
    Param param{Param::Mode::MUL};
98 99
    auto attr = OprAttr::make("Elemwise");
    attr->cast_final_safe<OprAttr>().param.write_pod(param);
100 101 102 103 104

    SmallVector<LogicalTensorDesc> input_descs;
    for (auto&& i : inputs) {
        input_descs.push_back({i->layout(), i->comp_node()});
    }
M
Megvii Engine Team 已提交
105 106
    auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true},
                                             {true});
107 108
    auto&& save_for_backward = result.input_mask;
    auto&& input_has_grad = result.output_mask;
109

110
    auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
111 112 113 114
    inputs.push_back(outputs[0]);
    hvs.push_back(*gen({42}));
    inputs.push_back(Tensor::make(hvs.back()));
    mgb_assert(save_for_backward.size() == inputs.size());
M
Megvii Engine Team 已提交
115
    for (size_t i = 0; i < inputs.size(); ++i) {
116
        if (!save_for_backward[i]) {
M
Megvii Engine Team 已提交
117
            inputs[i].reset();  // drop unused tensor
118 119 120 121 122 123 124 125 126
        }
    }
    SmallVector<TensorPtr> backward_graph_inputs;
    for (auto&& i : inputs) {
        if (i) {
            backward_graph_inputs.push_back(i);
        }
    }
    inputs.clear();
127
    auto input_grads = result.graph.apply(backward_graph_inputs,
M
Megvii Engine Team 已提交
128 129
                                             apply_shared_on_physical_tensor,
                                             [&](auto&& x) { return x; });
130
    mgb_assert(input_grads.size() == input_has_grad.size());
M
Megvii Engine Team 已提交
131
    for (size_t i = 0; i < input_has_grad.size(); ++i) {
132 133 134 135 136 137 138 139
        mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
    }

    SmallVector<HostTensorND> res;
    for (auto&& i : input_grads) {
        res.emplace_back();
        res.back().copy_from(i->dev_tensor()).sync();
    }
M
Megvii Engine Team 已提交
140 141 142 143
    for (size_t i = 0; i < 42; ++i) {
        for (size_t j = 0; j < 1; ++j) {
            ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i],
                      res[j ^ 1].ptr<float>()[i]);
144 145 146 147 148 149 150 151 152 153 154
        }
    }
}

TEST(TestImperative, BackwardGraphIdentity) {
    HostTensorGenerator<> gen;
    auto host_a = gen({42}), host_dc = gen({42});
    auto a = Tensor::make(*host_a), dc = Tensor::make(*host_dc);
    SmallVector<TensorPtr> inputs;
    inputs.push_back(a);

155 156
    auto attr = OprAttr::make("Identity");
    attr->cast_final_safe<OprAttr>().param.write_pod<megdnn::param::Empty>({});
157 158 159

    SmallVector<LogicalTensorDesc> input_descs;
    input_descs.push_back({a->layout(), a->comp_node()});
M
Megvii Engine Team 已提交
160 161
    auto result =
            OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
162 163
    auto&& save_for_backward = result.input_mask;
    auto&& input_has_grad = result.output_mask;
164

165
    auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
166 167 168
    inputs.push_back(outputs[0]);
    inputs.push_back(dc);
    mgb_assert(save_for_backward.size() == inputs.size());
M
Megvii Engine Team 已提交
169
    for (size_t i = 0; i < inputs.size(); ++i) {
170
        if (!save_for_backward[i]) {
M
Megvii Engine Team 已提交
171
            inputs[i].reset();  // drop unused tensor
172 173 174 175 176 177 178 179 180
        }
    }
    SmallVector<TensorPtr> backward_graph_inputs;
    for (auto&& i : inputs) {
        if (i) {
            backward_graph_inputs.push_back(i);
        }
    }
    inputs.clear();
181
    auto input_grads = result.graph.apply(backward_graph_inputs,
M
Megvii Engine Team 已提交
182 183
                                             apply_shared_on_physical_tensor,
                                             [&](auto&& x) { return x; });
184
    mgb_assert(input_grads.size() == input_has_grad.size());
M
Megvii Engine Team 已提交
185
    for (size_t i = 0; i < input_has_grad.size(); ++i) {
186 187 188 189 190
        mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
    }

    HostTensorND hv;
    hv.copy_from(input_grads[0]->dev_tensor()).sync();
M
Megvii Engine Team 已提交
191
    for (size_t i = 0; i < 42; ++i) {
192 193 194 195 196
        ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]);
    }
}

TEST(TestImperative, BatchNormGrad) {
197 198
    auto cn = CompNode::load("xpux");
    using Param = opr::BatchNorm::Param;
M
Megvii Engine Team 已提交
199
    size_t N = 2, C = 3, H = 5, W = 5;
200 201 202 203 204 205 206 207 208
    LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
    LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
    {
        auto op = OprAttr::make("BatchNorm");
        auto&& attr = op->cast_final_safe<OprAttr>();
        Param param;
        param.fwd_mode = Param::FwdMode::TRAINING;
        attr.param.write_pod(param);
        OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat},
M
Megvii Engine Team 已提交
209 210
                                   {true, true, true, false, false},
                                   {false, false, false, false, true});
211 212 213 214 215 216 217
    }
    {
        auto op = OprAttr::make("BatchNorm");
        auto&& attr = op->cast_final_safe<OprAttr>();
        Param param;
        param.fwd_mode = Param::FwdMode::TRAINING;
        attr.param.write_pod(param);
M
Megvii Engine Team 已提交
218 219
        OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true},
                                   {false, false, true});
220 221 222 223 224 225 226 227
    }
}

TEST(TestImperative, OptimizedBackwardGraphBasic) {
    auto cn = CompNode::load("xpux");
    LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn};
    HostTensorGenerator<> gen;
    auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD));
M
Megvii Engine Team 已提交
228 229
    auto bg =
            OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true});
230 231 232 233 234 235 236 237 238 239 240 241 242 243
    auto obg = OptimizedBackwardGraphResult(bg);
    ASSERT_EQ(obg.save_for_backward.size(), 4);
    ASSERT_FALSE(obg.save_for_backward[0]);
    ASSERT_FALSE(obg.save_for_backward[1]);
    ASSERT_FALSE(obg.save_for_backward[2]);

    auto a_hv = gen({42});
    auto b_hv = gen({5, 42});
    auto dc_hv = gen({5, 42});
    auto a_tn = Tensor::make(*a_hv);
    auto b_tn = Tensor::make(*b_hv);
    auto dc_tn = Tensor::make(*dc_hv);
    auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0];

M
Megvii Engine Team 已提交
244 245 246 247
    auto backward_graph_inputs =
            prepare_backward_graph_inputs<SmallVector<TensorPtr>>(
                    bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
    auto grads =
248
            expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs,
M
Megvii Engine Team 已提交
249 250
                                               apply_shared_on_physical_tensor,
                                               [&](auto&& x) { return x; }));
251

M
Megvii Engine Team 已提交
252 253 254
    auto precomp = obg.precomp.apply(SmallVector<TensorPtr>{a_tn, b_tn, c_tn},
                                     apply_shared_on_physical_tensor,
                                     [&](auto&& x) { return x; });
255 256 257 258 259 260
    ASSERT_EQ(precomp.size(), 2);
    ASSERT_EQ(precomp[0]->shape().ndim, 1);
    ASSERT_LE(precomp[0]->shape()[0], 2);
    ASSERT_EQ(precomp[1]->shape().ndim, 1);
    ASSERT_LE(precomp[1]->shape()[0], 2);

M
Megvii Engine Team 已提交
261 262 263 264
    auto backward_inputs =
            prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(
                    obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
    auto grads2 = expand_grads(
265
            obg.input_has_grad,
M
Megvii Engine Team 已提交
266 267
            obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor,
                               [&](auto&& x) { return x; }));
268 269 270 271

    ASSERT_EQ(grads2.size(), 2);
    MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());
    MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value());
272
}