From b561a05ee34bc10e9a174e81dae28c7a591420c6 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 10 Aug 2023 16:35:37 +0800 Subject: [PATCH] fix A100 fused linear grad add ut bug (#56136) --- .../test_fused_linear_param_grad_add.py | 61 ++++++++++++------- tools/gpups_test.sh | 2 +- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/test/legacy_test/test_fused_linear_param_grad_add.py b/test/legacy_test/test_fused_linear_param_grad_add.py index e707bbc41fa..762b2a99b52 100644 --- a/test/legacy_test/test_fused_linear_param_grad_add.py +++ b/test/legacy_test/test_fused_linear_param_grad_add.py @@ -54,7 +54,7 @@ def recreate(x, multi_precision): return paddle.to_tensor(x.numpy()) -def run_ground_truth(x, dy, dweight, dbias, multi_precision): +def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias): x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision) dweight_tmp = paddle.matmul( @@ -69,24 +69,35 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision): assert dweight.dtype == dweight.dtype dweight += dweight_tmp - dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) - if dbias is None: - dbias = dbias_tmp - else: - assert dbias.shape == dbias_tmp.shape - assert dbias.dtype == dbias_tmp.dtype - dbias += dbias_tmp + if has_bias: + dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) + if dbias is None: + dbias = dbias_tmp + else: + assert dbias.shape == dbias_tmp.shape + assert dbias.dtype == dbias_tmp.dtype + dbias += dbias_tmp - return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() + return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() + else: + return promote_dtype(dweight).numpy() -def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision): +def run_fused_linear_param_grad_add( + x, dy, dweight, dbias, multi_precision, has_bias +): dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add( - x, dy, dweight, dbias, multi_precision + x, dy, dweight, dbias, multi_precision, has_bias ) if dweight is not None: assert dweight_new.data_ptr() == dweight.data_ptr() - return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy() + if has_bias: + return ( + promote_dtype(dweight_new).numpy(), + promote_dtype(dbias_new).numpy(), + ) + else: + return promote_dtype(dweight_new).numpy() class TestMainClassBase(unittest.TestCase): @@ -103,7 +114,9 @@ class TestMainClassBase(unittest.TestCase): x = paddle.to_tensor(x) return x.astype(dtype or self.dtype) - def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): + def generate_rand_inputs( + self, has_dweight, has_dbias, multi_precision, has_bias + ): x_shape = self.shape dy_shape = self.shape[:-1] + [self.output_size] dweight_shape = [self.shape[-1], self.output_size] @@ -118,7 +131,7 @@ class TestMainClassBase(unittest.TestCase): else: dweight = None - if has_dbias: + if has_bias and has_dbias: dbias = self.rand(dbias_shape) if multi_precision: dbias = promote_dtype(dbias) @@ -126,14 +139,15 @@ class TestMainClassBase(unittest.TestCase): dbias = None return x, dy, dweight, dbias - def check_main(self, has_dweight, has_dbias, multi_precision): - print(has_dweight, has_dbias, multi_precision) + def check_main(self, has_dweight, has_dbias, multi_precision, has_bias): x, dy, dweight, dbias = self.generate_rand_inputs( - has_dweight, has_dbias, multi_precision + has_dweight, has_dbias, multi_precision, has_bias + ) + res1 = run_ground_truth( + x, dy, dweight, dbias, multi_precision, has_bias ) - res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision) res2 = run_fused_linear_param_grad_add( - x, dy, dweight, dbias, multi_precision + x, dy, dweight, dbias, multi_precision, has_bias ) self.assertEqual(len(res1), len(res2)) for r1, r2 in zip(res1, res2): @@ -153,9 +167,12 @@ class TestMainClassBase(unittest.TestCase): return for has_dweight in [False, True]: - for has_dbias in [False, True]: - for multi_precision in [False, True]: - self.check_main(has_dweight, has_dbias, multi_precision) + for has_bias in [False, True]: + for has_dbias in [False, True]: + for multi_precision in [False, True]: + self.check_main( + has_dweight, has_dbias, multi_precision, has_bias + ) class TestMainClassBF16(TestMainClassBase): diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index a833e48efc3..d0f9dd19341 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Disable Test list: test_fused_linear_param_grad_add serial_list="^test_conv2d_op$|\ ^test_conv2d_transpose_op$|\ @@ -69,6 +68,7 @@ parallel_list="^init_phi_test$|\ ^test_fused_gemm_epilogue_op$|\ ^test_fused_gemm_epilogue_op_with_es$|\ ^test_fused_layernorm_residual_dropout_bias$|\ +^test_fused_linear_param_grad_add$|\ ^test_fused_linear_pass$|\ ^test_fused_matmul_bias$|\ ^test_fused_multi_transformer_decoder_pass$|\ -- GitLab