提交 58df717e 编写于 作者: M Megvii Engine Team

fix(mge/autodiff): fix attaching tensor already in gradient path

GitOrigin-RevId: da774509cabeb525ba717dbcb0ae88c3b0ad836b
上级 05186e7b
......@@ -136,6 +136,46 @@ def test_grad_with_tensor_wrapper():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
def test_wrt_intermediate_var():
x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np)
result = {}
with Grad() as grad:
grad.wrt(x, callback=lambda dx: result.update(dx=dx))
y = mul(x, x)
grad.wrt(y, callback=lambda dy: result.update(dy=dy))
z = mul(y, y)
grad(z, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(result["dx"].numpy(), 4 * x_np ** 3, decimal=6)
np.testing.assert_almost_equal(result["dy"].numpy(), 2 * (x_np ** 2), decimal=6)
@pytest.mark.parametrize("in_path", [False, True])
def test_wrt_visibility(in_path):
x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np)
def copy(x):
xx = mge.Tensor(x)
xx._reset(x)
return xx
result = {}
with Grad() as grad:
if in_path:
grad.wrt(x, callback=lambda _: None)
y = mul(x, x)
grad.wrt(copy(y), callback=lambda dy: result.update(dy=dy))
z = mul(y, y)
grad(z, mge.Tensor(np.ones_like(x_np)))
assert not result
def test_release():
def check(f):
n = 0
......
......@@ -265,20 +265,21 @@ void GradKey::backward() {
GradValue::ref_t GradKey::attach(
ValueRef tensor, std::function<void(ValueRef)> callback) {
auto grad_value = tensor.as_ref(m_value_type);
if (grad_value) {
mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists");
} else {
GradSlotPtr grad_slot;
auto& grad_fn = grad_slot.m_fn;
grad_fn = LocalPtr<GradFn>::make();
grad_fn->m_key = shared_from_this();
grad_fn->m_slots.resize(1);
grad_slot.m_index = 0;
grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot);
// always create a new grad value
GradSlotPtr grad_slot;
auto& grad_fn = grad_slot.m_fn;
grad_fn = LocalPtr<GradFn>::make();
grad_fn->m_key = shared_from_this();
grad_fn->m_slots.resize(1);
grad_fn->m_slots[0].callback = callback;
grad_slot.m_index = 0;
if (auto&& grad_value = tensor.as_ref(m_value_type)) {
grad_fn->m_backward.emplace<IdentityBackward>();
grad_fn->m_dests.push_back(grad_value->m_slot);
tensor = grad_value->m_value;
m_tape.emplace_back(grad_fn, nullptr);
}
grad_value->slot().m_fn->m_slots[0].callback = callback;
return grad_value;
return m_value_type.make(tensor, shared_from_this(), grad_slot);
}
void GradKey::freeze() {
......@@ -424,22 +425,17 @@ ValueRefList GradTransformation::apply_transformation(
return outputs;
} else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs);
}
if (auto* attach_grad = op.as<AttachGrad>()) {
auto& tensor = inputs[0];
if (auto&& grad_value = tensor.as_ref(m_value_type)) {
mgb_assert(!has_key(attach_grad->key()));
auto output = fallback()[0];
return record_grad(m_value_type.make(output, m_key, grad_value->slot()));
} else if (!has_key(attach_grad->key())) {
} else if (auto* attach_grad = op.as<AttachGrad>()) {
if (!has_key(attach_grad->key())) {
return fallback();
} else {
GenericFunction callback =
(GenericFunction&)inputs[1].cast<FunctionValue>();
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) {
auto ret = callback({&grad, 1});
assert(ret.empty());
});
auto output =
attach_grad->key()->attach(inputs[0], [callback](ValueRef grad) {
auto ret = callback({&grad, 1});
mgb_assert(ret.empty());
});
return {record_grad(output)};
}
} else if (auto* grad_backward = op.as<GradBackward>()) {
......
......@@ -83,6 +83,20 @@ public:
static BackwardRule lookup_grad_rule(Typeinfo* typeinfo);
};
struct IdentityBackward {
bool input_has_grad(size_t i) { mgb_assert(0); }
bool output_requires_grad(size_t i) { mgb_assert(0); }
template <typename F>
void operator()(Span<ValueRef> grads, F&& receiver) {
for (size_t i = 0; i < grads.size(); ++i) {
if (grads[i]) {
receiver(i, grads[i]);
}
}
}
};
class GradSlot;
class GradSlotPtr;
class GradSlotProducerPtr;
......@@ -165,7 +179,9 @@ private:
std::weak_ptr<GradKey> m_key;
SmallVector<GradSlot> m_slots;
SmallVector<GradSlotProducerPtr> m_dests;
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward;
std::variant<
std::monostate, BackwardGraphWithClosure, CustomBackward, IdentityBackward>
m_backward;
public:
void clear() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册