提交 278b2baa 编写于 作者: M Megvii Engine Team

perf(mge): add memory optimization for backward graph

precompute ops in forward to reduce saved tensor size

GitOrigin-RevId: d67043ba82673406cef141771c56d3eded68f9c5
上级 ebe86892
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "./grad.h" #include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
...@@ -32,14 +33,14 @@ struct GradSlotWeakPtr { ...@@ -32,14 +33,14 @@ struct GradSlotWeakPtr {
size_t idx; size_t idx;
}; };
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override { std::shared_ptr<void> on_comp_node_finalize() override {
clear(); clear();
return {}; return {};
} }
} backward_graph_cache; } backward_graph_cache;
std::shared_ptr<BackwardGraphResult> make_backward_graph( std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) { ApplyContext& ctx, const apply_result_t& outputs) {
// hash // hash
static_assert(alignof(size_t) % alignof(bool) == 0); static_assert(alignof(size_t) % alignof(bool) == 0);
...@@ -72,23 +73,23 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( ...@@ -72,23 +73,23 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
inputs[i].layout.dtype = ctx.args[i]->dtype(); inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = python::input_requires_grad(ctx, i); input_requires_grad[i] = python::input_requires_grad(ctx, i);
} }
auto result = std::make_shared<BackwardGraphResult>( std::shared_ptr<OptimizedBackwardGraphResult> ret;
proxy_graph_detail::make_backward_graph( auto bg = proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad)); *ctx.op, inputs, input_requires_grad, output_has_grad);
if (!result->backward) { if (bg.backward) {
result.reset(); ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
} }
backward_graph_cache.emplace(key, result); backward_graph_cache.emplace(key, ret);
return result; return ret;
} }
struct BackwardGraphWithClosure { struct BackwardGraphWithClosure {
std::shared_ptr<BackwardGraphResult> backward_graph; std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure; SmallVector<std::shared_ptr<Tensor>> closure;
size_t output_mask_offset; size_t output_mask_offset;
size_t grad_mask_offset; size_t grad_mask_offset;
BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_, BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
ApplyContext& ctx, const apply_result_t& outputs) ApplyContext& ctx, const apply_result_t& outputs)
: backward_graph(backward_graph_), : backward_graph(backward_graph_),
output_mask_offset(ctx.nargs), output_mask_offset(ctx.nargs),
...@@ -107,9 +108,18 @@ struct BackwardGraphWithClosure { ...@@ -107,9 +108,18 @@ struct BackwardGraphWithClosure {
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1] // b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
auto& save_for_backward = backward_graph->save_for_backward; auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
closure.reserve(std::count_if(save_for_backward.begin(), size_t count = std::count_if(save_for_backward.begin(),
save_for_backward.end(), save_for_backward.end(),
ranges::identity{})); ranges::identity{});
if (backward_graph->precomp) {
auto&& irng = ranges::span(ctx.args, ctx.nargs);
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
closure.reserve(precomp.size() + count);
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
} else {
closure.reserve(count);
}
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) { if (save_for_backward[i]) {
closure.push_back(ctx.args[i]->shared_from_this()); closure.push_back(ctx.args[i]->shared_from_this());
......
...@@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) { ...@@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) {
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
return resolve_arrow(p.operator->()); return resolve_arrow(p.operator->());
} else { } else {
return p; return std::forward<T>(p);
} }
} }
} }
......
/**
* \file imperative/src/impl/backward_graph_opt.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"
using namespace mgb;
using namespace imperative;
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
: input_has_grad(src.input_has_grad) {
if (!src.backward->same_type<BackwardGraph>()) {
// backward graph only contains a single op
backward = src.backward;
save_for_backward = src.save_for_backward;
return;
}
save_for_backward.resize(src.save_for_backward.size(), false);
precomp.reset(new BackwardGraph);
backward.reset(new BackwardGraph);
auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph();
auto&& mask = src.save_for_backward;
size_t input_size = src.input_has_grad.size();
size_t output_size = (mask.size() - input_size) / 2;
mgb_assert(input_size + output_size * 2 == mask.size());
auto& fgraph = precomp->cast_final<BackwardGraph>().graph();
auto& bgraph = backward->cast_final<BackwardGraph>().graph();
// optimization: move ops (e.g. GetVarShape) to forward to
// reduce memory footprint
struct VInfo {
bool appears_in_backward = false;
};
std::unordered_map<size_t, VInfo> vinfo;
// step 1.1: ops not in whitelist must run in backward.
// mark their inputs as always appears in backward
for (auto&& [op, iv, ov] : graph.exprs) {
if (!op->same_type<GetVarShape>()) {
for (auto&& v : iv) {
vinfo[v].appears_in_backward = true;
}
}
}
// step 1.2: inputs only available in backward (i.e. grads)
// should be marked as always appears in backward
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (!mask[i]) continue;
if (i > input_size + output_size) {
vinfo[graph.inputs[j]].appears_in_backward = true;
}
++j;
}
// step 2: try to move ops to forward, if not all their inputs
// are marked always appears in backward (otherwise no memory saving)
for (auto&& expr : graph.exprs) {
auto&& [op, iv, ov] = expr;
if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) {
bgraph.exprs.push_back(expr);
for (auto&& v : ov) {
vinfo[v].appears_in_backward = true;
}
// logically should also mark all inputs as appears in backward
// but clearly that's a no-op.
} else {
fgraph.exprs.push_back(expr);
for (auto&& v : ov) {
if (vinfo[v].appears_in_backward) {
// appears_in_backward won't change after this point
// so it is safe to set fgraph.outputs based on current value
fgraph.outputs.push_back(v);
}
}
}
}
// initialize remaining parts
fgraph.constants = graph.constants;
fgraph.inputs.reserve(input_size + output_size);
for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
if (!mask[i]) {
fgraph.inputs.push_back(1000000000 + i);
continue;
}
fgraph.inputs.push_back(graph.inputs[j++]);
}
bgraph.constants = graph.constants;
bgraph.outputs = graph.outputs;
bgraph.inputs = fgraph.outputs;
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (mask[i]) {
auto&& v = graph.inputs[j++];
if (vinfo[v].appears_in_backward) {
save_for_backward[i] = true;
bgraph.inputs.push_back(v);
}
}
}
}
/**
* \file imperative/src/include/megbrain/imperative/backward_graph_opt.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./op_def.h"
namespace mgb::imperative {
struct OptimizedBackwardGraphResult {
std::shared_ptr<OpDef> precomp;
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
std::vector<bool> input_has_grad;
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph);
};
} // namespace mgb::imperative
...@@ -13,11 +13,68 @@ ...@@ -13,11 +13,68 @@
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/backward_graph_opt.h"
using namespace mgb; using namespace mgb;
using namespace cg; using namespace cg;
using namespace imperative; using namespace imperative;
template <typename T>
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) {
T ret;
size_t i = 0;
for (auto&& t : inputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : outputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : grads) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
return ret;
}
template <typename T, typename U>
T expand_grads(const U& bg, const T& outputs) {
T ret(bg.input_has_grad.size());
for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) {
if (bg.input_has_grad[i]) {
ret[i] = outputs[j++];
}
}
return ret;
}
template <typename T>
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) {
T ret = precomp;
size_t i = 0;
for (auto&& t : inputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : outputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : grads) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
return ret;
}
TEST(TestImperative, BackwardGraphBasic) { TEST(TestImperative, BackwardGraphBasic) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
SmallVector<HostTensorND> hvs; SmallVector<HostTensorND> hvs;
...@@ -121,27 +178,65 @@ TEST(TestImperative, BackwardGraphIdentity) { ...@@ -121,27 +178,65 @@ TEST(TestImperative, BackwardGraphIdentity) {
} }
TEST(TestImperative, BatchNormGrad) { TEST(TestImperative, BatchNormGrad) {
auto cn = CompNode::load("xpux"); auto cn = CompNode::load("xpux");
using Param = opr::BatchNorm::Param; using Param = opr::BatchNorm::Param;
size_t N=2, C=3, H=5, W=5; size_t N=2, C=3, H=5, W=5;
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{ {
auto op = OprAttr::make("BatchNorm"); auto op = OprAttr::make("BatchNorm");
auto&& attr = op->cast_final_safe<OprAttr>(); auto&& attr = op->cast_final_safe<OprAttr>();
Param param; Param param;
param.fwd_mode = Param::FwdMode::TRAINING; param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param); attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat},
{true, true ,true, false, false}, {false, false, false, false, true}); {true, true ,true, false, false}, {false, false, false, false, true});
} }
{ {
auto op = OprAttr::make("BatchNorm"); auto op = OprAttr::make("BatchNorm");
auto&& attr = op->cast_final_safe<OprAttr>(); auto&& attr = op->cast_final_safe<OprAttr>();
Param param; Param param;
param.fwd_mode = Param::FwdMode::TRAINING; param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param); attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat}, OpDef::make_backward_graph(attr, {inp, stat, stat},
{true, true ,true}, {false, false, true}); {true, true ,true}, {false, false, true});
} }
}
TEST(TestImperative, OptimizedBackwardGraphBasic) {
auto cn = CompNode::load("xpux");
LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn};
HostTensorGenerator<> gen;
auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD));
auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true});
auto obg = OptimizedBackwardGraphResult(bg);
ASSERT_EQ(obg.save_for_backward.size(), 4);
ASSERT_FALSE(obg.save_for_backward[0]);
ASSERT_FALSE(obg.save_for_backward[1]);
ASSERT_FALSE(obg.save_for_backward[2]);
auto a_hv = gen({42});
auto b_hv = gen({5, 42});
auto dc_hv = gen({5, 42});
auto a_tn = Tensor::make(*a_hv);
auto b_tn = Tensor::make(*b_hv);
auto dc_tn = Tensor::make(*dc_hv);
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0];
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs));
auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn});
ASSERT_EQ(precomp.size(), 2);
ASSERT_EQ(precomp[0]->shape().ndim, 1);
ASSERT_LE(precomp[0]->shape()[0], 2);
ASSERT_EQ(precomp[1]->shape().ndim, 1);
ASSERT_LE(precomp[1]->shape()[0], 2);
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs));
ASSERT_EQ(grads2.size(), 2);
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());
MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value());
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册