diff --git a/paddle/fluid/operators/sparse_attention_op.cu b/paddle/fluid/operators/sparse_attention_op.cu index 2949642d2f3dd7504baa7affdb6c3baa971f7768..c1e6a9ca3f9957e68b8e1251e730435ef14f96a8 100644 --- a/paddle/fluid/operators/sparse_attention_op.cu +++ b/paddle/fluid/operators/sparse_attention_op.cu @@ -378,7 +378,7 @@ void DotSdd(const platform::CUDADeviceContext& ctx, const Tensor* a, const_cast(b_data), gpu_type, CUSPARSE_ORDER_ROW); // Create sparse matrix C in CSR format - int c_nnz = c_columns->dims()[1]; + int c_nnz = c_columns->numel(); platform::dynload::cusparseCreateCsr( &mat_c, num_rows, num_rows, c_nnz, const_cast(c_offset_data), const_cast(c_columns_data), c_value_data, CUSPARSE_INDEX_32I, @@ -427,7 +427,7 @@ void DotDsd(const platform::CUDADeviceContext& ctx, const Tensor* a_offset, platform::dynload::cusparseCreate(&handle); // Create sparse matrix A in CSR format - int a_nnz = a_columns->dims()[1]; + int a_nnz = a_columns->numel(); platform::dynload::cusparseCreateCsr( &mat_a, num_rows, num_rows, a_nnz, const_cast(a_offset_data), const_cast(a_columns_data), const_cast(a_value_data), @@ -600,7 +600,7 @@ class SparseAttentionGradCUDAKernel : public framework::OpKernel { &dvalue_lists[i], M, N, true, false); // dSoftmax = dOut * transpose(Value) - int nnz_num = columns.dims()[0]; + int nnz_num = columns_lists[i].numel(); Tensor dsoftmax; dsoftmax.Resize({nnz_num}); dsoftmax.mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py b/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py index 3eab857826005747e026dce6f0db561575349956..16f9ecad6d2525000ac08600f8f8de4764d4be93 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py +++ b/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py @@ -25,8 +25,10 @@ np.random.seed(0) @unittest.skipIf(not paddle.is_compiled_with_cuda() - or paddle.get_cudnn_version() < 8000, - "only support with cuda and cudnn version is at least 8.0.") + or paddle.get_cudnn_version() < 8000 + or paddle.device.cuda.get_device_capability()[0] < 7, + "only support with cuda and cudnn version is at least 8.0 " + "and device's compute capability is at least 7.0") class TestFuseResNetUnit(unittest.TestCase): def test_fuse_resenet_unit(self): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 7310d19a522ff8c978fdec29efcb5c401f4506c2..0b2df9885abaff67a94d36a306513c29b00521e0 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -15,6 +15,7 @@ import unittest import paddle import paddle.fluid as fluid +import paddle.fluid.core as core import numpy as np import six import cv2 @@ -1304,6 +1305,10 @@ class TestLayerNormFp16(unittest.TestCase): func_isinstance() +@unittest.skipIf( + paddle.is_compiled_with_cuda() + and not core.is_bfloat16_supported(core.CUDAPlace(0)), + "skip bf16 test if cuda is in use but bf16 is not supported by gpu arch.") class TestBf16(unittest.TestCase): ''' test amp for BF16 @@ -1323,15 +1328,13 @@ class TestBf16(unittest.TestCase): def test_bf16(self): def func_isinstance(): - if fluid.core.is_compiled_with_cuda( - ) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)): - out_fp32 = self.train(enable_amp=False) - out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') - out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') - self.assertTrue( - np.allclose(out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) - self.assertTrue( - np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) + out_fp32 = self.train(enable_amp=False) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose(out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) with _test_eager_guard(): func_isinstance()