Created by: chenwhql
In order to verify whether the loss of single card and multi-card are almost equal, add two parallel dygraph unittests:
- test_parallel_dygraph_transformer.py
- test_parallel_dygraph_resnet.py
test data:
batch 1-5 ======= single card loss : multi card loss =======
transformer
- original optimizer in model:
======= [9.274224] : [9.266623] =======
======= [9.26067] : [9.261173] =======
======= [9.256851] : [9.262501] =======
======= [9.241505] : [9.231884] =======
======= [9.283638] : [9.267639] =======
- simple optimzier (SGD)
======= [9.274224] : [9.266623] =======
======= [9.25505] : [9.255424] =======
======= [9.246191] : [9.251655] =======
======= [9.225809] : [9.216118] =======
======= [9.263782] : [9.247366] =======
resnet
- original optimizer in model: (Failed!!!)
======= [4.8580146] : [4.8252993] =======
======= [29.669163] : [25.71742] =======
- simple optimzier (SGD)
======= [4.8580146] : [4.8252993] =======
======= [4.785389] : [4.8186526] =======
======= [4.756192] : [4.7837706] =======
======= [4.742502] : [4.7632484] =======
======= [4.759595] : [4.7871733] =======
Compare gradients of single-card & allreduce multi-cards
- set
self.assertTrue(np.allclose(grad, tr0_losses[1][name], rtol=1e-01))
https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html#numpy.allclose
-
rtol=1e-01
is very relaxed conditions- mnist: success
- transformer & resnet: failed!!! ( Is this reasonable? )
- use
np.testing.assert_array_almost_equal(grad, tr0_losses[1][name])
https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_array_almost_equal.html
-
this method may not be suitable for small values
-
abs(desired-actual) < 1.5 * 10**(-decimal)
, default decimal=6 - example: 0.00000012 & 0.00000013 will pass this test
-
-
but if we can not pass this test, that may indicate a more serious problem!!!
-
mnist: also success
-
transformer & resnet: also failed!!!
- transformer test result:
======= [9.274224] : [9.266623] =======
======= gradients: transformer/TransFormer_0/WrapDecoderLayer_0/DecoderLayer_0/DecoderSubLayer_5/MultiHeadAttentionLayer_1/FC_3.w_0 =======
F
======================================================================
FAIL: test_transformer (__main__.TestParallelDygraphTransformer)
----------------------------------------------------------------------
AssertionError:
Arrays are not almost equal to 6 decimals
Mismatch: 95.7%
Max absolute difference: 0.00025227
Max relative difference: 32303.217
x: array([[ 1.729866e-04, 2.968760e-04, -3.134187e-04, ..., -1.848387e-05,
-1.290701e-04, 1.529645e-04],
[-3.651518e-05, -7.046186e-05, 2.662834e-05, ..., 1.388673e-05,...
y: array([[ 1.933558e-04, 3.038431e-04, -2.942407e-04, ..., 1.316451e-05,
-1.025933e-04, 1.219973e-04],
[-1.290588e-05, -5.520150e-05, 1.526501e-05, ..., 1.666747e-05,...
- resnet test result:
======= [4.8580146] : [4.8252993] =======
======= gradients: resnet/ResNet_0/BottleneckBlock_13/ConvBNLayer_1/BatchNorm_0.b_0 =======
F
======================================================================
FAIL: test_transformer (__main__.TestParallelDygraphResNet)
----------------------------------------------------------------------
AssertionError:
Arrays are not almost equal to 6 decimals
Mismatch: 100%
Max absolute difference: 0.00427286
Max relative difference: 2213.9375
x: array([-1.131162e-03, 1.687959e-03, 6.250179e-04, -1.508491e-03,
7.233161e-04, 1.471173e-03, 4.064720e-04, 1.250255e-03,
9.171964e-04, -1.898815e-05, 1.040716e-03, -8.769262e-04,...
y: array([-7.811059e-04, 1.366269e-03, -1.232496e-04, -4.984437e-04,
9.567069e-06, 1.042704e-03, -4.914380e-04, -4.864433e-04,
-1.015805e-04, 8.403676e-04, 9.958970e-04, 8.907004e-04,...
SetDygraphSortedGradient
- transformer
======= [9.274224] : [9.266623] =======
======= gradients: transformer/TransFormer_0/WrapDecoderLayer_0/DecoderLayer_0/DecoderSubLayer_1/MultiHeadAttentionLayer_1/FC_0.w_0 =======
F
======================================================================
FAIL: test_transformer (__main__.TestParallelDygraphTransformer)
----------------------------------------------------------------------
AssertionError:
Arrays are not almost equal to 6 decimals
Mismatch: 82.9%
Max absolute difference: 4.816447e-05
Max relative difference: 146327.75
x: array([[-2.436040e-06, 3.248527e-06, -1.358558e-09, ..., -1.128117e-06,
-1.294629e-06, -3.612499e-06],
[-5.579529e-06, 2.376522e-06, 9.744054e-06, ..., 4.107386e-06,...
y: array([[ 9.107453e-09, -1.423047e-06, 2.923735e-06, ..., -1.055777e-06,
-4.661979e-06, -8.509627e-06],
[ 2.073039e-07, 4.023343e-07, -4.841148e-06, ..., 3.270911e-06,...
- resnet
======= [4.8580146] : [4.8252993] =======
======= gradients: resnet/ResNet_0/BottleneckBlock_13/ConvBNLayer_2/BatchNorm_0.w_0 =======
F
======================================================================
FAIL: test_transformer (__main__.TestParallelDygraphResNet)
----------------------------------------------------------------------
AssertionError:
Arrays are not almost equal to 6 decimals
Mismatch: 99.7%
Max absolute difference: 0.00212005
Max relative difference: 2589.7422
x: array([-2.287483e-05, -1.007247e-04, 4.445101e-04, ..., -6.187817e-04,
3.487974e-04, 2.720110e-03], dtype=float32)
y: array([-3.022037e-04, 9.708891e-05, -2.009416e-06, ..., -1.034389e-03,
-1.033636e-03, 1.144969e-03], dtype=float32)
ResNet: Replace BatchNorm with LayerNorm
- almost equal
Arrays are not almost equal to 6 decimals
Mismatch: 0.895%
Max absolute difference: 9.115698e-06
Max relative difference: nan
x: array([-3.322816e-07, 1.172189e-06, 6.136031e-07, ..., -4.012745e-06,
-6.015944e-07, -2.520906e-06], dtype=float32)
y: array([-3.327114e-07, 1.163721e-06, 6.368654e-07, ..., -3.860220e-06,
-5.732491e-07, -2.468568e-06], dtype=float32)
- remove sorted gradient strategy (almost equal)
- Correctness verified!!!
======= [5.2025075] : [5.202508] =======
======= [5.070281] : [5.070281] =======
======= [4.8536005] : [4.8536043] =======
======= [4.9283624] : [4.928356] =======
======= [5.007301] : [5.007288] =======
Transformer: remove dropout & set weight to constant
- Arrays are almost equal to 6 decimals
- Correctness verified!!!
======= [9.255363] : [9.255363] =======
======= [9.262659] : [9.262657] =======
======= [9.2535305] : [9.2535305] =======
======= [9.262439] : [9.26244] =======
======= [9.252597] : [9.252596] =======