提交 c53abcdf 编写于 作者: M Megvii Engine Team

chore(mge): minor improvements related to grad

GitOrigin-RevId: 102467d79d148b52f4dfefadeb3e6a7d7a0d2ad6
上级 0a3ca253
...@@ -155,17 +155,7 @@ struct BackwardGraphWithClosure { ...@@ -155,17 +155,7 @@ struct BackwardGraphWithClosure {
} }
if (null_grad) return; if (null_grad) return;
ApplyContext ctx; auto igrads = apply(backward_graph->backward, args, nargs);
ctx.op = backward_graph->backward;
ctx.flags = is_tracing ? Flags::TRACE : 0;
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
mgb_assert(args[i]);
}
auto igrads = apply(ctx);
auto&& it = igrads.begin(); auto&& it = igrads.begin();
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
if (p) { if (p) {
......
...@@ -252,6 +252,18 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) ...@@ -252,6 +252,18 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
return apply(ctx); return apply(ctx);
} }
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}
void init_tensor(pybind11::module); void init_tensor(pybind11::module);
extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode; extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode;
......
...@@ -111,4 +111,8 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe ...@@ -111,4 +111,8 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
} }
} }
} }
if (!fgraph.outputs.size()) {
precomp.reset();
}
} }
...@@ -9,7 +9,11 @@ ...@@ -9,7 +9,11 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include <sstream>
#include <range/v3/all.hpp>
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "../op_trait.h" #include "../op_trait.h"
namespace mgb { namespace mgb {
...@@ -66,6 +70,67 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i ...@@ -66,6 +70,67 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
return {ret, validated}; return {ret, validated};
} }
std::string BackwardGraph::InternalGraph::repr() {
std::ostringstream buf;
buf << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << inputs[i];
}
buf << ") => {\n";
auto fmt_const = [](size_t i, TensorPtr& t) {
if (t->shape().ndim == 1 && t->shape()[0] == 1) {
auto&& v = t->get_value();
if (v.dtype() == dtype::Float32{}) {
return std::to_string(*v.ptr<dt_float32>());
} else if (v.dtype() == dtype::Int32{}) {
return std::to_string(*v.ptr<int32_t>());
}
}
return std::string("%c") + std::to_string(i);
};
std::unordered_map<size_t, std::string> const_reps;
for (auto&& [i, t] : constants) {
const_reps.emplace(i, fmt_const(i, t));
}
for (auto& [op, ins, outs] : exprs) {
buf << " ";
if (outs.size()) {
for (size_t i = 0; i < outs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outs[i];
}
buf << " = ";
}
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->dyn_typeinfo()->name;
}
for (size_t i : ins) {
buf << " ";
auto&& it = const_reps.find(i);
if (it != const_reps.end()) {
buf << it->second;
} else {
buf << "%" << i;
}
}
buf << "\n";
}
buf << " ";
if (outputs.size()) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outputs[i];
}
} else {
buf << "()";
}
buf << "\n}\n";
return buf.str();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph);
namespace { namespace {
......
...@@ -71,6 +71,8 @@ public: ...@@ -71,6 +71,8 @@ public:
} }
return ret; return ret;
} }
std::string repr();
}; };
const InternalGraph& graph() const { const InternalGraph& graph() const {
...@@ -93,6 +95,8 @@ public: ...@@ -93,6 +95,8 @@ public:
return false; return false;
} }
std::string repr() {return m_graph.repr();}
private: private:
InternalGraph m_graph; InternalGraph m_graph;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册