未验证 提交 a9134dc2 编写于 作者: S Shijie 提交者: GitHub

Fix 3 unittest errors (#43532)

* Fix test_fuse_resnet_unit failure

* Fix test_imperative_auto_mixed_precision failure

* Fix sparse_attention_op error

* Fix sparse_attention_op error
上级 c41c5e63
...@@ -378,7 +378,7 @@ void DotSdd(const platform::CUDADeviceContext& ctx, const Tensor* a, ...@@ -378,7 +378,7 @@ void DotSdd(const platform::CUDADeviceContext& ctx, const Tensor* a,
const_cast<T*>(b_data), gpu_type, const_cast<T*>(b_data), gpu_type,
CUSPARSE_ORDER_ROW); CUSPARSE_ORDER_ROW);
// Create sparse matrix C in CSR format // Create sparse matrix C in CSR format
int c_nnz = c_columns->dims()[1]; int c_nnz = c_columns->numel();
platform::dynload::cusparseCreateCsr( platform::dynload::cusparseCreateCsr(
&mat_c, num_rows, num_rows, c_nnz, const_cast<int*>(c_offset_data), &mat_c, num_rows, num_rows, c_nnz, const_cast<int*>(c_offset_data),
const_cast<int*>(c_columns_data), c_value_data, CUSPARSE_INDEX_32I, const_cast<int*>(c_columns_data), c_value_data, CUSPARSE_INDEX_32I,
...@@ -427,7 +427,7 @@ void DotDsd(const platform::CUDADeviceContext& ctx, const Tensor* a_offset, ...@@ -427,7 +427,7 @@ void DotDsd(const platform::CUDADeviceContext& ctx, const Tensor* a_offset,
platform::dynload::cusparseCreate(&handle); platform::dynload::cusparseCreate(&handle);
// Create sparse matrix A in CSR format // Create sparse matrix A in CSR format
int a_nnz = a_columns->dims()[1]; int a_nnz = a_columns->numel();
platform::dynload::cusparseCreateCsr( platform::dynload::cusparseCreateCsr(
&mat_a, num_rows, num_rows, a_nnz, const_cast<int*>(a_offset_data), &mat_a, num_rows, num_rows, a_nnz, const_cast<int*>(a_offset_data),
const_cast<int*>(a_columns_data), const_cast<T*>(a_value_data), const_cast<int*>(a_columns_data), const_cast<T*>(a_value_data),
...@@ -600,7 +600,7 @@ class SparseAttentionGradCUDAKernel : public framework::OpKernel<T> { ...@@ -600,7 +600,7 @@ class SparseAttentionGradCUDAKernel : public framework::OpKernel<T> {
&dvalue_lists[i], M, N, true, false); &dvalue_lists[i], M, N, true, false);
// dSoftmax = dOut * transpose(Value) // dSoftmax = dOut * transpose(Value)
int nnz_num = columns.dims()[0]; int nnz_num = columns_lists[i].numel();
Tensor dsoftmax; Tensor dsoftmax;
dsoftmax.Resize({nnz_num}); dsoftmax.Resize({nnz_num});
dsoftmax.mutable_data<T>(ctx.GetPlace()); dsoftmax.mutable_data<T>(ctx.GetPlace());
......
...@@ -25,8 +25,10 @@ np.random.seed(0) ...@@ -25,8 +25,10 @@ np.random.seed(0)
@unittest.skipIf(not paddle.is_compiled_with_cuda() @unittest.skipIf(not paddle.is_compiled_with_cuda()
or paddle.get_cudnn_version() < 8000, or paddle.get_cudnn_version() < 8000
"only support with cuda and cudnn version is at least 8.0.") or paddle.device.cuda.get_device_capability()[0] < 7,
"only support with cuda and cudnn version is at least 8.0 "
"and device's compute capability is at least 7.0")
class TestFuseResNetUnit(unittest.TestCase): class TestFuseResNetUnit(unittest.TestCase):
def test_fuse_resenet_unit(self): def test_fuse_resenet_unit(self):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np import numpy as np
import six import six
import cv2 import cv2
...@@ -1304,6 +1305,10 @@ class TestLayerNormFp16(unittest.TestCase): ...@@ -1304,6 +1305,10 @@ class TestLayerNormFp16(unittest.TestCase):
func_isinstance() func_isinstance()
@unittest.skipIf(
paddle.is_compiled_with_cuda()
and not core.is_bfloat16_supported(core.CUDAPlace(0)),
"skip bf16 test if cuda is in use but bf16 is not supported by gpu arch.")
class TestBf16(unittest.TestCase): class TestBf16(unittest.TestCase):
''' '''
test amp for BF16 test amp for BF16
...@@ -1323,15 +1328,13 @@ class TestBf16(unittest.TestCase): ...@@ -1323,15 +1328,13 @@ class TestBf16(unittest.TestCase):
def test_bf16(self): def test_bf16(self):
def func_isinstance(): def func_isinstance():
if fluid.core.is_compiled_with_cuda( out_fp32 = self.train(enable_amp=False)
) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)): out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_fp32 = self.train(enable_amp=False) out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') self.assertTrue(
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') np.allclose(out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue( self.assertTrue(
np.allclose(out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
with _test_eager_guard(): with _test_eager_guard():
func_isinstance() func_isinstance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册