From ba415ee77a8ef3d2519075417b7d6cf68ff98f0c Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 25 Apr 2023 14:56:11 +0800 Subject: [PATCH] [CppExtension Cuda] Add cuda unit test for CppExtension (#52900) (#53269) * [CppExtension Cuda] Add cuda unit test for CppExtension * update extra_compile_args for CUDAExtension * add debug info * Add patch to fix CUDA12 compile error * patch for all env * add windows judgement * Try to fix setup function not found error * fix mix_relu_and_extension include file * fix setup compile error * remove useless debug comments * add sleep, debug CI-build * add space to disable cmake cache * remove debug info * add space to pass CI-build --- cmake/external/pybind11.cmake | 9 ++++ cmake/external/warpctc.cmake | 1 + patches/pybind/cast.h.patch | 15 +++++++ .../cpp_extension/cpp_extension_setup.py | 14 ++++-- .../tests/cpp_extension/custom_extension.cc | 3 ++ .../cpp_extension/custom_relu_forward.cu | 45 +++++++++++++++++++ .../mix_relu_and_extension_setup.py | 2 +- .../cpp_extension/test_cpp_extension_jit.py | 13 +++++- .../cpp_extension/test_cpp_extension_setup.py | 13 ++++++ .../paddle/fluid/tests/cpp_extension/utils.py | 42 +++++++++++++++++ .../utils/cpp_extension/extension_utils.py | 9 ++-- 11 files changed, 156 insertions(+), 10 deletions(-) create mode 100644 patches/pybind/cast.h.patch create mode 100644 python/paddle/fluid/tests/cpp_extension/custom_relu_forward.cu diff --git a/cmake/external/pybind11.cmake b/cmake/external/pybind11.cmake index 6abd24e8730..db53e3511be 100644 --- a/cmake/external/pybind11.cmake +++ b/cmake/external/pybind11.cmake @@ -21,6 +21,14 @@ set(PYBIND_TAG v2.10.3) set(PYBIND_INCLUDE_DIR ${THIRD_PARTY_PATH}/pybind/src/extern_pybind/include) include_directories(${PYBIND_INCLUDE_DIR}) +set(PYBIND_PATCH_COMMAND "") +if(NOT WIN32) + file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/pybind/cast.h.patch + native_dst) + set(PYBIND_PATCH_COMMAND patch -d ${PYBIND_INCLUDE_DIR}/pybind11 < + ${native_dst}) +endif() + ExternalProject_Add( extern_pybind ${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE} @@ -33,6 +41,7 @@ ExternalProject_Add( # third-party library version changes cannot be incorporated. # reference: https://cmake.org/cmake/help/latest/module/ExternalProject.html UPDATE_COMMAND "" + PATCH_COMMAND ${PYBIND_PATCH_COMMAND} CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index e1e7234da0e..46befee8bd2 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -82,6 +82,7 @@ else() set(WARPCTC_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(WARPCTC_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) endif() + ExternalProject_Add( extern_warpctc ${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE} diff --git a/patches/pybind/cast.h.patch b/patches/pybind/cast.h.patch new file mode 100644 index 00000000000..ebd65571ebf --- /dev/null +++ b/patches/pybind/cast.h.patch @@ -0,0 +1,15 @@ +diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h +index 3a404602..9054478c 100644 +--- a/include/pybind11/cast.h ++++ b/include/pybind11/cast.h +@@ -42,7 +42,9 @@ using make_caster = type_caster>; + // Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T + template + typename make_caster::template cast_op_type cast_op(make_caster &caster) { +- return caster.operator typename make_caster::template cast_op_type(); ++ // https://github.com/pybind/pybind11/issues/4606 with CUDA 12 ++ //return caster.operator typename make_caster::template cast_op_type(); ++ return caster; + } + template + typename make_caster::template cast_op_type::type> diff --git a/python/paddle/fluid/tests/cpp_extension/cpp_extension_setup.py b/python/paddle/fluid/tests/cpp_extension/cpp_extension_setup.py index b5c12284c11..5a4ff2afd6c 100644 --- a/python/paddle/fluid/tests/cpp_extension/cpp_extension_setup.py +++ b/python/paddle/fluid/tests/cpp_extension/cpp_extension_setup.py @@ -15,7 +15,9 @@ import os from site import getsitepackages -from paddle.utils.cpp_extension import CppExtension, setup +from utils import extra_compile_args + +from paddle.utils.cpp_extension import CUDAExtension, setup paddle_includes = [] for site_packages_path in getsitepackages(): @@ -30,10 +32,14 @@ paddle_includes.append(os.path.dirname(os.path.abspath(__file__))) setup( name='custom_cpp_extension', - ext_modules=CppExtension( - sources=["custom_extension.cc", "custom_sub.cc"], + ext_modules=CUDAExtension( + sources=[ + "custom_extension.cc", + "custom_sub.cc", + "custom_relu_forward.cu", + ], include_dirs=paddle_includes, - extra_compile_args={'cc': ['-w', '-g']}, + extra_compile_args=extra_compile_args, verbose=True, ), ) diff --git a/python/paddle/fluid/tests/cpp_extension/custom_extension.cc b/python/paddle/fluid/tests/cpp_extension/custom_extension.cc index 2334e23af53..2fc5c42a80d 100644 --- a/python/paddle/fluid/tests/cpp_extension/custom_extension.cc +++ b/python/paddle/fluid/tests/cpp_extension/custom_extension.cc @@ -20,6 +20,8 @@ paddle::Tensor custom_sub(paddle::Tensor x, paddle::Tensor y); +paddle::Tensor relu_cuda_forward(const paddle::Tensor& x); + paddle::Tensor custom_add(const paddle::Tensor& x, const paddle::Tensor& y) { return x.exp() + y.exp(); } @@ -46,6 +48,7 @@ PYBIND11_MODULE(custom_cpp_extension, m) { m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None"); m.def( "optional_tensor", &optional_tensor, "returned Tensor might be optional"); + m.def("relu_cuda_forward", &relu_cuda_forward, "relu(x)"); py::class_(m, "Power") .def(py::init()) diff --git a/python/paddle/fluid/tests/cpp_extension/custom_relu_forward.cu b/python/paddle/fluid/tests/cpp_extension/custom_relu_forward.cu new file mode 100644 index 00000000000..e0405309f7a --- /dev/null +++ b/python/paddle/fluid/tests/cpp_extension/custom_relu_forward.cu @@ -0,0 +1,45 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +template +__global__ void relu_cuda_forward_kernel(const data_t* x, + data_t* y, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); + } +} + +paddle::Tensor relu_cuda_forward(const paddle::Tensor& x) { + CHECK_GPU_INPUT(x); + auto out = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out.data(), numel); + })); + + return out; +} diff --git a/python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension_setup.py b/python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension_setup.py index 3766d33f034..823d0183cfd 100644 --- a/python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension_setup.py +++ b/python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension_setup.py @@ -21,7 +21,7 @@ from paddle.utils.cpp_extension import CppExtension, setup setup( name='mix_relu_extension', ext_modules=CppExtension( - sources=["mix_relu_and_extension.cc", "custom_sub.cc"], + sources=["mix_relu_and_extension.cc"], include_dirs=paddle_includes + [os.path.dirname(os.path.abspath(__file__))], extra_compile_args={'cc': ['-w', '-g']}, diff --git a/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_jit.py b/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_jit.py index 9ed330a2b4a..bc6f8113afd 100644 --- a/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_jit.py +++ b/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_jit.py @@ -18,6 +18,7 @@ import unittest from site import getsitepackages import numpy as np +from utils import check_output import paddle from paddle.utils.cpp_extension import load @@ -27,7 +28,7 @@ if os.name == 'nt' or sys.platform.startswith('darwin'): sys.exit() # Compile and load cpp extension Just-In-Time. -sources = ["custom_extension.cc", "custom_sub.cc"] +sources = ["custom_extension.cc", "custom_sub.cc", "custom_relu_forward.cu"] paddle_includes = [] for site_packages_path in getsitepackages(): paddle_includes.append( @@ -69,6 +70,8 @@ class TestCppExtensionJITInstall(unittest.TestCase): self._test_extension_class() self._test_nullable_tensor() self._test_optional_tensor() + if paddle.is_compiled_with_cuda(): + self._test_cuda_relu() def _test_extension_function(self): for dtype in self.dtypes: @@ -130,6 +133,14 @@ class TestCppExtensionJITInstall(unittest.TestCase): err_msg=f'extension out: {x},\n numpy out: {x_np}', ) + def _test_cuda_relu(self): + paddle.set_device('gpu') + x = np.random.uniform(-1, 1, [4, 8]).astype('float32') + x = paddle.to_tensor(x, dtype='float32') + out = custom_cpp_extension.relu_cuda_forward(x) + pd_out = paddle.nn.functional.relu(x) + check_output(out, pd_out, "out") + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py b/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py index 5c8c91ed303..53dffde4320 100644 --- a/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py +++ b/python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py @@ -18,6 +18,7 @@ import sys import unittest import numpy as np +from utils import check_output import paddle from paddle import static @@ -154,6 +155,8 @@ class TestCppExtensionSetupInstall(unittest.TestCase): self._test_static() self._test_dynamic() self._test_double_grad_dynamic() + if paddle.is_compiled_with_cuda(): + self._test_cuda_relu() def _test_extension_function_plain(self): import custom_cpp_extension @@ -314,6 +317,16 @@ class TestCppExtensionSetupInstall(unittest.TestCase): ), ) + def _test_cuda_relu(self): + import custom_cpp_extension + + paddle.set_device('gpu') + x = np.random.uniform(-1, 1, [4, 8]).astype('float32') + x = paddle.to_tensor(x, dtype='float32') + out = custom_cpp_extension.relu_cuda_forward(x) + pd_out = paddle.nn.functional.relu(x) + check_output(out, pd_out, "out") + if __name__ == '__main__': if os.name == 'nt' or sys.platform.startswith('darwin'): diff --git a/python/paddle/fluid/tests/cpp_extension/utils.py b/python/paddle/fluid/tests/cpp_extension/utils.py index 5c5a458a5c7..19659c6d5d7 100644 --- a/python/paddle/fluid/tests/cpp_extension/utils.py +++ b/python/paddle/fluid/tests/cpp_extension/utils.py @@ -16,6 +16,8 @@ import os import sys from site import getsitepackages +import numpy as np + from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS IS_MAC = sys.platform.startswith('darwin') @@ -37,3 +39,43 @@ for site_packages_path in getsitepackages(): extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] extra_nvcc_args = ['-O3'] extra_compile_args = {'cc': extra_cc_args, 'nvcc': extra_nvcc_args} + + +def check_output(out, pd_out, name): + if out is None and pd_out is None: + return + assert out is not None, "out value of " + name + " is None" + assert pd_out is not None, "pd_out value of " + name + " is None" + if isinstance(out, list) and isinstance(pd_out, list): + for idx in range(len(out)): + np.testing.assert_array_equal( + out[idx], + pd_out[idx], + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out[idx], name, pd_out[idx] + ), + ) + else: + np.testing.assert_array_equal( + out, + pd_out, + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out, name, pd_out + ), + ) + + +def check_output_allclose(out, pd_out, name, rtol=5e-5, atol=1e-2): + if out is None and pd_out is None: + return + assert out is not None, "out value of " + name + " is None" + assert pd_out is not None, "pd_out value of " + name + " is None" + np.testing.assert_allclose( + out, + pd_out, + rtol, + atol, + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out, name, pd_out + ), + ) diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 8ff70ca4c0e..582cd560ddc 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -180,6 +180,9 @@ def custom_write_stub(resource, pyfile): def __bootstrap__(): assert os.path.exists(so_path) + # load custom op shared library with abs path + custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) + if os.name == 'nt' or sys.platform.startswith('darwin'): # Cpp Extension only support Linux now mod = types.ModuleType(__name__) @@ -193,10 +196,8 @@ def custom_write_stub(resource, pyfile): except ImportError: mod = types.ModuleType(__name__) - # load custom op shared library with abs path - custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path) - for custom_ops in custom_ops: - setattr(mod, custom_ops, eval(custom_ops)) + for custom_op in custom_ops: + setattr(mod, custom_op, eval(custom_op)) __bootstrap__() -- GitLab