提交 2f3d185d 编写于 作者: M Megvii Engine Team

fix(mgb): register invalid grad for AddUpdate

GitOrigin-RevId: f9bbf570dcdc2aca714f3d9f3b7831c82bc7c3a9
上级 b3b14fdf
......@@ -947,6 +947,11 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
MGB_IMPL_OPR_GRAD(AddUpdate) {
// actually valid, just not implemented
return InvalidGrad::make(opr, wrt_idx);
}
/* =========================== Reduce =========================== */
class Reduce::KernScheduler {
......
......@@ -372,7 +372,7 @@ TEST(TestOprBasicArith, AddUpdateVolatile) {
for (size_t i = 0; i < SIZE * 2; i ++) {
MGB_ASSERT_FLOAT_EQ(expect(i), z[i]);
}
mgb_assert(host_sub.shape().total_nr_elems() == 4 &&
mgb_assert(host_sub.shape().total_nr_elems() == 4 &&
host_sub.layout().is_contiguous());
for (size_t i = 0; i < 4; ++ i) {
size_t idx = i * (SIZE >> 1);
......@@ -390,6 +390,27 @@ TEST(TestOprBasicArith, AddUpdateVolatile) {
}
}
// AddUpdate in gradient path but no gradient flows through it
TEST(TestOprBasicArith, AddUpdateInGradPath) {
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto dest = opr::SharedDeviceTensor::make(*graph, *gen({42}));
auto host_x = gen({42});
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
// delta depends on x, but not differentiable wrt x
// a invalid grad is registered for AddUpdate to fix this case
auto delta = opr::VirtualDep::make({opr::SetGrad::make(x, nullptr), x});
auto updated = opr::AddUpdate::make(dest, delta);
auto y = opr::reduce_ax_sum(updated + x, 0);
auto dx = cg::grad(y, x);
HostTensorND host_dx;
auto func = graph->compile({make_callback_copy(dx, host_dx)});
func->execute();
for (size_t i = 0; i < host_dx.shape(0); ++i) {
MGB_ASSERT_FLOAT_EQ(host_dx.ptr<float>()[i], 1.f);
}
}
TEST(TestOprBasicArith, MemFwd) {
constexpr size_t SIZE = 12321;
HostTensorGenerator<> gen;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册