Created by: JepsonWong
Now paddle don't support gradient accumulate in dygraph. User may not have much gpu memory and cannot implement training with large batch data, to achieve training with large batch samples, we should support gradient accumulate in dygraph.
- User can use mini-batch data to backward multiple times to achieve the effect of large batch data backward.
sample code:
model.clear_gradient() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.minimize() # Now we can do an optimizer step
model.clear_gradient() # Reset gradients tensors
transformer mode test:
-
batch_size = 32, don't use gradient accumulate. pass : 0 finished, validation avg loss: [4.2705855] pass : 1 finished, validation avg loss: [3.3497431] pass : 2 finished, validation avg loss: [3.0039177] pass : 3 finished, validation avg loss: [2.88103] pass : 4 finished, validation avg loss: [2.8394444] pass : 5 finished, validation avg loss: [2.8676476] pass : 6 finished, validation avg loss: [2.9263651] pass : 7 finished, validation avg loss: [2.9343238] pass : 8 finished, validation avg loss: [2.9415674] pass : 9 finished, validation avg loss: [2.8939047] pass : 10 finished, validation avg loss: [2.888354] pass : 11 finished, validation avg loss: [2.8944314] pass : 12 finished, validation avg loss: [2.942391] pass : 13 finished, validation avg loss: [2.9244676] pass : 14 finished, validation avg loss: [2.9677844] pass : 15 finished, validation avg loss: [2.9790475] pass : 16 finished, validation avg loss: [3.0021255] pass : 17 finished, validation avg loss: [3.0062253] pass : 18 finished, validation avg loss: [3.0180252] pass : 19 finished, validation avg loss: [3.0087602]
-
batch_size=16, use gradient accumulate, accumulation_steps = 2. pass : 0 finished, validation avg loss: [4.279442] pass : 1 finished, validation avg loss: [3.352506] pass : 2 finished, validation avg loss: [2.9923875] pass : 3 finished, validation avg loss: [2.8756173] pass : 4 finished, validation avg loss: [2.8708618] pass : 5 finished, validation avg loss: [2.892614] pass : 6 finished, validation avg loss: [2.8977745] pass : 7 finished, validation avg loss: [2.9026713] pass : 8 finished, validation avg loss: [2.9281192] pass : 9 finished, validation avg loss: [2.895417] pass : 10 finished, validation avg loss: [2.9202948] pass : 11 finished, validation avg loss: [2.9038296] pass : 12 finished, validation avg loss: [2.9103203] pass : 13 finished, validation avg loss: [2.9175427] pass : 14 finished, validation avg loss: [2.9222167] pass : 15 finished, validation avg loss: [2.954068] pass : 16 finished, validation avg loss: [3.0088449] pass : 17 finished, validation avg loss: [2.9913201] pass : 18 finished, validation avg loss: [2.9951591] pass : 19 finished, validation avg loss: [3.008231]