未验证 提交 ef8323d4 编写于 作者: F furnace 提交者: GitHub

[ROCM] Add ROCm support for warpctc op (#31817)

* bugfix for warpctc

* fix warpctc commit id

* fix warpctc commit id

* fix warpctc commit id

* fix warpctc commit id

* fix warpctc commit id

* fix WARPCTC_WITH_HIP invalid

* Add logs to find out why can not dlopen libwarpctc.so

* fix warpctc commit id

* fix unit test test_warpctc_op

* Optime failed log for dlopen

* Optime failed log for dlopen

* Delete extra changes

* fix warpctc commit id

* fix warpctc commit id

* Add is_compiled_with_rocm for test_warpctc_op

* fix warpctc commit id

* Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed

* Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed

* Cancel optimize dlopen failed reason, move to next pr, due to it makes windows ci failed

* fix code style problems
上级 95f808c8
......@@ -14,11 +14,15 @@
INCLUDE(ExternalProject)
IF(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
ENDIF()
SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git)
set(WARPCTC_TAG cd828e5b6c3b953b82af73f7f44cddc393a20efa)
set(WARPCTC_TAG c690fc5755abbdbdc98ef78d51ec10a6748a8cd1)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)
......@@ -57,6 +61,7 @@ ExternalProject_Add(
-DCMAKE_CXX_FLAGS_DEBUG=$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DWITH_TORCH=OFF
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
......
......@@ -159,8 +159,7 @@ class WarpCTCFunctor {
warpctc_version_ = platform::dynload::get_warpctc_version();
if (platform::is_gpu_place(ctx.GetPlace())) {
// HIP not support ctcOptions in third-party warpctc
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
......
......@@ -20,6 +20,7 @@ import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
import paddle
import paddle.nn.functional as F
......@@ -240,8 +241,18 @@ class TestWarpCTCOp(OpTest):
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
if core.is_compiled_with_rocm():
self.check_grad(
["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False)
["Logits"],
"Loss",
max_relative_error=0.009,
check_dygraph=False)
else:
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False)
class TestWarpCTCOpCase1(TestWarpCTCOp):
......@@ -335,8 +346,18 @@ class TestWarpCTCOpWithPadding(OpTest):
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
if core.is_compiled_with_rocm():
self.check_grad(
["Logits"],
"Loss",
max_relative_error=0.009,
check_dygraph=False)
else:
self.check_grad(
["Logits"], "Loss", max_relative_error=0.007, check_dygraph=False)
["Logits"],
"Loss",
max_relative_error=0.007,
check_dygraph=False)
class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册