提交 0b8dc2c9 编写于 作者: M Megvii Engine Team

refactor(subgraph): add generic encoded_graph

GitOrigin-RevId: 56d90be0e702ed15cafc9e00586a313bf817c9dd
上级 88b3c842
......@@ -77,7 +77,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = OpDef::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad);
if (!bg.backward.empty()) {
if (!bg.graph.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
backward_graph_cache.emplace(key, ret);
......
......@@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) {
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad){
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad);
return std::make_tuple("backward_graph", result.input_mask, result.output_mask);
};
m.def("make_backward_graph", make_backward_graph);
}
......@@ -16,19 +16,19 @@
using namespace mgb;
using namespace imperative;
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
: input_has_grad(src.input_has_grad) {
if (src.backward.exprs.size() <= 1) {
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubraph& src)
: input_has_grad(src.output_mask) {
if (src.graph.exprs.size() <= 1) {
// backward graph only contains a single op
backward = src.backward;
save_for_backward = src.save_for_backward;
backward = src.graph;
save_for_backward = src.input_mask;
return;
}
save_for_backward.resize(src.save_for_backward.size(), false);
save_for_backward.resize(src.input_mask.size(), false);
auto&& graph = src.backward;
auto&& mask = src.save_for_backward;
size_t input_size = src.input_has_grad.size();
auto&& graph = src.graph;
auto&& mask = src.input_mask;
size_t input_size = src.output_mask.size();
size_t output_size = (mask.size() - input_size) / 2;
mgb_assert(input_size + output_size * 2 == mask.size());
......
......@@ -80,7 +80,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli
return def.trait()->infer_output_attrs_fallible(def, inputs);
}
BackwardGraphResult OpDef::make_backward_graph(
EncodedSubraph OpDef::make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph {
cg::VarNode* grad;
};
BackwardGraphResult
EncodedSubraph
ProxyGraph::make_backward_graph(
const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& input_descs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
ThinHashMap<VarNode*, size_t> var2idx;
auto push = [&var2idx, cnt=0](VarNode* var) mutable {
auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero
auto&& ret = var2idx.emplace(var, cnt ++);
mgb_assert(ret.second, "var %s has been already inserted", var->cname());
return ret.first->second;
......@@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph(
}
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());
BackwardGraphResult result;
auto&& igraph = result.backward;
EncodedSubraph result;
auto&& igraph = result.graph;
size_t nr_backward_graph_inputs = 0;
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
......@@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph(
// set backward graph outputs
cg::DepOprIter iter{gen_expr};
iter.set_visited(fwd);
result.input_has_grad.resize(inputs.size());
result.output_mask.resize(inputs.size());
VarNodeArray output_grads_with_unused_var;
{
......@@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph(
if (grad_results.valid()) {
grad = grad_results.val()[i];
} else {
mgb_assert(gfunc, "could not find grad function");
auto res = (*gfunc)(fwd, i, output_grads_with_unused_var);
if (res.from_single()) {
grad = res.single();
......@@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph(
fwd->dyn_typeinfo()->name, i);
iter.add(grad);
igraph.outputs.push_back(var2idx.at(grad));
result.input_has_grad[i] = true;
result.output_mask[i] = true;
} else {
result.input_has_grad[i] = false;
result.output_mask[i] = false;
}
}
if (igraph.outputs.empty()) {
......@@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph(
// set backward graph inputs
igraph.inputs.reserve(nr_backward_graph_inputs);
result.save_for_backward.reserve(nr_backward_graph_inputs);
result.input_mask.reserve(nr_backward_graph_inputs);
auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) {
for (auto&& i: vars) {
auto&& iter = var2idx.find(i);
if (iter != var2idx.end()) {
igraph.inputs.push_back(iter->second);
result.save_for_backward.push_back(true);
result.input_mask.push_back(true);
} else {
result.save_for_backward.push_back(false);
result.input_mask.push_back(false);
}
}
};
......
......@@ -40,7 +40,7 @@ public:
const SmallVector<Tensor*>& outputs,
const SmallVector<Tensor*>& workspace);
BackwardGraphResult make_backward_graph(
EncodedSubraph make_backward_graph(
const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& input_descs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def,
return state.digest();
}
struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, CompNodeDepedentObject {
struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
......@@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, Com
} // anonymous namespace
BackwardGraphResult
EncodedSubraph
make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -101,5 +101,26 @@ void Subgraph::replace_vars(
}
}
std::string EncodedSubraph::repr() const {
std::string buffer;
buffer.push_back('|');
for (size_t i = 0; i < input_mask.size(); ++i) {
buffer.push_back(input_mask[i] ? '#' : ' ');
}
buffer.push_back('|');
buffer.push_back('\n');
buffer.append(graph.repr());
buffer.push_back('|');
for (size_t i = 0; i < output_mask.size(); ++i) {
buffer.push_back(output_mask[i] ? '#' : ' ');
}
buffer.push_back('|');
return buffer;
}
size_t EncodedSubraph::hash() const {
return std::hash<std::string>{}(repr());
}
} // namespace imperative
} // namespace mgb
......@@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult {
SmallVector<bool> save_for_backward;
SmallVector<bool> input_has_grad;
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph);
OptimizedBackwardGraphResult(const EncodedSubraph& bgraph);
};
} // namespace mgb::imperative
......@@ -29,12 +29,6 @@ enum DispatchMode {
using SharedOp = std::shared_ptr<OpDef>;
struct BackwardGraphResult {
Subgraph backward;
SmallVector<bool> save_for_backward;
SmallVector<bool> input_has_grad;
};
class OpDef : public Hashable,
public NonCopyableObj,
public std::enable_shared_from_this<OpDef> {
......@@ -91,7 +85,7 @@ public:
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
static BackwardGraphResult make_backward_graph(
static EncodedSubraph make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -38,7 +38,7 @@ void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs);
BackwardGraphResult
EncodedSubraph
make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -96,5 +96,185 @@ struct Subgraph {
bool operator==(const Subgraph& rhs) const;
};
struct EncodedSubraph {
Subgraph graph;
SmallVector<bool> input_mask;
SmallVector<bool> output_mask;
template <typename TContainer>
TContainer encode_inputs(TContainer inputs) const {
TContainer encoded_inputs;
size_t index = 0;
for (auto&& input : inputs) {
mgb_assert(index < input_mask.size(), "index out of range");
if (input_mask[index++]) {
encoded_inputs.push_back(input);
}
}
mgb_assert(index == input_mask.size(), "mask size mismatch");
return encoded_inputs;
}
template <typename TContainer>
TContainer encode_outputs(TContainer outputs) const {
TContainer encoded_outputs;
size_t index = 0;
for (auto&& output : outputs) {
mgb_assert(index < output_mask.size(), "index out of range");
if (output_mask[index++]) {
encoded_outputs.push_back(output);
}
}
mgb_assert(index == output_mask.size(), "mask size mismatch");
return encoded_outputs;
}
template <typename TContainer>
TContainer decode_outputs(TContainer outputs) const {
TContainer decoded_outputs;
size_t index = 0;
for (size_t i = 0; i < output_mask.size(); i++) {
mgb_assert(index < output_mask.size(), "index out of range");
if (output_mask[i]) {
decoded_outputs.push_back(outputs[index++]);
} else {
decoded_outputs.emplace_back();
}
}
mgb_assert(decoded_outputs.size() == output_mask.size(),
"mask size mismatch");
return decoded_outputs;
}
static EncodedSubraph make(Subgraph graph) {
EncodedSubraph result;
result.input_mask = graph.gen_input_mask();
result.output_mask = graph.gen_output_mask();
graph.inputs = result.encode_inputs(graph.inputs);
graph.outputs = result.encode_outputs(graph.outputs);
result.graph = graph;
return result;
}
static EncodedSubraph make_single(
std::shared_ptr<OpDef> op,
SmallVector<bool> input_mask,
SmallVector<bool> output_mask) {
EncodedSubraph result;
result.input_mask = input_mask;
result.output_mask = output_mask;
Subgraph::var_t last_var = 0;
for (auto&& mask: input_mask) {
if (mask) {
result.graph.inputs.push_back(++last_var);
}
}
for (auto&& mask: output_mask) {
if (mask) {
result.graph.outputs.push_back(++last_var);
}
}
result.graph.exprs = {Subgraph::expr_t{op, result.graph.inputs, result.graph.outputs}};
return result;
}
template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
auto encoded_inputs = encode_inputs(input_vars);
auto encoded_outputs = graph.apply(encoded_inputs, std::forward<F>(f),
std::forward<C>(c));
return decode_outputs(encoded_outputs);
}
std::string repr() const;
size_t hash() const;
};
template <typename T>
class GradContext {
public:
using var_t = T;
using vars_t = SmallVector<var_t>;
using expr_t = Expr<T>;
private:
std::unordered_map<var_t, var_t> m_grads;
std::unordered_set<var_t> m_vars_require_grad;
std::function<var_t(var_t, var_t)> m_accumulator;
std::vector<expr_t> m_exprs;
public:
GradContext(std::function<var_t(var_t, var_t)> accumulator): m_accumulator{std::move(accumulator)}{}
SmallVector<bool> get_require_grads(vars_t dests) {
SmallVector<bool> mask;
for (auto&& dest: dests) {
mask.push_back(bool(m_vars_require_grad.count(dest)));
}
return mask;
}
SmallVector<bool> get_has_grads(vars_t dests) {
SmallVector<bool> mask;
for (auto&& dest: dests) {
mask.push_back(bool(m_grads.count(dest)));
}
return mask;
}
void mark_require_grads(vars_t dests) {
for (auto&& dest: dests) {
m_vars_require_grad.insert(dest);
}
}
var_t accumulate_grad(var_t dest, var_t grad) {
if (!m_grads.count(dest)) {
return m_grads[dest] = grad;
} else {
return m_grads[dest] = m_accumulator(m_grads[dest], grad);
}
}
void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) {
bool require_grad = false;
for (auto&& input: inputs) {
if (m_vars_require_grad.count(input)) {
require_grad = true;
break;
}
}
if (require_grad) {
m_exprs.push_back({op, inputs, outputs});
mark_require_grads(outputs);
}
}
template <typename TFunctor>
void backward(vars_t outputs, vars_t output_grads, TFunctor functor) {
size_t nr_outputs = outputs.size();
for (size_t i = 0; i < nr_outputs; ++i) {
m_grads[outputs[i]] = output_grads[i];
}
auto exprs = m_exprs;
std::reverse(exprs.begin(), exprs.end());
for (const expr_t& expr: exprs) {
size_t nr_inputs = expr.inputs.size();
vars_t input_grads = functor(expr, get_grads(expr.outputs));
mgb_assert(input_grads.size() == nr_inputs, "input size mismatch");
for (size_t i = 0; i < nr_inputs; ++i) {
if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) {
accumulate_grad(expr.inputs[i], input_grads[i]);
}
}
}
}
var_t get_grad(var_t dest) {
if (m_grads.count(dest)) {
return m_grads.at(dest);
}
return 0;
}
vars_t get_grads(vars_t dests) {
vars_t grads;
for (auto&& dest: dests) {
grads.push_back(get_grad(dest));
}
return grads;
}
};
} // namespace imperative
} // namespace mgb
\ No newline at end of file
......@@ -22,22 +22,22 @@ using namespace cg;
using namespace imperative;
template <typename T>
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs,
T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs,
const T& outputs, const T& grads) {
T ret;
size_t i = 0;
for (auto&& t : inputs) {
if (bg.save_for_backward[i++]) {
if (bg.input_mask[i++]) {
ret.push_back(t);
}
}
for (auto&& t : outputs) {
if (bg.save_for_backward[i++]) {
if (bg.input_mask[i++]) {
ret.push_back(t);
}
}
for (auto&& t : grads) {
if (bg.save_for_backward[i++]) {
if (bg.input_mask[i++]) {
ret.push_back(t);
}
}
......@@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs,
}
template <typename T, typename U>
T expand_grads(const U& bg, const T& outputs) {
T ret(bg.input_has_grad.size());
for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) {
if (bg.input_has_grad[i]) {
T expand_grads(const U& mask, const T& outputs) {
T ret(mask.size());
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (mask[i]) {
ret[i] = outputs[j++];
}
}
......@@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg,
}
SmallVector<TensorPtr> apply_shared_on_physical_tensor(
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) {
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) {
return OpDef::apply_on_physical_tensor(*def, inputs);
}
......@@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) {
}
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true},
{true});
auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad;
auto&& save_for_backward = result.input_mask;
auto&& input_has_grad = result.output_mask;
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
inputs.push_back(outputs[0]);
......@@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs.clear();
auto input_grads = result.backward.apply(backward_graph_inputs,
auto input_grads = result.graph.apply(backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size());
......@@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) {
input_descs.push_back({a->layout(), a->comp_node()});
auto result =
OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad;
auto&& save_for_backward = result.input_mask;
auto&& input_has_grad = result.output_mask;
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
inputs.push_back(outputs[0]);
......@@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs.clear();
auto input_grads = result.backward.apply(backward_graph_inputs,
auto input_grads = result.graph.apply(backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size());
......@@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
prepare_backward_graph_inputs<SmallVector<TensorPtr>>(
bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads =
expand_grads(bg, bg.backward.apply(backward_graph_inputs,
expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }));
......@@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(
obg,
obg.input_has_grad,
obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册