提交 02455941 编写于 作者: M Megvii Engine Team 提交者: huangxinda

test(autograd): test jvp emulated by 2nd grad

GitOrigin-RevId: 47114fcd99f53b335c1615896370dd6e45dd3693
上级 8480302d
......@@ -163,7 +163,9 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
mgb_assert(ctx.nargs == 1);
if (ctx.nargs != 1) {
throw GradRuleFallback();
}
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
input_shapes[0] = get_shape(ctx.args[0]);
......
......@@ -349,3 +349,29 @@ def test_grad_manager_visibility_by_order():
gm.backward(x.grad)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
@pytest.mark.require_higher_order_directive()
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1])
def test_emulate_forward_mode_with_reverse_mode(target):
def jvp(inp, expr):
with GradManager() as gm:
with GradManager().attach([inp]) as gm2:
oup = expr(inp)
oup_grad = F.zeros_like(oup)
gm.attach(oup_grad)
gm2.backward(oup, oup_grad)
gm.backward(inp.grad)
return oup, oup_grad.grad
def fake_jvp(inp, expr):
delta = 0.001
return expr(inp), (expr(inp + delta) - expr(inp - delta)) / (2 * delta)
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
y, dy = jvp(x, target)
y1, dy1 = fake_jvp(x, target)
np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5)
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册