提交 c53abcdf 编写于 作者: M Megvii Engine Team

chore(mge): minor improvements related to grad

GitOrigin-RevId: 102467d79d148b52f4dfefadeb3e6a7d7a0d2ad6
上级 0a3ca253
......@@ -155,17 +155,7 @@ struct BackwardGraphWithClosure {
}
if (null_grad) return;
ApplyContext ctx;
ctx.op = backward_graph->backward;
ctx.flags = is_tracing ? Flags::TRACE : 0;
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
mgb_assert(args[i]);
}
auto igrads = apply(ctx);
auto igrads = apply(backward_graph->backward, args, nargs);
auto&& it = igrads.begin();
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
if (p) {
......
......@@ -252,6 +252,18 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
return apply(ctx);
}
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}
void init_tensor(pybind11::module);
extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode;
......
......@@ -111,4 +111,8 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
}
}
}
if (!fgraph.outputs.size()) {
precomp.reset();
}
}
......@@ -9,7 +9,11 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <sstream>
#include <range/v3/all.hpp>
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "../op_trait.h"
namespace mgb {
......@@ -66,6 +70,67 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i
return {ret, validated};
}
std::string BackwardGraph::InternalGraph::repr() {
std::ostringstream buf;
buf << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << inputs[i];
}
buf << ") => {\n";
auto fmt_const = [](size_t i, TensorPtr& t) {
if (t->shape().ndim == 1 && t->shape()[0] == 1) {
auto&& v = t->get_value();
if (v.dtype() == dtype::Float32{}) {
return std::to_string(*v.ptr<dt_float32>());
} else if (v.dtype() == dtype::Int32{}) {
return std::to_string(*v.ptr<int32_t>());
}
}
return std::string("%c") + std::to_string(i);
};
std::unordered_map<size_t, std::string> const_reps;
for (auto&& [i, t] : constants) {
const_reps.emplace(i, fmt_const(i, t));
}
for (auto& [op, ins, outs] : exprs) {
buf << " ";
if (outs.size()) {
for (size_t i = 0; i < outs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outs[i];
}
buf << " = ";
}
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->dyn_typeinfo()->name;
}
for (size_t i : ins) {
buf << " ";
auto&& it = const_reps.find(i);
if (it != const_reps.end()) {
buf << it->second;
} else {
buf << "%" << i;
}
}
buf << "\n";
}
buf << " ";
if (outputs.size()) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outputs[i];
}
} else {
buf << "()";
}
buf << "\n}\n";
return buf.str();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph);
namespace {
......
......@@ -71,6 +71,8 @@ public:
}
return ret;
}
std::string repr();
};
const InternalGraph& graph() const {
......@@ -93,6 +95,8 @@ public:
return false;
}
std::string repr() {return m_graph.repr();}
private:
InternalGraph m_graph;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册