未验证 提交 b561a05e 编写于 作者: Y Yuang Liu 提交者: GitHub

fix A100 fused linear grad add ut bug (#56136)

上级 7c4a3556
......@@ -54,7 +54,7 @@ def recreate(x, multi_precision):
return paddle.to_tensor(x.numpy())
def run_ground_truth(x, dy, dweight, dbias, multi_precision):
def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias):
x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision)
dweight_tmp = paddle.matmul(
......@@ -69,24 +69,35 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision):
assert dweight.dtype == dweight.dtype
dweight += dweight_tmp
dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
if dbias is None:
dbias = dbias_tmp
else:
assert dbias.shape == dbias_tmp.shape
assert dbias.dtype == dbias_tmp.dtype
dbias += dbias_tmp
if has_bias:
dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
if dbias is None:
dbias = dbias_tmp
else:
assert dbias.shape == dbias_tmp.shape
assert dbias.dtype == dbias_tmp.dtype
dbias += dbias_tmp
return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
else:
return promote_dtype(dweight).numpy()
def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision):
def run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision, has_bias
):
dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision
x, dy, dweight, dbias, multi_precision, has_bias
)
if dweight is not None:
assert dweight_new.data_ptr() == dweight.data_ptr()
return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy()
if has_bias:
return (
promote_dtype(dweight_new).numpy(),
promote_dtype(dbias_new).numpy(),
)
else:
return promote_dtype(dweight_new).numpy()
class TestMainClassBase(unittest.TestCase):
......@@ -103,7 +114,9 @@ class TestMainClassBase(unittest.TestCase):
x = paddle.to_tensor(x)
return x.astype(dtype or self.dtype)
def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision):
def generate_rand_inputs(
self, has_dweight, has_dbias, multi_precision, has_bias
):
x_shape = self.shape
dy_shape = self.shape[:-1] + [self.output_size]
dweight_shape = [self.shape[-1], self.output_size]
......@@ -118,7 +131,7 @@ class TestMainClassBase(unittest.TestCase):
else:
dweight = None
if has_dbias:
if has_bias and has_dbias:
dbias = self.rand(dbias_shape)
if multi_precision:
dbias = promote_dtype(dbias)
......@@ -126,14 +139,15 @@ class TestMainClassBase(unittest.TestCase):
dbias = None
return x, dy, dweight, dbias
def check_main(self, has_dweight, has_dbias, multi_precision):
print(has_dweight, has_dbias, multi_precision)
def check_main(self, has_dweight, has_dbias, multi_precision, has_bias):
x, dy, dweight, dbias = self.generate_rand_inputs(
has_dweight, has_dbias, multi_precision
has_dweight, has_dbias, multi_precision, has_bias
)
res1 = run_ground_truth(
x, dy, dweight, dbias, multi_precision, has_bias
)
res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision)
res2 = run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision
x, dy, dweight, dbias, multi_precision, has_bias
)
self.assertEqual(len(res1), len(res2))
for r1, r2 in zip(res1, res2):
......@@ -153,9 +167,12 @@ class TestMainClassBase(unittest.TestCase):
return
for has_dweight in [False, True]:
for has_dbias in [False, True]:
for multi_precision in [False, True]:
self.check_main(has_dweight, has_dbias, multi_precision)
for has_bias in [False, True]:
for has_dbias in [False, True]:
for multi_precision in [False, True]:
self.check_main(
has_dweight, has_dbias, multi_precision, has_bias
)
class TestMainClassBF16(TestMainClassBase):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Disable Test list: test_fused_linear_param_grad_add
serial_list="^test_conv2d_op$|\
^test_conv2d_transpose_op$|\
......@@ -69,6 +68,7 @@ parallel_list="^init_phi_test$|\
^test_fused_gemm_epilogue_op$|\
^test_fused_gemm_epilogue_op_with_es$|\
^test_fused_layernorm_residual_dropout_bias$|\
^test_fused_linear_param_grad_add$|\
^test_fused_linear_pass$|\
^test_fused_matmul_bias$|\
^test_fused_multi_transformer_decoder_pass$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册