未验证 提交 9edbe4aa 编写于 作者: A Aganlengzi 提交者: GitHub

[cherry-pick] NVIDIA fixes (#43780)

* Use all sitepackages path as the library/include path (#42940)

* Fix several unit tests and increase the unit tests stability (#43670)

* Reduce gather op unit tests size and increase the timeout

* Add NVIDIA_TF32_OVERRIDE for multi-processes environment

* Remove record test for device event ut

* Fix 3 unittest errors (#43532)

* Fix test_fuse_resnet_unit failure

* Fix test_imperative_auto_mixed_precision failure

* Fix sparse_attention_op error

* Fix sparse_attention_op error

* Use fixed random seed (#43659)

* for CI test_collective_sendrecv_api
Co-authored-by: Nzlsh80826 <rewang@nvidia.com>
Co-authored-by: NShijie <505749828@qq.com>
上级 edff59b1
......@@ -378,7 +378,7 @@ void DotSdd(const platform::CUDADeviceContext& ctx, const Tensor* a,
const_cast<T*>(b_data), gpu_type,
CUSPARSE_ORDER_ROW);
// Create sparse matrix C in CSR format
int c_nnz = c_columns->dims()[1];
int c_nnz = c_columns->numel();
platform::dynload::cusparseCreateCsr(
&mat_c, num_rows, num_rows, c_nnz, const_cast<int*>(c_offset_data),
const_cast<int*>(c_columns_data), c_value_data, CUSPARSE_INDEX_32I,
......@@ -427,7 +427,7 @@ void DotDsd(const platform::CUDADeviceContext& ctx, const Tensor* a_offset,
platform::dynload::cusparseCreate(&handle);
// Create sparse matrix A in CSR format
int a_nnz = a_columns->dims()[1];
int a_nnz = a_columns->numel();
platform::dynload::cusparseCreateCsr(
&mat_a, num_rows, num_rows, a_nnz, const_cast<int*>(a_offset_data),
const_cast<int*>(a_columns_data), const_cast<T*>(a_value_data),
......@@ -600,7 +600,7 @@ class SparseAttentionGradCUDAKernel : public framework::OpKernel<T> {
&dvalue_lists[i], M, N, true, false);
// dSoftmax = dOut * transpose(Value)
int nnz_num = columns.dims()[0];
int nnz_num = columns_lists[i].numel();
Tensor dsoftmax;
dsoftmax.Resize({nnz_num});
dsoftmax.mutable_data<T>(ctx.GetPlace());
......
......@@ -43,8 +43,6 @@ TEST(DeviceEvent, CUDA) {
ASSERT_EQ(status, true);
// case 2. test for event_recorder
event.Record(context);
status = event.Query();
ASSERT_EQ(status, false);
// case 3. test for event_finisher
event.Finish();
status = event.Query();
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import site
from paddle.fluid import core
from distutils.sysconfig import get_python_lib
from distutils.core import setup, Extension
......@@ -42,10 +43,11 @@ if core.is_compiled_with_npu():
paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0']
# include path
site_packages_path = get_python_lib()
paddle_custom_kernel_include = [
os.path.join(site_packages_path, 'paddle', 'include'),
]
site_packages_path = site.getsitepackages()
paddle_custom_kernel_include = list(
map(lambda path: os.path.join(path, 'paddle', 'include'),
site_packages_path))
# include path third_party
compile_third_party_path = os.path.join(os.environ['PADDLE_ROOT'],
'build/third_party')
......@@ -56,9 +58,8 @@ paddle_custom_kernel_include += [
]
# libs path
paddle_custom_kernel_library_dir = [
os.path.join(site_packages_path, 'paddle', 'fluid'),
]
paddle_custom_kernel_library_dir = list(
map(lambda path: os.path.join(path, 'paddle', 'fluid'), site_packages_path))
# libs
libs = [':core_avx.so']
......
......@@ -1001,7 +1001,7 @@ set_tests_properties(test_matmul_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nearest_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 180)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250)
set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120)
......@@ -1203,6 +1203,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_collective_alltoall_api
test_collective_global_gather
test_collective_global_scatter
test_collective_sendrecv_api
PROPERTIES LABELS "RUN_TYPE=DIST")
endif()
set_tests_properties(test_paddle_multiprocessing PROPERTIES TIMEOUT 120)
......
......@@ -22,6 +22,8 @@ import config
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
np.random.seed(2022)
@place(DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'alpha', 'beta'),
......
......@@ -24,9 +24,11 @@ paddle.enable_static()
np.random.seed(0)
@unittest.skipIf(not paddle.is_compiled_with_cuda() or
paddle.get_cudnn_version() < 8000,
"only support with cuda and cudnn version is at least 8.0.")
@unittest.skipIf(not paddle.is_compiled_with_cuda()
or paddle.get_cudnn_version() < 8000
or paddle.device.cuda.get_device_capability()[0] < 7,
"only support with cuda and cudnn version is at least 8.0 "
"and device's compute capability is at least 7.0")
class TestFuseResNetUnit(unittest.TestCase):
def test_fuse_resenet_unit(self):
place = paddle.CUDAPlace(0)
......
......@@ -206,21 +206,15 @@ class TestDistBase(unittest.TestCase):
with_gloo = '0'
else:
with_gloo = '1'
required_envs = {
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_eager_delete_tensor_gb": "0.0",
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"FLAGS_call_stack_level": "2",
"GLOG_v": "3",
required_envs = os.environ.copy()
additional_envs = {
"NCCL_P2P_DISABLE": "1",
"STATIC_MODE": static_mode,
"PADDLE_WITH_GLOO": with_gloo,
"BACKEND": backend,
"PATH_ID": path_id
}
required_envs.update(additional_envs)
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
......
......@@ -300,7 +300,7 @@ class API_TestDygraphGather(unittest.TestCase):
return
x = np.random.rand(226862, 256).astype("float32")
index = np.random.randint(0, 22682, size=(11859027))
index = np.random.randint(0, 22682, size=(8859027))
def test_dygraph():
with fluid.dygraph.guard():
......
......@@ -15,6 +15,7 @@
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import six
import cv2
......@@ -1283,6 +1284,10 @@ class TestLayerNormFp16(unittest.TestCase):
func_isinstance()
@unittest.skipIf(
paddle.is_compiled_with_cuda()
and not core.is_bfloat16_supported(core.CUDAPlace(0)),
"skip bf16 test if cuda is in use but bf16 is not supported by gpu arch.")
class TestBf16(unittest.TestCase):
'''
test amp for BF16
......@@ -1300,17 +1305,13 @@ class TestBf16(unittest.TestCase):
def test_bf16(self):
def func_isinstance():
if fluid.core.is_compiled_with_cuda(
) and fluid.core.is_bfloat16_supported(paddle.CUDAPlace(0)):
out_fp32 = self.train(enable_amp=False)
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
out_fp32 = self.train(enable_amp=False)
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
self.assertTrue(
np.allclose(out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))
with _test_eager_guard():
func_isinstance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册