未验证 提交 b561a05e 编写于 作者: Y Yuang Liu 提交者: GitHub

fix A100 fused linear grad add ut bug (#56136)

上级 7c4a3556
...@@ -54,7 +54,7 @@ def recreate(x, multi_precision): ...@@ -54,7 +54,7 @@ def recreate(x, multi_precision):
return paddle.to_tensor(x.numpy()) return paddle.to_tensor(x.numpy())
def run_ground_truth(x, dy, dweight, dbias, multi_precision): def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias):
x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision) x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision)
dweight_tmp = paddle.matmul( dweight_tmp = paddle.matmul(
...@@ -69,6 +69,7 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision): ...@@ -69,6 +69,7 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision):
assert dweight.dtype == dweight.dtype assert dweight.dtype == dweight.dtype
dweight += dweight_tmp dweight += dweight_tmp
if has_bias:
dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
if dbias is None: if dbias is None:
dbias = dbias_tmp dbias = dbias_tmp
...@@ -78,15 +79,25 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision): ...@@ -78,15 +79,25 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision):
dbias += dbias_tmp dbias += dbias_tmp
return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
else:
return promote_dtype(dweight).numpy()
def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision): def run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision, has_bias
):
dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add( dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision x, dy, dweight, dbias, multi_precision, has_bias
) )
if dweight is not None: if dweight is not None:
assert dweight_new.data_ptr() == dweight.data_ptr() assert dweight_new.data_ptr() == dweight.data_ptr()
return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy() if has_bias:
return (
promote_dtype(dweight_new).numpy(),
promote_dtype(dbias_new).numpy(),
)
else:
return promote_dtype(dweight_new).numpy()
class TestMainClassBase(unittest.TestCase): class TestMainClassBase(unittest.TestCase):
...@@ -103,7 +114,9 @@ class TestMainClassBase(unittest.TestCase): ...@@ -103,7 +114,9 @@ class TestMainClassBase(unittest.TestCase):
x = paddle.to_tensor(x) x = paddle.to_tensor(x)
return x.astype(dtype or self.dtype) return x.astype(dtype or self.dtype)
def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): def generate_rand_inputs(
self, has_dweight, has_dbias, multi_precision, has_bias
):
x_shape = self.shape x_shape = self.shape
dy_shape = self.shape[:-1] + [self.output_size] dy_shape = self.shape[:-1] + [self.output_size]
dweight_shape = [self.shape[-1], self.output_size] dweight_shape = [self.shape[-1], self.output_size]
...@@ -118,7 +131,7 @@ class TestMainClassBase(unittest.TestCase): ...@@ -118,7 +131,7 @@ class TestMainClassBase(unittest.TestCase):
else: else:
dweight = None dweight = None
if has_dbias: if has_bias and has_dbias:
dbias = self.rand(dbias_shape) dbias = self.rand(dbias_shape)
if multi_precision: if multi_precision:
dbias = promote_dtype(dbias) dbias = promote_dtype(dbias)
...@@ -126,14 +139,15 @@ class TestMainClassBase(unittest.TestCase): ...@@ -126,14 +139,15 @@ class TestMainClassBase(unittest.TestCase):
dbias = None dbias = None
return x, dy, dweight, dbias return x, dy, dweight, dbias
def check_main(self, has_dweight, has_dbias, multi_precision): def check_main(self, has_dweight, has_dbias, multi_precision, has_bias):
print(has_dweight, has_dbias, multi_precision)
x, dy, dweight, dbias = self.generate_rand_inputs( x, dy, dweight, dbias = self.generate_rand_inputs(
has_dweight, has_dbias, multi_precision has_dweight, has_dbias, multi_precision, has_bias
)
res1 = run_ground_truth(
x, dy, dweight, dbias, multi_precision, has_bias
) )
res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision)
res2 = run_fused_linear_param_grad_add( res2 = run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision x, dy, dweight, dbias, multi_precision, has_bias
) )
self.assertEqual(len(res1), len(res2)) self.assertEqual(len(res1), len(res2))
for r1, r2 in zip(res1, res2): for r1, r2 in zip(res1, res2):
...@@ -153,9 +167,12 @@ class TestMainClassBase(unittest.TestCase): ...@@ -153,9 +167,12 @@ class TestMainClassBase(unittest.TestCase):
return return
for has_dweight in [False, True]: for has_dweight in [False, True]:
for has_bias in [False, True]:
for has_dbias in [False, True]: for has_dbias in [False, True]:
for multi_precision in [False, True]: for multi_precision in [False, True]:
self.check_main(has_dweight, has_dbias, multi_precision) self.check_main(
has_dweight, has_dbias, multi_precision, has_bias
)
class TestMainClassBF16(TestMainClassBase): class TestMainClassBF16(TestMainClassBase):
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Disable Test list: test_fused_linear_param_grad_add
serial_list="^test_conv2d_op$|\ serial_list="^test_conv2d_op$|\
^test_conv2d_transpose_op$|\ ^test_conv2d_transpose_op$|\
...@@ -69,6 +68,7 @@ parallel_list="^init_phi_test$|\ ...@@ -69,6 +68,7 @@ parallel_list="^init_phi_test$|\
^test_fused_gemm_epilogue_op$|\ ^test_fused_gemm_epilogue_op$|\
^test_fused_gemm_epilogue_op_with_es$|\ ^test_fused_gemm_epilogue_op_with_es$|\
^test_fused_layernorm_residual_dropout_bias$|\ ^test_fused_layernorm_residual_dropout_bias$|\
^test_fused_linear_param_grad_add$|\
^test_fused_linear_pass$|\ ^test_fused_linear_pass$|\
^test_fused_matmul_bias$|\ ^test_fused_matmul_bias$|\
^test_fused_multi_transformer_decoder_pass$|\ ^test_fused_multi_transformer_decoder_pass$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册