Determine the life time of a variable gradient
Created by: tonyyang-svail
A variable should be able to access its gradient. Ideally, the access should be a smart pointer.
Question: should a variable hold
- a
shared_ptr
to its gradient? - a
weak_ptr
to its gradient?
Case 0: forward
while (true) {
reset_global_tape();
loss = model.Forward(data);
}
Case 1: forward, backward
while (true) {
reset_global_tape();
loss = model.Forward(data);
loss.Backward();
}
Case 2: forward, backward, optimize
while (true) {
reset_global_tape();
loss = model.Forward(data);
loss.Backward();
sgd(model.Params());
}
Case 3: Release Memory on Backward
We can release memory aggressively. During backward, we can delete the OpHandle once we have finished its backward. Since all the variable is managed by smart pointers, the memory is automatically released when its ref_count goes to 0.
In the current implementation, shared_ptr
can't solve case 1, while weak_prt
can't solve case 3.