From c53abcdf1d12582c35be53a39e6383a520257642 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Jan 2021 17:36:41 +0800 Subject: [PATCH] chore(mge): minor improvements related to grad GitOrigin-RevId: 102467d79d148b52f4dfefadeb3e6a7d7a0d2ad6 --- imperative/python/src/grad.cpp | 12 +--- imperative/python/src/tensor.h | 12 ++++ imperative/src/impl/backward_graph_opt.cpp | 4 ++ imperative/src/impl/ops/backward_graph.cpp | 65 +++++++++++++++++++ .../megbrain/imperative/ops/backward_graph.h | 4 ++ 5 files changed, 86 insertions(+), 11 deletions(-) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 831d7f755..cecdc0cb1 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -155,17 +155,7 @@ struct BackwardGraphWithClosure { } if (null_grad) return; - ApplyContext ctx; - ctx.op = backward_graph->backward; - ctx.flags = is_tracing ? Flags::TRACE : 0; - ctx.nargs = nargs; - ctx.args = args; - for (size_t i = 0; i < nargs; ++i) { - ctx.flags |= args[i]->m_flags; - mgb_assert(args[i]); - } - - auto igrads = apply(ctx); + auto igrads = apply(backward_graph->backward, args, nargs); auto&& it = igrads.begin(); for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { if (p) { diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 7dbd2ed97..0345ed3f3 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -252,6 +252,18 @@ auto apply(std::shared_ptr op, T&& tensors) return apply(ctx); } +inline auto apply(std::shared_ptr op, Tensor*const* args, size_t nargs) { + ApplyContext ctx; + ctx.op = std::move(op); + ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; + ctx.nargs = nargs; + ctx.args = args; + for (size_t i = 0; i < nargs; ++i) { + ctx.flags |= args[i]->m_flags; + } + return apply(ctx); +} + void init_tensor(pybind11::module); extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode; diff --git a/imperative/src/impl/backward_graph_opt.cpp b/imperative/src/impl/backward_graph_opt.cpp index 61b3bc545..08f37ce19 100644 --- a/imperative/src/impl/backward_graph_opt.cpp +++ b/imperative/src/impl/backward_graph_opt.cpp @@ -111,4 +111,8 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe } } } + + if (!fgraph.outputs.size()) { + precomp.reset(); + } } diff --git a/imperative/src/impl/ops/backward_graph.cpp b/imperative/src/impl/ops/backward_graph.cpp index 6b729f43c..38ddd7690 100644 --- a/imperative/src/impl/ops/backward_graph.cpp +++ b/imperative/src/impl/ops/backward_graph.cpp @@ -9,7 +9,11 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include +#include + #include "megbrain/imperative/ops/backward_graph.h" +#include "megbrain/imperative/ops/opr_attr.h" #include "../op_trait.h" namespace mgb { @@ -66,6 +70,67 @@ std::tuple, bool> BackwardGraph::InternalGraph::i return {ret, validated}; } +std::string BackwardGraph::InternalGraph::repr() { + std::ostringstream buf; + buf << "("; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i > 0) buf << ", "; + buf << "%" << inputs[i]; + } + buf << ") => {\n"; + auto fmt_const = [](size_t i, TensorPtr& t) { + if (t->shape().ndim == 1 && t->shape()[0] == 1) { + auto&& v = t->get_value(); + if (v.dtype() == dtype::Float32{}) { + return std::to_string(*v.ptr()); + } else if (v.dtype() == dtype::Int32{}) { + return std::to_string(*v.ptr()); + } + } + return std::string("%c") + std::to_string(i); + }; + std::unordered_map const_reps; + for (auto&& [i, t] : constants) { + const_reps.emplace(i, fmt_const(i, t)); + } + for (auto& [op, ins, outs] : exprs) { + buf << " "; + if (outs.size()) { + for (size_t i = 0; i < outs.size(); ++i) { + if (i > 0) buf << ", "; + buf << "%" << outs[i]; + } + buf << " = "; + } + if (auto* p = op->try_cast_final()) { + buf << p->type; + } else { + buf << op->dyn_typeinfo()->name; + } + for (size_t i : ins) { + buf << " "; + auto&& it = const_reps.find(i); + if (it != const_reps.end()) { + buf << it->second; + } else { + buf << "%" << i; + } + } + buf << "\n"; + } + buf << " "; + if (outputs.size()) { + for (size_t i = 0; i < outputs.size(); ++i) { + if (i > 0) buf << ", "; + buf << "%" << outputs[i]; + } + } else { + buf << "()"; + } + buf << "\n}\n"; + return buf.str(); +} + MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); namespace { diff --git a/imperative/src/include/megbrain/imperative/ops/backward_graph.h b/imperative/src/include/megbrain/imperative/ops/backward_graph.h index ba452703f..0df4d1283 100644 --- a/imperative/src/include/megbrain/imperative/ops/backward_graph.h +++ b/imperative/src/include/megbrain/imperative/ops/backward_graph.h @@ -71,6 +71,8 @@ public: } return ret; } + + std::string repr(); }; const InternalGraph& graph() const { @@ -93,6 +95,8 @@ public: return false; } + std::string repr() {return m_graph.repr();} + private: InternalGraph m_graph; }; -- GitLab