diff --git a/paddle/fluid/operators/sparse_attention_op.cu b/paddle/fluid/operators/sparse_attention_op.cu index 49f8263ab289a131ac58f5d995999c9043d0a33f..7991ca9e767bb4c62046af72b3c64d44c459e43b 100644 --- a/paddle/fluid/operators/sparse_attention_op.cu +++ b/paddle/fluid/operators/sparse_attention_op.cu @@ -378,7 +378,7 @@ void DotSdd(const platform::CUDADeviceContext& ctx, const Tensor* a, const_cast(b_data), gpu_type, CUSPARSE_ORDER_ROW); // Create sparse matrix C in CSR format - int c_nnz = c_columns->dims()[1]; + int c_nnz = c_columns->numel(); platform::dynload::cusparseCreateCsr( &mat_c, num_rows, num_rows, c_nnz, const_cast(c_offset_data), const_cast(c_columns_data), c_value_data, CUSPARSE_INDEX_32I, @@ -427,7 +427,7 @@ void DotDsd(const platform::CUDADeviceContext& ctx, const Tensor* a_offset, platform::dynload::cusparseCreate(&handle); // Create sparse matrix A in CSR format - int a_nnz = a_columns->dims()[1]; + int a_nnz = a_columns->numel(); platform::dynload::cusparseCreateCsr( &mat_a, num_rows, num_rows, a_nnz, const_cast(a_offset_data), const_cast(a_columns_data), const_cast(a_value_data), @@ -600,7 +600,7 @@ class SparseAttentionGradCUDAKernel : public framework::OpKernel { &dvalue_lists[i], M, N, true, false); // dSoftmax = dOut * transpose(Value) - int nnz_num = columns.dims()[0]; + int nnz_num = columns_lists[i].numel(); Tensor dsoftmax; dsoftmax.Resize({nnz_num}); dsoftmax.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/platform/device_event_test.cc b/paddle/fluid/platform/device_event_test.cc index d9f744b26256b1f00bd256319a5ab606fe7a0b4c..c0646b454646a8fdc5e855713c465a2b20dd0a18 100644 --- a/paddle/fluid/platform/device_event_test.cc +++ b/paddle/fluid/platform/device_event_test.cc @@ -43,8 +43,6 @@ TEST(DeviceEvent, CUDA) { ASSERT_EQ(status, true); // case 2. test for event_recorder event.Record(context); - status = event.Query(); - ASSERT_EQ(status, false); // case 3. test for event_finisher event.Finish(); status = event.Query(); diff --git a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py index 3cef228d14d6eb4293f14c9e93f3f7f2945871b1..d52882acfc9acf0e4003c9cb6817396481d9dffe 100644 --- a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py +++ b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import site from paddle.fluid import core from distutils.sysconfig import get_python_lib from distutils.core import setup, Extension @@ -42,10 +43,11 @@ if core.is_compiled_with_npu(): paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0'] # include path -site_packages_path = get_python_lib() -paddle_custom_kernel_include = [ - os.path.join(site_packages_path, 'paddle', 'include'), -] +site_packages_path = site.getsitepackages() +paddle_custom_kernel_include = list( + map(lambda path: os.path.join(path, 'paddle', 'include'), + site_packages_path)) + # include path third_party compile_third_party_path = os.path.join(os.environ['PADDLE_ROOT'], 'build/third_party') @@ -56,9 +58,8 @@ paddle_custom_kernel_include += [ ] # libs path -paddle_custom_kernel_library_dir = [ - os.path.join(site_packages_path, 'paddle', 'fluid'), -] +paddle_custom_kernel_library_dir = list( + map(lambda path: os.path.join(path, 'paddle', 'fluid'), site_packages_path)) # libs libs = [':core_avx.so'] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7face5f8662c4ea9cdfe9a73f0fba95873026cc2..a225be8536b561cfee1be388957898fa0034624b 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1001,7 +1001,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_gather_op PROPERTIES TIMEOUT 180) set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250) set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120) set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120) @@ -1203,6 +1203,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) test_collective_alltoall_api test_collective_global_gather test_collective_global_scatter + test_collective_sendrecv_api PROPERTIES LABELS "RUN_TYPE=DIST") endif() set_tests_properties(test_paddle_multiprocessing PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py index fb0c37e3d659d82e9186cd4438b614d066e726a3..53683ae3b92bfb125c72b4bc0032ccb611a256fd 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py @@ -22,6 +22,8 @@ import config from config import ATOL, DEVICES, RTOL from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +np.random.seed(2022) + @place(DEVICES) @parameterize_cls((TEST_CASE_NAME, 'alpha', 'beta'), diff --git a/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py b/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py index 711891216b68a1c50a4a6469b84d0367925de83b..3422135c229005541c5ca5460c2eb4204fb0d78a 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py +++ b/python/paddle/fluid/tests/unittests/ir/test_fuse_resnet_unit.py @@ -24,9 +24,11 @@ paddle.enable_static() np.random.seed(0) -@unittest.skipIf(not paddle.is_compiled_with_cuda() or - paddle.get_cudnn_version() < 8000, - "only support with cuda and cudnn version is at least 8.0.") +@unittest.skipIf(not paddle.is_compiled_with_cuda() + or paddle.get_cudnn_version() < 8000 + or paddle.device.cuda.get_device_capability()[0] < 7, + "only support with cuda and cudnn version is at least 8.0 " + "and device's compute capability is at least 7.0") class TestFuseResNetUnit(unittest.TestCase): def test_fuse_resenet_unit(self): place = paddle.CUDAPlace(0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 821702a30feeccda7a572079212bdee22a092ddb..f4eb31032da252a2b6973f70d4573880def490fb 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -206,21 +206,15 @@ class TestDistBase(unittest.TestCase): with_gloo = '0' else: with_gloo = '1' - required_envs = { - "FLAGS_fraction_of_gpu_memory_to_use": "0.15", - "FLAGS_eager_delete_tensor_gb": "0.0", - "PATH": os.getenv("PATH"), - "PYTHONPATH": os.getenv("PYTHONPATH", ""), - "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), - "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), - "FLAGS_call_stack_level": "2", - "GLOG_v": "3", + required_envs = os.environ.copy() + additional_envs = { "NCCL_P2P_DISABLE": "1", "STATIC_MODE": static_mode, "PADDLE_WITH_GLOO": with_gloo, "BACKEND": backend, "PATH_ID": path_id } + required_envs.update(additional_envs) required_envs.update(need_envs) if check_error_log: required_envs["GLOG_v"] = "3" diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 3d7dc2da052f35ae213ebdb65e4864a7f89d81c9..87c1c728226f0b7acde700395917cb4e3ebd847d 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -300,7 +300,7 @@ class API_TestDygraphGather(unittest.TestCase): return x = np.random.rand(226862, 256).astype("float32") - index = np.random.randint(0, 22682, size=(11859027)) + index = np.random.randint(0, 22682, size=(8859027)) def test_dygraph(): with fluid.dygraph.guard(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index f14606ca2d91da647165fe3651b949ffb2b78b21..0c7f375baba3eb6a227a2b24c633994da8d50abb 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -15,6 +15,7 @@ import unittest import paddle import paddle.fluid as fluid +import paddle.fluid.core as core import numpy as np import six import cv2 @@ -1283,6 +1284,10 @@ class TestLayerNormFp16(unittest.TestCase): func_isinstance() +@unittest.skipIf( + paddle.is_compiled_with_cuda() + and not core.is_bfloat16_supported(core.CUDAPlace(0)), + "skip bf16 test if cuda is in use but bf16 is not supported by gpu arch.") class TestBf16(unittest.TestCase): ''' test amp for BF16 @@ -1300,17 +1305,13 @@ class TestBf16(unittest.TestCase): def test_bf16(self): def func_isinstance(): - if fluid.core.is_compiled_with_cuda( - ) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)): - out_fp32 = self.train(enable_amp=False) - out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') - out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) + out_fp32 = self.train(enable_amp=False) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose(out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) with _test_eager_guard(): func_isinstance()