提交 273c0e87 编写于 作者: M Megvii Engine Team

fix(autodiff): fix some bugs in relation to 2nd order grad

1. implement double backward for batchnorm
2. fix grad attach in nested grad manager
3. pad empty tensor for unsatisfied output_has_grad
4. support double backward for jit subgraph
5. support double backward for autodiff.Function
6. readd debug flag MGE_LOG_OP_DISPATCH

GitOrigin-RevId: cd31ddc620a35e0582c9721df7290c972fa3c610
上级 bc9aa47a
......@@ -212,10 +212,7 @@ class Function:
if self.__single_output:
outputs = (outputs,)
for grad in reversed(group):
if grad._impl is None:
continue
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs)
outputs = core2.set_grad(normalized_backward, args, outputs)
if self.__single_output:
(outputs,) = outputs
return outputs
......
......@@ -209,7 +209,6 @@ def subgraph(
outputs = gen.send(None)
nr_outputs = len(outputs)
forward_fn = build(builder, outputs, [False] * nr_outputs)
output_grads = [builder.input() for _ in range(nr_outputs)]
input_grads = gen.send(output_grads)
assert len(input_grads) == nr_inputs
......@@ -222,21 +221,45 @@ def subgraph(
]
encoded_input_grads = [grad for grad in input_grads if grad is not None]
backward_fn = build(
builder, encoded_input_grads, [False] * len(encoded_input_grads)
builder, encoded_input_grads, [True] * len(encoded_input_grads)
)
class SubgraphOp(Function):
def __init__(self):
self.inputs = None
self.output_shapes = None
def forward(self, *inputs):
self.inputs = inputs
return apply(forward_fn(), *inputs)
outputs = apply(forward_fn(), *inputs)
if len(outputs) > 1:
self.output_shapes = [output.shape for output in outputs]
return outputs
def backward(self, *output_grads):
inputs = self.inputs
self.inputs = None
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads)
any_valid = False
all_valid = True
for output_grad in output_grads:
if output_grad is None:
all_valid = False
else:
any_valid = True
if not any_valid:
input_grads = [None] * len(indices)
else:
if not all_valid:
assert self.output_shapes is not None
from ...functional import zeros
output_grads = [
zeros(self.output_shapes[i]) if grad is None else grad
for i, grad in enumerate(output_grads)
]
self = None
encoded_input_grads = apply(
backward_fn(), *inputs, *output_grads
)
input_grads = [
encoded_input_grads[i] if i is not None else None
for i in indices
......
......@@ -896,7 +896,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor:
@lru_cache(maxsize=None)
def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None):
def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None):
@subgraph_fn(
"LeakyReLU",
dtype=dtype,
......@@ -925,7 +925,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
Refer to :class:`~.LeakyReLU` for more information.
"""
leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
(oup,) = leakyReLU(inp)
return oup
......@@ -1399,7 +1399,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
f("fma3", input, inv_var_wt,
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)
return (outvar, channel_mean, channel_var), (True, True, True)
@subgraph("SyncBnStage1Inference", dtype, device, 6)
def syncbn_stage1_inference(inputs, f, c):
......@@ -1509,7 +1509,7 @@ def sync_batch_norm(
"""
_eps_mode = eps_mode.lower()
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
if _eps_mode == "additive" and not (is_distributed() and training):
if _eps_mode == "additive" and not (is_distributed() or training):
return batch_norm(
inp,
running_mean,
......
......@@ -121,13 +121,13 @@ void GradKeyWrapper::enter() {
m_key = m_transformation->key();
m_key->name(m_name);
grad_key_map[m_key] = this;
TransformationManager::get_instance().register_at<TransformationManager::Grad>(
m_transformation);
m_transformation_guard =
TransformationManager::get_instance()
.register_at<TransformationManager::Grad>(m_transformation);
}
void GradKeyWrapper::exit() {
TransformationManager::get_instance().unregister<TransformationManager::Grad>(
m_transformation);
m_transformation_guard.reset();
grad_key_map.erase(m_key);
m_key = {};
m_transformation.reset();
......
......@@ -29,6 +29,7 @@ struct GradKeyWrapper : NonCopyableObj {
std::string m_name;
std::shared_ptr<GradKey> m_key;
std::shared_ptr<GradTransformation> m_transformation;
std::unique_ptr<CleanupGuard<>> m_transformation_guard;
GradKeyWrapper();
......
......@@ -449,15 +449,24 @@ void init_tensor(py::module m) {
interpreter::Interpreter::inst().create_channel())
->get();
interpreter_for_py = channel;
transformations.register_at<Segment::Eval>(
MGB_MARK_USED_VAR(
transformations
.register_at<Segment::Eval>(
std::make_shared<InterpreterTransformation>(
std::shared_ptr<Channel>(channel, [](Channel*) {})));
transformations.register_at<Segment::Scalar>(
std::make_shared<ScalarTransformation>());
transformations.register_at<Segment::DTypePromote>(
std::make_shared<DTypePromoteTransformation>());
transformations.register_at<Segment::DimExpansion>(
std::make_shared<DimExpansionTransformation>());
std::shared_ptr<Channel>(channel, [](Channel*) {})))
.release());
MGB_MARK_USED_VAR(transformations
.register_at<Segment::Scalar>(
std::make_shared<ScalarTransformation>())
.release());
MGB_MARK_USED_VAR(transformations
.register_at<Segment::DTypePromote>(
std::make_shared<DTypePromoteTransformation>())
.release());
MGB_MARK_USED_VAR(transformations
.register_at<Segment::DimExpansion>(
std::make_shared<DimExpansionTransformation>())
.release());
static py::exception<interpreter::AsyncError> py_async_error(
m, "AsyncError", PyExc_RuntimeError);
......@@ -681,6 +690,9 @@ void init_tensor(py::module m) {
std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
std::optional<TraceResult> trace_result;
std::function<bool(py::object, py::object)> array_comparator;
std::unique_ptr<CleanupGuard<>> tracing_guard;
std::unique_ptr<CleanupGuard<>> compiled_guard;
std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
bool compare_value(ValueRef lhs, ValueRef rhs) {
auto lvalue = lhs.cast_ref<HostValue>();
......@@ -730,12 +742,15 @@ void init_tensor(py::module m) {
std::make_shared<GraphProfiler>(&current_graph));
}
}
compiled_guard =
transformations.register_at<Segment::Trace>(self.compiled);
// start execute because InputCallback depends
self.compiled->execute();
} else if (self.tracing) {
tracing_guard =
transformations.register_at<Segment::Trace>(self.tracing);
if (self.lazy_eval) {
lazy_eval_guard =
transformations.register_at<Segment::Eval>(self.lazy_eval);
}
} else {
......@@ -746,16 +761,16 @@ void init_tensor(py::module m) {
void exit() {
auto& self = *this;
if (self.tracing) {
transformations.unregister<Segment::Trace>(self.tracing);
tracing_guard.reset();
self.trace_result = self.tracing->get_result();
self.tracing.reset();
if (self.lazy_eval) {
auto lazy_eval = std::move(self.lazy_eval);
transformations.unregister<Segment::Eval>(lazy_eval);
lazy_eval_guard.reset();
lazy_eval->check_exception();
}
} else if (self.compiled) {
transformations.unregister<Segment::Trace>(self.compiled);
compiled_guard.reset();
self.compiled->wait();
} else {
mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
......@@ -829,16 +844,18 @@ void init_tensor(py::module m) {
[](Trace& self) {
mgb_assert(bool(self.tracing) ^ bool(self.compiled));
if (self.tracing) {
transformations.unregister<Segment::Trace>(self.tracing);
self.tracing_guard.reset();
} else if (self.compiled) {
transformations.unregister<Segment::Trace>(self.compiled);
self.compiled_guard.reset();
}
})
.def("end_excluded_region", [](Trace& self) {
mgb_assert(bool(self.tracing) ^ bool(self.compiled));
if (self.tracing) {
self.tracing_guard =
transformations.register_at<Segment::Trace>(self.tracing);
} else if (self.compiled) {
self.compiled_guard =
transformations.register_at<Segment::Trace>(self.compiled);
}
});
......@@ -900,11 +917,8 @@ void init_tensor(py::module m) {
GradKeyWrapper::get(output.cast<GradKeyValue>())));
});
m.def("set_grad", [](py::object py_key, py::function backward_fn,
std::vector<py::object> inputs,
m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
std::vector<py::object> outputs) {
mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr()));
auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst();
GenericFunction generic_backward_fn =
[backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
py::list output_grad_tws;
......@@ -937,8 +951,8 @@ void init_tensor(py::module m) {
values[i + inputs.size()] =
outputs[i].cast<TensorWrapper>().m_tensor->data();
}
auto wrapped_output_values = imperative::apply(
SetGrad(key->m_key, generic_backward_fn, inputs.size()), values);
auto wrapped_output_values =
imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
std::vector<py::object> wrapped_outputs;
mgb_assert(wrapped_output_values.size() == outputs.size());
for (auto&& output_value : wrapped_output_values) {
......@@ -956,8 +970,10 @@ void init_tensor(py::module m) {
mgb_assert(module_trace_hook);
module_trace_transformation =
std::make_shared<ModuleTraceTransformation>(module_trace_hook);
transformations.register_at<Segment::ModuleTrace>(
module_trace_transformation);
MGB_MARK_USED_VAR(transformations
.register_at<Segment::ModuleTrace>(
module_trace_transformation)
.release());
}
return module_trace_transformation;
};
......
......@@ -18,11 +18,13 @@
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/value.h"
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative::python {
struct TransformationManager {
public:
enum Segment {
ModuleTrace,
DTypePromote,
......@@ -35,8 +37,21 @@ struct TransformationManager {
std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments;
private:
template <Segment segment>
void unregister(std::shared_ptr<Transformation> transformation) noexcept {
mgb_assert(segment < segments.size());
auto iter = std::find(
segments[segment].begin(), segments[segment].end(), transformation);
mgb_assert(iter != segments[segment].end());
transformation->unregister();
segments[segment].erase(iter);
}
public:
template <Segment segment>
void register_at(std::shared_ptr<Transformation> transformation) {
[[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at(
std::shared_ptr<Transformation> transformation) {
mgb_assert(segment < segments.size());
std::shared_ptr<Transformation> next;
for (size_t i = segment; i < segments.size(); ++i) {
......@@ -51,16 +66,8 @@ struct TransformationManager {
transformation->register_at(next->pos());
}
segments[segment].push_back(transformation);
}
template <Segment segment>
void unregister(std::shared_ptr<Transformation> transformation) noexcept {
mgb_assert(segment < segments.size());
auto iter = std::find(
segments[segment].begin(), segments[segment].end(), transformation);
mgb_assert(iter != segments[segment].end());
transformation->unregister();
segments[segment].erase(iter);
return std::make_unique<CleanupGuard<>>(
[this, transformation]() { unregister<segment>(transformation); });
}
static TransformationManager& get_instance() {
......
......@@ -452,6 +452,8 @@ def test_2nd_grad_with_custom_gradient():
return y
def backward(self, dy):
if dy is None:
return None
dx = -MySin()(self.inp) * dy
return dx
......
......@@ -14,6 +14,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad
......
......@@ -318,3 +318,41 @@ def test_throw_on_non_tensor_argument():
func = NonTensorArg()
with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"):
func(x, 1)
def test_multiple_grad():
data_shape = (9, 2, 6)
av = np.random.random(data_shape).astype(np.float32)
class MulFunc(Function):
def forward(self, a):
self.a = a
return a * 10
def backward(self, grad_o):
return grad_o * 20
class Simple(Module):
def __init__(self, a):
super().__init__()
self.a = Parameter(a, dtype=np.float32)
self.layer1 = MulFunc()
def forward(self):
x = self.layer1(self.a)
return x
net = Simple(av)
gm = ad.GradManager().attach(net.parameters())
gm2 = ad.GradManager().attach(net.parameters())
opt = optimizer.SGD(net.parameters(), lr=1.0)
opt.clear_grad()
with gm:
with gm2:
loss = net()
gm.backward(loss.sum())
opt.step()
np.testing.assert_almost_equal(loss.numpy(), (av * 10))
np.testing.assert_almost_equal(net.a.numpy(), (av - 20))
......@@ -109,3 +109,46 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level,
_assert_allclose(out1.numpy(), out2.numpy())
_assert_allclose(grad1.numpy(), grad2.numpy())
@functools.lru_cache(maxsize=None)
def _get_mul_fn(dtype, device):
@subgraph_fn(
"Mul",
dtype=dtype,
device=device,
nr_inputs=2,
gopt_level=None,
jit_fusion=False,
custom_grad=True,
)
def mul(inputs, f, c):
x, y = inputs[0:2]
z = f("*", x, y)
(dz,) = yield (z,)
dx = f("*", dz, y)
dy = f("*", dz, x)
yield (dx, dy)
return mul
def test_subgraph_jit_backward():
x_np = np.random.rand(3, 4, 5).astype("float32")
x1 = megengine.Tensor(x_np)
x2 = megengine.Tensor(x_np)
mul = _get_mul_fn(x1.dtype, x1.device)
gm = GradManager()
gm.attach([x1, x2])
with gm:
y1 = x1 * x1
y2 = mul(x2, x2)
gm.backward(y1)
with gm:
y1 = x1 * x1
y2 = mul(x2, x2)
gm.backward(y1 + y2)
with gm:
y1 = x1 * x1
y2 = mul(x2, x2)
gm.backward(y2)
......@@ -18,18 +18,44 @@
namespace mgb {
namespace imperative {
namespace {
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) {
ValueRefList apply_release(const Operator& op, Span<ValueRef> inputs) {
auto& context = Transformation::get_context();
size_t& depth = context.next_transformation;
// TODO: add fallback transformation
bool fallback = depth >= context.transformations.size();
if (mgb_unlikely(fallback)) {
return op.fallback(inputs);
} else {
mgb_assert(depth < context.transformations.size());
auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }};
return transformation.apply_transformation(op, inputs);
}
MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span<ValueRef> inputs) {
auto& context = Transformation::get_context();
size_t& depth = context.next_transformation;
mgb_assert(depth < context.transformations.size());
static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1;
mgb_log_debug(
"%s apply %s to %s", prefix, op.to_string().c_str(),
imperative::to_string(inputs).c_str());
ValueRefList result;
auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }};
result = transformation.apply_transformation(op, inputs);
mgb_log_debug(
"%s returns %s", prefix,
imperative::to_string(Span<ValueRef>(result)).c_str());
return result;
}
} // namespace
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) {
static bool debug = MGB_GETENV("MGE_LOG_OP_DISPATCH");
if (mgb_unlikely(debug)) {
return apply_debug(op, inputs);
} else {
return apply_release(op, inputs);
}
}
......
......@@ -106,7 +106,8 @@ EncodedSubgraph OpDef::make_forward_graph(
}
std::string OpDef::to_string() const {
std::string builder = trait()->make_name(*this) + "{";
std::string builder = trait()->name;
builder += "{";
for (auto&& [name, value] : props(*this)) {
builder += name;
builder += ": ";
......@@ -196,7 +197,7 @@ std::string Subgraph::repr() const {
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->make_name();
buf << op->to_string();
}
for (size_t i : ins) {
buf << " ";
......
......@@ -11,13 +11,94 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "../op_trait.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/tensor.h"
namespace mgb {
namespace imperative {
namespace {
EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) {
Subgraph::Builder<LogicalTensorDesc> builder{
[](std::shared_ptr<OpDef> op, SmallVector<LogicalTensorDesc> inputs,
size_t nr_outputs) {
auto [outputs, validated] =
OpDef::infer_output_attrs_fallible(*op, inputs);
mgb_assert(outputs.size() == nr_outputs, "nr_outputs mismatch");
return outputs;
}};
auto f = [&](auto&& op, auto... args) {
return builder.write_expr(
op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0];
};
auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0));
auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM));
auto sub = Elemwise::make(Elemwise::Mode::SUB);
auto mul = Elemwise::make(Elemwise::Mode::MUL);
auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV);
auto floor_div = Elemwise::make(Elemwise::Mode::FLOOR_DIV);
auto broadcast = Broadcast::make();
auto c = [&](TensorPtr tensor, DType dtype) {
auto result = builder.write_constant(
tensor, {TensorLayout{tensor->dtype()}, tensor->comp_node()});
if (tensor->dtype() != dtype) {
result = f(TypeCvt::make(dtype), result);
}
return result;
};
auto ci = [&](megdnn::dt_int32 value) {
return c(Tensor::make_scalar(DTypeScalar(value), device), dtype::Int32());
};
auto cf = [&](megdnn::dt_float32 value) {
return c(Tensor::make_scalar(DTypeScalar(value), device), dtype);
};
auto desc = LogicalTensorDesc{TensorLayout{dtype}, device};
auto x = builder.write_input(desc);
auto y_grad = builder.write_input(desc);
auto save_mean = builder.write_input(desc);
auto save_invstd = builder.write_input(desc);
auto weight = builder.write_input(desc);
auto reserved = builder.write_input(desc);
MGB_MARK_USED_VAR(reserved);
// assert x.ndim == 4
auto input_shape = f(GetVarShape::make(), x);
auto channels = f(GetVarShape::make(1), x);
auto reduce_shape = f(Concat::make(0, device), ci(1), channels, ci(1), ci(1));
auto input_elems = f(prod, input_shape);
auto reduce_size = f(floor_div, input_elems, channels);
auto reduce_size_f = f(TypeCvt::make(dtype), reduce_size);
auto mean = f(broadcast, save_mean, input_shape);
auto invstd = save_invstd;
auto norm = f(div, cf(1), reduce_size_f);
auto output_grad_sum = f(sum, y_grad, reduce_shape);
auto dot_p = f(sum, f(mul, y_grad, f(sub, x, mean)), reduce_shape);
auto mean_grad = f(broadcast, f(mul, output_grad_sum, norm), input_shape);
auto proj_scale =
f(broadcast, f(mul, f(mul, dot_p, norm), f(mul, invstd, invstd)),
input_shape);
auto grad_scale = f(
mul, f(broadcast, invstd, input_shape), f(broadcast, weight, input_shape));
auto proj = f(mul, f(sub, x, mean), proj_scale);
auto x_grad = f(mul, f(sub, f(sub, y_grad, proj), mean_grad), grad_scale);
auto weight_grad = f(mul, dot_p, invstd);
auto bias_grad = output_grad_sum;
builder.add_outputs({weight_grad, bias_grad, x_grad});
auto bn_backward = builder.encode();
return bn_backward;
}
namespace bn {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::BatchNorm>();
return BatchNorm::make(node->param());
......@@ -72,8 +153,60 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // anonymous namespace
} // namespace bn
namespace bn_backward {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::BatchNormBackward>();
return BatchNormBackward::make(node->param());
}
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto& op = def.cast_final_safe<BatchNormBackward>();
cg::SymbolVar x, y_grad, save_mean, save_variance, weight, reserve;
x = inputs[0];
y_grad = inputs[1];
save_mean = inputs[2];
save_variance = inputs[3];
weight = inputs[4];
if (inputs.size() == 6) {
reserve = inputs[5];
}
return opr::BatchNormBackward::make(
x, y_grad, save_mean, save_variance, weight, reserve, op.param())[0]
.node()
->owner_opr()
->usable_output();
}
EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
def.cast_final_safe<BatchNormBackward>();
size_t nr_inputs = 6;
size_t nr_outputs = 3;
mgb_assert(inputs.size() == nr_inputs);
mgb_assert(input_requires_grad.size() == nr_inputs);
mgb_assert(output_has_grad.size() == nr_outputs);
auto dtype = inputs[0].layout.dtype;
auto device = inputs[0].comp_node;
auto bn_backward = generate_batchnorm_backward_graph(dtype, device);
auto bn_double_backward = subgraph_detail::make_backward_graph_from_forward(
bn_backward, inputs, input_requires_grad, output_has_grad);
return bn_double_backward;
}
OP_TRAIT_REG(BatchNormBackward, BatchNormBackward, opr::BatchNormBackward)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.make_backward_graph(make_backward_graph)
.fallback();
} // namespace bn_backward
} // anonymous namespace
} // namespace imperative
} // namespace mgb
......
......@@ -762,7 +762,9 @@ EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
return {};
return OpDef::make_backward_graph(
*def.cast_final_safe<JITFusionOp>().op, inputs, input_requires_grad,
output_has_grad);
}
OP_TRAIT_REG(JITFusionOp, JITFusionOp)
......
......@@ -96,10 +96,11 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint(
return res;
}
static EncodedSubgraph make_backward_graph_from_forward(
EncodedSubgraph make_backward_graph_from_forward(
const EncodedSubgraph& forward_graph,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad, EncodedSubgraph forward_graph) {
const SmallVector<bool>& output_has_grad) {
using namespace std::placeholders;
using var_t = Subgraph::var_t;
using vars_t = Subgraph::vars_t;
......@@ -179,7 +180,7 @@ EncodedSubgraph make_backward_graph(
const SmallVector<bool>& output_has_grad) {
auto forward_graph = OpDef::make_forward_graph(def, inputs);
return make_backward_graph_from_forward(
inputs, input_requires_grad, output_has_grad, forward_graph);
forward_graph, inputs, input_requires_grad, output_has_grad);
}
} // namespace subgraph_detail
......
......@@ -139,7 +139,7 @@ ValueRefList InterpreterTransformation::apply_transformation(
return {ValueRef()};
}
} else {
return imperative::apply(op, inputs);
return op.fallback(inputs);
}
}
......
......@@ -62,7 +62,8 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs)
: backward_graph(backward_graph),
output_mask_offset(inputs.size()),
grad_mask_offset(inputs.size() + outputs.size()) {
grad_mask_offset(inputs.size() + outputs.size()),
op(op) {
auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size());
size_t count = std::count_if(
......@@ -92,6 +93,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
closure.push_back(outputs[i]);
}
}
if (outputs.size() > 1) {
output_descs.reserve(outputs.size());
for (auto&& output : outputs) {
auto symbolic_shape = imperative::apply(*GetVarShape::make(), output)[0];
output_descs.push_back({symbolic_shape, output.dtype(), output.device()});
}
}
}
void BackwardGraphWithClosure::operator()(
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
......@@ -100,23 +108,46 @@ void BackwardGraphWithClosure::operator()(
for (auto&& value : closure) {
args[nargs++] = value;
}
bool null_grad = false;
size_t null_grad = 0;
size_t valid_grad = 0;
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (grads[i]) {
mgb_assert(!null_grad, "null_grad");
valid_grad++;
args[nargs++] = grads[i];
} else {
null_grad = true;
null_grad++;
nargs++;
}
}
}
if (null_grad) {
if (valid_grad == 0) {
return;
}
auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs));
SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()};
igrads_.clear();
if (null_grad > 0) {
auto zeros_like = [](const OutputDesc& desc) {
HostTensorStorage storage(*desc.device);
storage.ensure_size(desc.dtype->size());
std::memset(storage.ptr(), 0, desc.dtype->size());
auto t = imperative::apply(
CreateTensor(
CreateTensor::Unique, *desc.device, *desc.dtype,
ValueShape()),
HostStorage::make(storage))[0];
auto res = imperative::apply(*Broadcast::make(), t, desc.shape)[0];
return res;
};
nargs = closure.size();
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (!grads[i]) {
args[nargs] = zeros_like(output_descs[i]);
}
nargs++;
}
}
}
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs));
auto&& iter = igrads.begin();
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) {
if (p) {
......@@ -221,11 +252,13 @@ void GradKey::backward() {
if (!dest) {
continue;
}
if (!dest.m_producer_record.next && dest->callback && dest->m_grad) {
if (!dest.m_producer_record.next && dest->callback) {
// I'm the last grad producer, invoke callback
if (dest->m_grad) {
dest->callback(dest->m_grad);
}
}
}
grad_fn->clear();
}
tape.clear();
......@@ -394,16 +427,22 @@ ValueRefList GradTransformation::apply_transformation(
return imperative::apply(op, inputs);
}
if (auto* attach_grad = op.as<AttachGrad>()) {
if (!has_key(attach_grad->key())) {
auto& tensor = inputs[0];
if (auto&& grad_value = tensor.as_ref(m_value_type)) {
mgb_assert(!has_key(attach_grad->key()));
auto output = fallback()[0];
return record_grad(m_value_type.make(output, m_key, grad_value->slot()));
} else if (!has_key(attach_grad->key())) {
return fallback();
}
auto tensor = inputs[0];
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>();
} else {
GenericFunction callback =
(GenericFunction&)inputs[1].cast<FunctionValue>();
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) {
auto ret = callback({&grad, 1});
assert(ret.empty());
});
return {record_grad(output)};
}
} else if (auto* grad_backward = op.as<GradBackward>()) {
if (!has_key(grad_backward->key())) {
return fallback();
......@@ -431,10 +470,10 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert(inputs.size() > nr_inputs);
size_t nr_outputs = inputs.size() - nr_inputs;
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs};
backward.m_input_has_grad = SmallVector(nr_inputs, true);
backward.m_output_attrs =
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true});
auto outputs_ = fallback();
backward.m_input_has_grad.resize(nr_inputs, true);
backward.m_output_attrs.resize(
nr_outputs, CustomBackward::OutputAttr{true, true});
backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) {
auto result = fn(inputs);
return SmallVector<ValueRef>(result.begin(), result.end());
......
......@@ -31,6 +31,7 @@ class Subgraph::Builder {
using infer_fn_t = std::function<descs_t(op_t, descs_t, size_t)>;
using encoded_graph_t = EncodedSubgraph;
using var_map_t = std::unordered_map<var_t, var_t>;
using mask_t = SmallVector<bool>;
vars_t m_inputs;
SmallVector<std::pair<var_t, TensorPtr>> m_constants;
vars_t m_outputs;
......@@ -94,6 +95,7 @@ public:
descs_t get_descs(vars_t vars) {
descs_t descs;
for (auto&& var : vars) {
mgb_assert(var, "invalid var");
descs.push_back(get_desc(var));
}
return descs;
......
......@@ -38,7 +38,6 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> {
std::shared_ptr<OpDef> op;
SmallVector<CompNode> devices;
SmallVector<DType> dtypes;
EncodedSubgraph graph;
ShapeInfer() = default;
ShapeInfer(
std::shared_ptr<OpDef> op, SmallVector<CompNode> devices,
......
......@@ -39,6 +39,11 @@ EncodedSubgraph make_backward_graph(
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs);
EncodedSubgraph make_backward_graph_from_forward(
const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
} // namespace subgraph_detail
} // namespace imperative
} // namespace mgb
......@@ -29,6 +29,15 @@ struct BackwardGraphWithClosure {
SmallVector<ValueRef> closure;
size_t output_mask_offset;
size_t grad_mask_offset;
std::shared_ptr<OpDef> op;
struct OutputDesc {
ValueRef shape;
DTypeValue::ref_t dtype;
CompNodeValue::ref_t device;
};
SmallVector<OutputDesc> output_descs;
BackwardGraphWithClosure(
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
......@@ -356,20 +365,22 @@ public:
class SetGrad : public OperatorImpl<SetGrad> {
private:
std::shared_ptr<GradKey> m_key;
GenericFunction m_grad_fn;
size_t m_nr_inputs;
public:
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs)
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
SetGrad(GenericFunction grad_fn, size_t nr_inputs)
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
GenericFunction grad_fn() const { return m_grad_fn; }
size_t nr_inputs() const { return m_nr_inputs; }
std::string to_string() const override {
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str());
std::string to_string() const override { return ssprintf("SetGradValue{}"); }
ValueRefList fallback(Span<ValueRef> inputs) const override {
auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs);
return {outputs.begin(), outputs.end()};
}
};
......
......@@ -15,12 +15,14 @@
#include <memory>
#include <sstream>
#include "megbrain/utils/metahelper.h"
namespace mgb {
namespace imperative {
template <typename T>
class CleanupGuard {
template <typename T = std::function<void()>>
class CleanupGuard : public NonCopyableObj {
private:
T m_callback;
......
......@@ -89,6 +89,8 @@ def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWin
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;
def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>;
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册