未验证 提交 adcfd269 编写于 作者: S Shaden Smith 提交者: GitHub

Handle actvitation checkpointing args that are None or non-tensors (#660)

Special thanks to @g-karthik for tracking this issue down.
上级 da5563a9
......@@ -373,6 +373,10 @@ class CheckpointFunction(torch.autograd.Function):
inputs = []
for i, item in enumerate(args[:-1]):
if not torch.is_tensor(item):
inputs.append(item)
continue
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(
0,
......@@ -413,7 +417,12 @@ class CheckpointFunction(torch.autograd.Function):
inputs.append(args[-1])
#just in case something funky is happening such as reuse of inputs
inputs_cuda = [item.to(cuda_device) for item in args]
inputs_cuda = []
for item in args:
if torch.is_tensor(item):
inputs_cuda.append(item.to(cuda_device))
else:
inputs_cuda.append(item)
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -439,6 +448,10 @@ class CheckpointFunction(torch.autograd.Function):
if PARTITION_ACTIVATIONS:
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not torch.is_tensor(arg):
new_args.append(arg)
continue
size = torch.tensor(arg.size())
arg.data = inp.data
......@@ -573,7 +586,14 @@ class CheckpointFunction(torch.autograd.Function):
timers.log(['backward'])
if SYNCHRONIZE:
torch.cuda.synchronize()
return (None, ) + tuple(inp.grad for inp in detached_inputs)
ret_list = [None] # first None for ctx
for inp in detached_inputs:
if torch.is_tensor(inp):
ret_list.append(inp.grad)
else:
ret_list.append(None)
return tuple(ret_list)
def checkpoint(function, *args):
......
......@@ -23,7 +23,7 @@ def _compute(module, *inputs, do_checkpoint=False):
sum(o.sum() for o in outputs if o.requires_grad).backward()
grads = [p.grad for p in module.parameters()]
input_grads = [inp.grad for inp in inputs]
input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]
return {
'outputs': outputs,
......@@ -32,6 +32,18 @@ def _compute(module, *inputs, do_checkpoint=False):
}
def _prep_inputs(*inputs):
_inputs = []
for inp in inputs:
inp = deepcopy(inp)
if torch.is_tensor(inp):
inp = inp.cuda()
_inputs.append(inp)
return tuple(_inputs)
# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
......@@ -43,11 +55,11 @@ def _test_activation_checkpoint(module, *inputs):
module.eval()
module_ = deepcopy(module)
inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
inputs_ = _prep_inputs(*inputs)
base = _compute(module_, *inputs_, do_checkpoint=False)
module_ = deepcopy(module)
inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
inputs_ = _prep_inputs(*inputs)
test = _compute(module_, *inputs_, do_checkpoint=True)
for group in base.keys():
......@@ -155,3 +167,15 @@ def test_ckpt_inputs2_outputs3(mask):
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs, mask)
class DropMaskLinear(torch.nn.Linear):
def forward(self, x, mask):
return super().forward(x)
def test_ckpt_arg_none():
module = DropMaskLinear(HIDDEN_DIM, HIDDEN_DIM)
inputs = (torch.rand(HIDDEN_DIM), None)
inputs[0].requires_grad = True
_test_activation_checkpoint(module, *inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册