未验证 提交 ef8323d4 编写于 作者: F furnace 提交者: GitHub

[ROCM] Add ROCm support for warpctc op (#31817)

* bugfix for warpctc

* fix warpctc commit id

* fix warpctc commit id

* fix warpctc commit id

* fix warpctc commit id

* fix warpctc commit id

* fix WARPCTC_WITH_HIP invalid

* Add logs to find out why can not dlopen libwarpctc.so

* fix warpctc commit id

* fix unit test test_warpctc_op

* Optime failed log for dlopen

* Optime failed log for dlopen

* Delete extra changes

* fix warpctc commit id

* fix warpctc commit id

* Add is_compiled_with_rocm for test_warpctc_op

* fix warpctc commit id

* Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed

* Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed

* Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed

* fix code style problems
上级 95f808c8
...@@ -14,11 +14,15 @@ ...@@ -14,11 +14,15 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
IF(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
ENDIF()
SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git) set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git)
set(WARPCTC_TAG cd828e5b6c3b953b82af73f7f44cddc393a20efa) set(WARPCTC_TAG c690fc5755abbdbdc98ef78d51ec10a6748a8cd1)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE) CACHE PATH "Warp-ctc Directory" FORCE)
...@@ -57,6 +61,7 @@ ExternalProject_Add( ...@@ -57,6 +61,7 @@ ExternalProject_Add(
-DCMAKE_CXX_FLAGS_DEBUG=$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline> -DCMAKE_CXX_FLAGS_DEBUG=$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU} -DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP} -DWITH_OMP=${USE_OMP}
-DWITH_TORCH=OFF -DWITH_TORCH=OFF
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
......
...@@ -159,8 +159,7 @@ class WarpCTCFunctor { ...@@ -159,8 +159,7 @@ class WarpCTCFunctor {
warpctc_version_ = platform::dynload::get_warpctc_version(); warpctc_version_ = platform::dynload::get_warpctc_version();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
// HIP not support ctcOptions in third-party warpctc #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUDA
options_.loc = CTC_GPU; options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>( options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -240,8 +241,18 @@ class TestWarpCTCOp(OpTest): ...@@ -240,8 +241,18 @@ class TestWarpCTCOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad( if core.is_compiled_with_rocm():
["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False) self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.009,
check_dygraph=False)
else:
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False)
class TestWarpCTCOpCase1(TestWarpCTCOp): class TestWarpCTCOpCase1(TestWarpCTCOp):
...@@ -335,8 +346,18 @@ class TestWarpCTCOpWithPadding(OpTest): ...@@ -335,8 +346,18 @@ class TestWarpCTCOpWithPadding(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad( if core.is_compiled_with_rocm():
["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False) self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.009,
check_dygraph=False)
else:
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False)
class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding): class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册