提交 c2293815 编写于 作者: M Megvii Engine Team

fix(autodiff): proxy_graph_detail::make_backward_graph support multiple opnodes

GitOrigin-RevId: 2c0c8f330da645438f2a5ef17c9acef588f89fb3
上级 335d51b4
......@@ -11,10 +11,13 @@
#include "./proxy_graph.h"
#include "./blob_manager_impl.h"
#include "megbrain/graph.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/graph/static_infer.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
......@@ -486,139 +489,83 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
ThinHashMap<VarNode*, size_t> var2idx;
auto push = [&var2idx,
cnt = 1](VarNode* var) mutable { // cnt is always greater non zero
auto&& ret = var2idx.emplace(var, cnt++);
mgb_assert(ret.second, "var %s has been already inserted", var->cname());
return ret.first->second;
};
using op_t = OperatorNodeBase*;
using var_t = VarNode*;
using vars_t = VarNodeArray;
auto inputs = make_input_place_holders(input_descs);
auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr();
auto&& outputs = fwd->usable_output();
auto outputs = OpDef::apply_on_var_node(opdef, inputs);
SmallVector<LogicalTensorDesc> output_descs;
for (auto&& i : outputs) {
output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()});
}
GradContext<op_t, var_t> grad_context{[&](VarNode* lhs, VarNode* rhs) -> VarNode* {
auto add = opr::Elemwise::Mode::ADD;
return opr::Elemwise::make(VarNodeArray{lhs, rhs}, add).node();
}};
cg::DepOprIter iter{[&](OperatorNodeBase* op) {
grad_context.record_expr(op, op->input(), op->output());
}};
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
iter.set_visited(input->owner_opr());
if (input_requires_grad[i]) {
grad_context.mark_require_grad(input);
}
}
for (auto&& output : outputs) {
iter.add(output);
}
auto output_grads = make_input_place_holders(output_descs);
mgb_assert(
output_grads.size() == output_has_grad.size(), "%d vs %d",
output_grads.size(), output_has_grad.size());
bool any_input_has_grad = false;
for (size_t i = 0; i < output_grads.size(); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (!output_has_grad[i]) {
output_grads[i] = nullptr;
} else {
any_input_has_grad = true;
}
}
if (!any_input_has_grad) {
return {};
}
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());
EncodedSubgraph result;
auto&& igraph = result.graph;
size_t nr_backward_graph_inputs = 0;
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
&nr_backward_graph_inputs](cg::OperatorNodeBase* op) {
if (auto t = as_tensor(op)) {
mgb_assert(op->output().size() == 1);
igraph.constants.emplace_back(push(op->output(0)), std::move(t));
} else if (op->same_type<InputPlaceholder>()) {
++nr_backward_graph_inputs;
push(op->output(0));
} else {
SmallVector<size_t> inputs, outputs;
for (auto&& i : op->input()) {
if (i->owner_opr() == fwd) {
if (var2idx.find(i) == var2idx.end()) {
++nr_backward_graph_inputs;
push(i);
}
}
inputs.push_back(var2idx.at(i));
}
for (auto&& i : op->usable_output()) {
outputs.push_back(push(i));
auto compute_input_grads = [&](op_t op, vars_t inputs, vars_t outputs,
vars_t output_grads) {
auto* gfunc = cg::lookup_grad_func(op->dyn_typeinfo());
vars_t input_grads(inputs.size(), nullptr);
bool any_grad = false;
for (auto&& output_grad : output_grads) {
if (output_grad) {
any_grad = true;
}
igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs});
}
};
// set backward graph outputs
cg::DepOprIter iter{gen_expr};
iter.set_visited(fwd);
result.output_mask.resize(inputs.size());
VarNodeArray output_grads_with_unused_var;
{
auto iter = output_grads.begin();
for (auto&& i : fwd->output()) {
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
// the var node with VOLATILE_CONTENT(e.g. workspace
// or an empty var) would not be considered as a normal
// output, so its grad is always NULL
output_grads_with_unused_var.push_back(nullptr);
} else {
output_grads_with_unused_var.push_back(*iter);
++iter;
}
if (!gfunc || !any_grad) {
return input_grads;
}
mgb_assert(iter == output_grads.end());
}
Maybe<VarNodeArray> grad_results;
for (size_t i = 0; i < inputs.size(); ++i) {
VarNode* grad;
if (grad_results.valid()) {
grad = grad_results.val()[i];
} else {
mgb_assert(gfunc, "could not find grad function");
auto res = (*gfunc)(fwd, i, output_grads_with_unused_var);
if (res.from_single()) {
grad = res.single();
} else {
grad_results.emplace(res.all(fwd));
Maybe<VarNodeArray> grad_results;
auto&& input_requires_grad = grad_context.get_require_grads(inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
VarNode* grad;
if (grad_results.valid()) {
grad = grad_results.val()[i];
}
}
if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>() &&
input_requires_grad[i]) {
mgb_assert(
!grad->owner_opr()->same_type<opr::InvalidGrad>(),
"gradient of operator %s w.r.t. input #%lu is "
"either not well defined or not implemented",
fwd->dyn_typeinfo()->name, i);
iter.add(grad);
igraph.outputs.push_back(var2idx.at(grad));
result.output_mask[i] = true;
} else {
result.output_mask[i] = false;
}
}
if (igraph.outputs.empty()) {
return {};
}
// set backward graph inputs
auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) {
for (auto&& i : vars) {
auto&& iter = var2idx.find(i);
if (iter != var2idx.end()) {
igraph.inputs.push_back(iter->second);
result.input_mask.push_back(true);
} else {
result.input_mask.push_back(false);
mgb_assert(gfunc, "could not find grad function");
auto res = (*gfunc)(op, i, output_grads);
if (res.from_single()) {
grad = res.single();
} else {
grad_results.emplace(res.all(op));
grad = grad_results.val()[i];
}
}
if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>()) {
if (input_requires_grad[i]) {
input_grads[i] = grad;
}
}
}
return input_grads;
};
write_inputs(inputs);
write_inputs(outputs);
write_inputs(output_grads);
mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs);
return result;
grad_context.backward(outputs, output_grads, compute_input_grads);
auto input_grads = grad_context.get_grads(inputs);
VarNodeArray bgraph_inputs;
bgraph_inputs.insert(bgraph_inputs.end(), inputs.begin(), inputs.end());
bgraph_inputs.insert(bgraph_inputs.end(), outputs.begin(), outputs.end());
bgraph_inputs.insert(bgraph_inputs.end(), output_grads.begin(), output_grads.end());
auto graph = subgraph_detail::make_from_computing_graph(bgraph_inputs, input_grads);
return graph;
}
VarNodeArray ProxyGraph::make_input_place_holders(
......
......@@ -107,13 +107,16 @@ EncodedSubgraph make_backward_graph_from_forward(
Subgraph::Builder<LogicalTensorDesc> builder(
[](auto&& op, auto&& input_descs, size_t nr_outputs) {
auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs);
mgb_assert(
descs.size() == nr_outputs, "nr_outputs mismatch for %s",
op->to_string().c_str());
return descs;
});
auto accum_grad = [&](var_t lhs, var_t rhs) {
return builder.write_expr(
Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0];
};
GradContext<var_t> grad_context{accum_grad};
GradContext<std::shared_ptr<OpDef>, var_t> grad_context{accum_grad};
auto input_vars = builder.write_inputs(inputs);
auto outputs = forward_graph.apply<var_t>(
input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3),
......@@ -143,19 +146,17 @@ EncodedSubgraph make_backward_graph_from_forward(
grad_context.backward(
apply_mask(outputs, output_has_grad),
apply_mask(output_grads, output_has_grad),
[&](Subgraph::expr_t expr, vars_t output_grads) {
[&](Subgraph::op_t op, vars_t inputs, vars_t outputs, vars_t output_grads) {
auto bg = OpDef::make_backward_graph(
*expr.op, builder.get_descs(expr.inputs),
grad_context.get_require_grads(expr.inputs),
grad_context.get_has_grads(expr.outputs));
*op, builder.get_descs(inputs),
grad_context.get_require_grads(inputs),
grad_context.get_has_grads(outputs));
if (bg.graph.empty()) {
return vars_t(expr.inputs.size(), 0);
return vars_t(inputs.size(), 0);
}
vars_t grad_inputs;
grad_inputs.insert(
grad_inputs.end(), expr.inputs.begin(), expr.inputs.end());
grad_inputs.insert(
grad_inputs.end(), expr.outputs.begin(), expr.outputs.end());
grad_inputs.insert(grad_inputs.end(), inputs.begin(), inputs.end());
grad_inputs.insert(grad_inputs.end(), outputs.begin(), outputs.end());
grad_inputs.insert(
grad_inputs.end(), output_grads.begin(), output_grads.end());
auto apply_functor =
......@@ -183,6 +184,77 @@ EncodedSubgraph make_backward_graph(
forward_graph, inputs, input_requires_grad, output_has_grad);
}
EncodedSubgraph make_from_computing_graph(
const VarNodeArray& inputs, const VarNodeArray& outputs) {
Subgraph subgraph;
std::unordered_map<VarNode*, size_t> var2idx;
size_t next_idx = 0;
var2idx[nullptr] = next_idx++;
for (auto&& input : inputs) {
if (input) {
var2idx[input] = next_idx++;
}
}
auto is_tensor_holder = [](cg::OperatorNodeBase* op) {
return op->input().empty();
};
auto as_tensor = [](VarNode* var) -> TensorPtr {
auto* opr = var->owner_opr();
if (auto* imm_tensor = opr->try_cast_final<opr::ImmutableTensor>()) {
auto&& dv = imm_tensor->value();
HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype());
// get host value
auto&& cpu_value = imm_tensor->host_value();
mgb_assert(cpu_value.comp_node() == CompNode::default_cpu());
// default_cpu is synchronous with respect to caller
hv.proxy_to_default_cpu().copy_from_fixlayout(cpu_value);
return Tensor::make(dv, hv);
} else if (
auto* shared_tensor = opr->try_cast_final<opr::SharedDeviceTensor>()) {
return Tensor::make(shared_tensor->get_dev_tensor());
} else {
mgb_assert(
false, "unsupported tensor holder opr %s",
opr->dyn_typeinfo()->name);
}
};
cg::DepOprIter iter{[&](cg::OperatorNodeBase* op) {
// TODO: implement make_backward_graph for mm ops
// mgb_assert(!op->node_prop().contain(cg::OperatorNodeProp::Flag::IMPURE_FUNC));
if (is_tensor_holder(op)) {
for (auto&& output : op->usable_output()) {
subgraph.constants.push_back(
{var2idx[output] = next_idx++, as_tensor(output)});
}
} else {
Subgraph::vars_t inputs;
Subgraph::vars_t outputs;
for (auto&& input : op->input()) {
inputs.push_back(var2idx.at(input));
}
// NOTE: use usable_output
for (auto&& output : op->usable_output()) {
outputs.push_back(var2idx[output] = next_idx++);
}
auto opdef = OpDef::make_from_op_node(op);
subgraph.exprs.push_back({opdef, inputs, outputs});
}
}};
for (auto&& input : inputs) {
if (input) {
iter.set_visited(input->owner_opr());
}
subgraph.inputs.push_back(var2idx.at(input));
}
for (auto&& output : outputs) {
if (output) {
iter.add(output);
}
subgraph.outputs.push_back(var2idx.at(output));
}
return EncodedSubgraph::make(subgraph);
}
} // namespace subgraph_detail
} // namespace imperative
} // namespace mgb
......@@ -189,12 +189,17 @@ struct EncodedSubgraph {
size_t hash() const;
};
template <typename T>
template <typename TOp, typename TVar>
class GradContext {
public:
using var_t = T;
using op_t = TOp;
using var_t = TVar;
using vars_t = SmallVector<var_t>;
using expr_t = Expr<T>;
struct expr_t {
op_t op;
vars_t inputs;
vars_t outputs;
};
private:
std::unordered_map<var_t, var_t> m_grads;
......@@ -219,6 +224,7 @@ public:
}
return mask;
}
void mark_require_grad(var_t dest) { m_vars_require_grad.insert(dest); }
void mark_require_grads(vars_t dests) {
for (auto&& dest : dests) {
m_vars_require_grad.insert(dest);
......@@ -231,7 +237,7 @@ public:
return m_grads[dest] = m_accumulator(m_grads[dest], grad);
}
}
void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) {
void record_expr(op_t op, vars_t inputs, vars_t outputs) {
bool require_grad = false;
for (auto&& input : inputs) {
if (m_vars_require_grad.count(input)) {
......@@ -254,7 +260,8 @@ public:
std::reverse(exprs.begin(), exprs.end());
for (const expr_t& expr : exprs) {
size_t nr_inputs = expr.inputs.size();
vars_t input_grads = functor(expr, get_grads(expr.outputs));
vars_t input_grads = functor(
expr.op, expr.inputs, expr.outputs, get_grads(expr.outputs));
mgb_assert(input_grads.size() == nr_inputs, "input size mismatch");
for (size_t i = 0; i < nr_inputs; ++i) {
if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) {
......
......@@ -43,6 +43,8 @@ EncodedSubgraph make_backward_graph_from_forward(
const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
EncodedSubgraph make_from_computing_graph(
const VarNodeArray& inputs, const VarNodeArray& outputs);
} // namespace subgraph_detail
} // namespace imperative
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册