未验证 提交 c088f9ec 编写于 作者: H Hui Zhang 提交者: GitHub

add rnn-t loss and api (#49199)

* add warp transducer code
上级 4ed6eeab
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
if(WITH_ROCM)
add_definitions(-DWARPRNNT_WITH_HIP)
endif()
set(WARPRNNT_PREFIX_DIR ${THIRD_PARTY_PATH}/warprnnt)
set(WARPRNNT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warprnnt)
set(WARPRNNT_REPOSITORY ${GIT_URL}/PaddlePaddle/warp-transducer.git)
set(WARPRNNT_TAG 7ea6bfe748779c245a0fcaa5dd9383826273eff2)
set(WARPRNNT_INCLUDE_DIR
"${WARPRNNT_INSTALL_DIR}/include"
CACHE PATH "Warp-rnnt Directory" FORCE)
# Used in unit test test_WarpCTCLayer
set(WARPRNNT_LIB_DIR
"${WARPRNNT_INSTALL_DIR}/lib"
CACHE PATH "Warp-rnnt Library Directory" FORCE)
if(WIN32)
set(WARPRNNT_LIBRARIES
"${WARPRNNT_INSTALL_DIR}/bin/warprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-rnnt Library" FORCE)
else()
set(WARPRNNT_LIBRARIES
"${WARPRNNT_INSTALL_DIR}/lib/libwarprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-rnnt Library" FORCE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()
if(WIN32)
set(WARPRNNT_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(WARPRNNT_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(WARPRNNT_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(WARPRNNT_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(WARPRNNT_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(WARPRNNT_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(WARPRNNT_C_FLAGS ${CMAKE_C_FLAGS})
set(WARPRNNT_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(WARPRNNT_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(WARPRNNT_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(WARPRNNT_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(WARPRNNT_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
ExternalProject_Add(
extern_warprnnt
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${WARPRNNT_REPOSITORY}
GIT_TAG ${WARPRNNT_TAG}
PREFIX ${WARPRNNT_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${WARPRNNT_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${WARPRNNT_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${WARPRNNT_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${WARPRNNT_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${WARPRNNT_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${WARPRNNT_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPRNNT_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DBUILD_TESTS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${WARPRNNT_INSTALL_DIR}
BUILD_BYPRODUCTS ${WARPRNNT_LIBRARIES})
message(STATUS "warp-rnnt library: ${WARPRNNT_LIBRARIES}")
get_filename_component(WARPRNNT_LIBRARY_PATH ${WARPRNNT_LIBRARIES} DIRECTORY)
include_directories(${WARPRNNT_INCLUDE_DIR}
)# For warprnnt code to include its headers.
add_library(warprnnt INTERFACE)
# set_property(TARGET warprnnt PROPERTY IMPORTED_LOCATION ${WARPRNNT_LIBRARIES})
add_dependencies(warprnnt extern_warprnnt)
...@@ -254,6 +254,7 @@ include(external/threadpool) # download threadpool ...@@ -254,6 +254,7 @@ include(external/threadpool) # download threadpool
include(external/dlpack) # download dlpack include(external/dlpack) # download dlpack
include(external/xxhash) # download, build, install xxhash include(external/xxhash) # download, build, install xxhash
include(external/warpctc) # download, build, install warpctc include(external/warpctc) # download, build, install warpctc
include(external/warprnnt) # download, build, install warprnnt
include(external/utf8proc) # download, build, install utf8proc include(external/utf8proc) # download, build, install utf8proc
list(APPEND third_party_deps extern_eigen3 extern_gflags extern_glog list(APPEND third_party_deps extern_eigen3 extern_gflags extern_glog
...@@ -264,6 +265,7 @@ list( ...@@ -264,6 +265,7 @@ list(
extern_zlib extern_zlib
extern_dlpack extern_dlpack
extern_warpctc extern_warpctc
extern_warprnnt
extern_threadpool extern_threadpool
extern_utf8proc) extern_utf8proc)
include(external/lapack) # download, build, install lapack include(external/lapack) # download, build, install lapack
...@@ -276,6 +278,7 @@ list( ...@@ -276,6 +278,7 @@ list(
extern_zlib extern_zlib
extern_dlpack extern_dlpack
extern_warpctc extern_warpctc
extern_warprnnt
extern_threadpool extern_threadpool
extern_lapack) extern_lapack)
......
...@@ -1291,6 +1291,17 @@ ...@@ -1291,6 +1291,17 @@
kernel : kernel :
func : unstack_grad func : unstack_grad
- backward_op : warprnnt_grad
forward : warprnnt (Tensor input, Tensor label, Tensor input_lengths, Tensor label_lengths, int blank = 0, float fastemit_lambda = 0.0) -> Tensor(loss), Tensor(warprnntgrad)
args : (Tensor input, Tensor input_lengths, Tensor warprnntgrad, Tensor loss_grad, int blank = 0, float fastemit_lambda = 0.0)
output : Tensor(input_grad)
infer_meta :
func : UnchangedInferMeta
param : [input]
kernel :
func : warprnnt_grad
no_need_buffer : input
- backward_op : where_grad - backward_op : where_grad
forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out) forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out)
args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad) args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad)
......
...@@ -1119,6 +1119,17 @@ ...@@ -1119,6 +1119,17 @@
func : viterbi_decode func : viterbi_decode
data_type : potentials data_type : potentials
- op : warprnnt
args : (Tensor input, Tensor label, Tensor input_lengths, Tensor label_lengths, int blank = 0, float fastemit_lambda = 0.0)
output : Tensor(loss), Tensor(warprnntgrad)
infer_meta :
func : WarprnntInferMeta
kernel :
func : warprnnt
data_type: input
intermediate: warprnntgrad
backward : warprnnt_grad
- op : where - op : where
args : (Tensor condition, Tensor x, Tensor y) args : (Tensor condition, Tensor x, Tensor y)
output : Tensor output : Tensor
......
...@@ -62,6 +62,10 @@ if(WITH_ROCM) ...@@ -62,6 +62,10 @@ if(WITH_ROCM)
phi_dynload_warpctc phi_dynload_warpctc
SRCS warpctc.cc SRCS warpctc.cc
DEPS phi_dynamic_loader warpctc) DEPS phi_dynamic_loader warpctc)
cc_library(
phi_dynload_warprnnt
SRCS warprnnt.cc
DEPS phi_dynamic_loader warprnnt)
elseif(WITH_ASCEND_CL) elseif(WITH_ASCEND_CL)
cc_library( cc_library(
phi_dynload_warpctc phi_dynload_warpctc
...@@ -76,6 +80,10 @@ else() ...@@ -76,6 +80,10 @@ else()
phi_dynload_warpctc phi_dynload_warpctc
SRCS warpctc.cc SRCS warpctc.cc
DEPS phi_dynamic_loader warpctc) DEPS phi_dynamic_loader warpctc)
cc_library(
phi_dynload_warprnnt
SRCS warprnnt.cc
DEPS phi_dynamic_loader warprnnt)
endif() endif()
if(WITH_MKLML) if(WITH_MKLML)
cc_library( cc_library(
......
...@@ -470,6 +470,20 @@ void* GetWarpCTCDsoHandle() { ...@@ -470,6 +470,20 @@ void* GetWarpCTCDsoHandle() {
#endif #endif
} }
void* GetWarpRNNTDsoHandle() {
std::string warprnnt_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
warprnnt_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(warprnnt_dir, "libwarprnnt.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(warprnnt_dir, "warprnnt.dll");
#else
return GetDsoHandleFromSearchPath(warprnnt_dir, "libwarprnnt.so");
#endif
}
void* GetNCCLDsoHandle() { void* GetNCCLDsoHandle() {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
std::string warning_msg( std::string warning_msg(
......
...@@ -35,6 +35,7 @@ void* GetCusparseDsoHandle(); ...@@ -35,6 +35,7 @@ void* GetCusparseDsoHandle();
void* GetNVRTCDsoHandle(); void* GetNVRTCDsoHandle();
void* GetCUDADsoHandle(); void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle(); void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle();
void* GetNCCLDsoHandle(); void* GetNCCLDsoHandle();
void* GetHCCLDsoHandle(); void* GetHCCLDsoHandle();
void* GetTensorRtDsoHandle(); void* GetTensorRtDsoHandle();
......
...@@ -39,8 +39,8 @@ extern void* warpctc_dso_handle; ...@@ -39,8 +39,8 @@ extern void* warpctc_dso_handle;
std::call_once(warpctc_dso_flag, []() { \ std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = phi::dynload::GetWarpCTCDsoHandle(); \ warpctc_dso_handle = phi::dynload::GetWarpCTCDsoHandle(); \
}); \ }); \
static void* p_##_name = dlsym(warpctc_dso_handle, #__name); \ static void* p_##__name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \ return reinterpret_cast<warpctcFunc>(p_##__name)(args...); \
} \ } \
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/dynload/warprnnt.h"
namespace phi {
namespace dynload {
std::once_flag warprnnt_dso_flag;
void* warprnnt_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
WARPRNNT_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex> // NOLINT
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
#include "warprnnt/include/rnnt.h"
namespace phi {
namespace dynload {
extern std::once_flag warprnnt_dso_flag;
extern void* warprnnt_dso_handle;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load warprnnt routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_WARPRNNT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using warprnntFunc = decltype(&::__name); \
std::call_once(warprnnt_dso_flag, []() { \
warprnnt_dso_handle = phi::dynload::GetWarpRNNTDsoHandle(); \
}); \
static void* p_##__name = dlsym(warprnnt_dso_handle, #__name); \
return reinterpret_cast<warprnntFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_WARPRNNT_WRAP(__name) \
DYNAMIC_LOAD_WARPRNNT_WRAP(__name)
#define WARPRNNT_ROUTINE_EACH(__macro) \
__macro(get_warprnnt_version); \
__macro(rnntGetStatusString); \
__macro(compute_rnnt_loss); \
__macro(compute_rnnt_loss_fp64); \
__macro(get_rnnt_workspace_size);
WARPRNNT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPRNNT_WRAP);
#undef DYNAMIC_LOAD_WARPRNNT_WRAP
} // namespace dynload
} // namespace phi
...@@ -2759,6 +2759,36 @@ void WarpctcInferMeta(const MetaTensor& logits, ...@@ -2759,6 +2759,36 @@ void WarpctcInferMeta(const MetaTensor& logits,
loss->set_dtype(logits.dtype()); loss->set_dtype(logits.dtype());
} }
void WarprnntInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& input_lengths,
const MetaTensor& label_lengths,
int blank,
float fastemit_lambda,
MetaTensor* loss,
MetaTensor* warpctcgrad) {
auto acts_dims = input.dims();
int D = acts_dims[3];
PADDLE_ENFORCE_GE(
blank,
0,
errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
PADDLE_ENFORCE_LT(
blank,
D,
errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
loss->set_dims({-1});
loss->set_dtype(input.dtype());
}
void WhereInferMeta(const MetaTensor& condition, void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x, const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
......
...@@ -503,6 +503,15 @@ void WarpctcInferMeta(const MetaTensor& logits, ...@@ -503,6 +503,15 @@ void WarpctcInferMeta(const MetaTensor& logits,
MetaTensor* loss, MetaTensor* loss,
MetaTensor* warpctcgrad); MetaTensor* warpctcgrad);
void WarprnntInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& input_lengths,
const MetaTensor& label_lengths,
int blank,
float fastemit_lambda,
MetaTensor* loss,
MetaTensor* warpctcgrad);
void WhereInferMeta(const MetaTensor& condition, void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x, const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
......
...@@ -70,6 +70,7 @@ set(COMMON_KERNEL_DEPS ...@@ -70,6 +70,7 @@ set(COMMON_KERNEL_DEPS
matrix_inverse matrix_inverse
matrix_solve matrix_solve
phi_dynload_warpctc phi_dynload_warpctc
phi_dynload_warprnnt
sequence_padding sequence_padding
sequence_scale sequence_scale
fft fft
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warprnnt_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/warprnnt_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
warprnnt_grad, CPU, ALL_LAYOUT, phi::WarprnntGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warprnnt_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/warprnnt_kernel_impl.h"
PD_REGISTER_KERNEL(
warprnnt, CPU, ALL_LAYOUT, phi::WarprnntKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warprnnt_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/warprnnt_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
warprnnt_grad, GPU, ALL_LAYOUT, phi::WarprnntGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warprnnt_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/warprnnt_kernel_impl.h"
PD_REGISTER_KERNEL(
warprnnt, GPU, ALL_LAYOUT, phi::WarprnntKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void WarprnntGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& input_lengths,
const DenseTensor& warprnntgrad,
const DenseTensor& loss_grad,
int blank,
float fastemit_lambda,
DenseTensor* input_grad) {
dev_ctx.template Alloc<T>(input_grad);
int B = warprnntgrad.dims()[0]; // B
int Tmax = warprnntgrad.dims()[1]; // Tmax
int Umax = warprnntgrad.dims()[2]; // Umax
int D = warprnntgrad.dims()[3]; // D
// (B,)
auto loss_grad_e = EigenTensor<T, 1>::From(loss_grad);
// (B, T, U, D)
auto warprnntgrad_e = EigenTensor<T, 4>::From(warprnntgrad);
auto acts_grad_e = EigenTensor<T, 4>::From(*input_grad);
Eigen::DSizes<int, 4> grad_shape(B, 1, 1, 1);
Eigen::DSizes<int, 4> bcast(1, Tmax, Umax, D);
auto acts_g =
warprnntgrad_e * loss_grad_e.reshape(grad_shape).broadcast(bcast).eval();
auto* place = dev_ctx.eigen_device();
acts_grad_e.device(*place) = acts_g;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/backends/dynload/warprnnt.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename Context, typename T>
class ComputeRnntLossFunctor {
public:
rnntStatus_t operator()(const T* const activations,
T* gradients,
const int* const label,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
T* costs,
void* workspace,
rnntOptions options) {
return RNNT_STATUS_EXECUTION_FAILED;
}
};
template <typename Context>
class ComputeRnntLossFunctor<Context, float> {
public:
rnntStatus_t operator()(const float* const activations,
float* gradients,
const int* const label,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
float* costs,
void* workspace,
rnntOptions options) {
return phi::dynload::compute_rnnt_loss(activations,
gradients,
label,
label_lengths,
input_lengths,
static_cast<int>(alphabet_size),
static_cast<int>(minibatch),
costs,
workspace,
options);
}
};
template <typename Context>
class ComputeRnntLossFunctor<Context, double> {
public:
rnntStatus_t operator()(const double* const activations,
double* gradients,
const int* const label,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
double* costs,
void* workspace,
rnntOptions options) {
return phi::dynload::compute_rnnt_loss_fp64(activations,
gradients,
label,
label_lengths,
input_lengths,
static_cast<int>(alphabet_size),
static_cast<int>(minibatch),
costs,
workspace,
options);
}
};
template <typename Context, typename T>
class WarpRNNTFunctor {
public:
/*
* \brief Compute the RNN-T loss, and optionally compute the gradient
* with respect to the inputs.
*
* If gradient is nullptr, it only computes the rnnt loss,
* or computes both rnnt loss and gradient.
*
* \param ctx execution context of this functor
* \param input batch matrix of input probabilities, in
* (B, Tmax, Umax, D), (row-major) format
* \param gradient batch matrix of gradient, with the same shape as
* input, (B, Tmax, Umax, D)
* \param label label, (B, Umax)
* \param label_lengths length of all label, (B,).
* \param input_lengths length of all sequences, (B,).
* \param D number of vocab symbols, w/ blank.
* \param B number of example.
* \param blank blank label used in rnnt loss function.
* \param cpu_losss loss of each example in CPU memory.
*/
void operator()(const Context& dev_ctx,
const T* input,
T* gradient,
const int* label,
const int* label_lengths,
const int* input_lengths,
const size_t D,
const size_t B,
const size_t maxT,
const size_t maxU,
const int blank,
const float fastemit_lambda,
const int num_threads,
T* cpu_loss) {
// Init warp-rnnt options
init(dev_ctx, maxT, maxU, blank, fastemit_lambda, num_threads);
// Compute the required workspace size.
// There is no memory allocated operations within warp-rnnt.
rnntStatus_t status = RNNT_STATUS_UNKNOWN_ERROR;
bool gpu = false;
if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpu = true;
#else
PADDLE_THROW(errors::PreconditionNotMet(
"[WarpRNNTFunctor Operator] GPU is not enabled."));
#endif
}
size_t workspace_bytes = 0;
status = phi::dynload::get_rnnt_workspace_size(
maxT, maxU, B, gpu, &workspace_bytes, sizeof(T));
PADDLE_ENFORCE_EQ(
RNNT_STATUS_SUCCESS,
status,
errors::PreconditionNotMet(
"warp-rnnt [version %d] Error in get_rnnt_workspace_size: %s",
warprnnt_version_,
phi::dynload::rnntGetStatusString(status)));
PADDLE_ENFORCE_GT(
workspace_bytes,
0UL,
errors::InvalidArgument("Bytes of workspace got by warp-rnnt function, "
"get_rnnt_workspace_size() should be larger "
"than 0, but received %d",
workspace_bytes));
size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL;
DenseTensor workspace = phi::Full<T, Context>(
dev_ctx, {static_cast<int64_t>(workspace_elements)}, static_cast<T>(0));
T* workspace_data = workspace.data<T>();
// compute loss and gradient
status = ComputeRnntLossFunctor<Context, T>()(input,
gradient,
label,
label_lengths,
input_lengths,
static_cast<int>(D),
static_cast<int>(B),
cpu_loss,
workspace_data,
options_);
PADDLE_ENFORCE_EQ(
RNNT_STATUS_SUCCESS,
status,
errors::PreconditionNotMet(
"warp-rnnt [version %d] Error in get_workspace_size: %s",
warprnnt_version_,
phi::dynload::rnntGetStatusString(status)));
}
protected:
void init(const Context& dev_ctx,
const size_t maxT,
const size_t maxU,
const size_t blank,
const float fastemit_lambda,
const int num_threads) {
warprnnt_version_ = phi::dynload::get_warprnnt_version();
options_.maxT = maxT;
options_.maxU = maxU;
options_.blank_label = blank;
options_.fastemit_lambda = fastemit_lambda;
options_.batch_first = true;
if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
options_.loc = RNNT_GPU;
options_.stream =
reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
#else
PADDLE_THROW(
errors::PreconditionNotMet("[warprnnt init] GPU is not enabled."));
#endif
} else {
options_.loc = RNNT_CPU;
options_.num_threads = num_threads;
#ifdef PADDLE_WITH_MKLML
// have to use at least one
options_.num_threads = std::max(options_.num_threads, (unsigned int)1);
#endif
}
}
private:
int warprnnt_version_;
rnntOptions options_;
};
template <typename T, typename Context>
void WarprnntKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& input_lengths,
const DenseTensor& label_lengths,
int blank,
float fastemit_lambda,
DenseTensor* loss,
DenseTensor* warprnntgrad) {
PADDLE_ENFORCE_EQ(
input.dims().size(),
4,
phi::errors::InvalidArgument("The rank of Input(Logits) should be 4 "
"but received %d. ",
input.dims().size()));
PADDLE_ENFORCE_EQ(
label.dims().size(),
2,
phi::errors::InvalidArgument("The rank of Input(Label) should be 2 "
"but received %d. ",
label.dims().size()));
PADDLE_ENFORCE_EQ(input_lengths.dims().size(),
1,
phi::errors::InvalidArgument(
"The rank of Input(LogitsLength) should be 1 "
"but received %d. ",
input_lengths.dims().size()));
PADDLE_ENFORCE_EQ(
label_lengths.dims().size(),
1,
phi::errors::InvalidArgument("The rank of Input(LabelLength) should be 1 "
"but received %d. ",
label_lengths.dims().size()));
size_t B, Tmax, Umax, D;
B = input.dims()[0];
Tmax = input.dims()[1];
Umax = input.dims()[2];
D = input.dims()[3];
PADDLE_ENFORCE_GT(B,
0,
phi::errors::InvalidArgument(
"The first dimension of Input(Logits) is B should be "
"greater than zero "
"but received %d. ",
B));
PADDLE_ENFORCE_GT(Tmax,
0,
phi::errors::InvalidArgument(
"The second dimension of Input(Logits) is T should be "
"greater than zero "
"but received %d. ",
Tmax));
PADDLE_ENFORCE_GT(Umax,
0,
phi::errors::InvalidArgument(
"The third dimension of Input(Logits) is U should be "
"greater than zero "
"but received %d. ",
Umax));
PADDLE_ENFORCE_GT(D,
0,
phi::errors::InvalidArgument(
"The forth dimension of Input(Logits) is D should be "
"greater than zero "
"but received %d. ",
D));
warprnntgrad->Resize(input.dims());
T* warprnntgrad_data = dev_ctx.template Alloc<T>(warprnntgrad);
phi::funcs::SetConstant<Context, T>()(
dev_ctx, warprnntgrad, static_cast<T>(0));
// loss on cpu (B,)
auto loss_dims = phi::make_ddim({static_cast<int64_t>(B)});
DenseTensor warprnnt_loss;
warprnnt_loss.Resize(loss_dims);
T* warprnnt_loss_data = dev_ctx.template HostAlloc<T>(&warprnnt_loss);
WarpRNNTFunctor<Context, T>()(dev_ctx,
input.data<T>(),
warprnntgrad_data,
label.data<int>(),
label_lengths.data<int>(),
input_lengths.data<int>(),
D,
B,
Tmax,
Umax,
blank,
fastemit_lambda,
1 /*num_threads*/,
warprnnt_loss_data);
phi::Copy(dev_ctx, warprnnt_loss, dev_ctx.GetPlace(), true, loss);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void WarprnntGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& input_lengths,
const DenseTensor& warprnntgrad,
const DenseTensor& loss_grad,
int blank,
float fastemit_lambda,
DenseTensor* input_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void WarprnntKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& input_lengths,
const DenseTensor& label_lengths,
int blank,
float fastemit_lambda,
DenseTensor* loss,
DenseTensor* warprnntgrad);
} // namespace phi
...@@ -11,6 +11,7 @@ env_dict={ ...@@ -11,6 +11,7 @@ env_dict={
'WITH_PSLI':'@WITH_PSLI@', 'WITH_PSLI':'@WITH_PSLI@',
'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@',
'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@',
'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@',
'LAPACK_LIB':'@LAPACK_LIB@', 'LAPACK_LIB':'@LAPACK_LIB@',
'GFORTRAN_LIB':'@GFORTRAN_LIB@', 'GFORTRAN_LIB':'@GFORTRAN_LIB@',
'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle import _C_ops
from paddle.fluid import Program, program_guard
paddle.enable_static()
def python_api(
logits,
label,
logits_length,
labels_length,
blank=0,
fastemit_lambda=0.0,
num_threads=1,
):
loss_out = _C_ops.warprnnt(
logits,
label,
logits_length,
labels_length,
blank,
fastemit_lambda,
num_threads,
)
return loss_out
class TestWarpRNNTOp(OpTest):
def set_act(self):
# logsoftmax
self.acts = np.array(
[
[
[
[-1.40493705, -0.68276381, -1.38870219],
[-1.25243963, -1.03148021, -1.02802034],
[-1.19624572, -0.93786934, -1.18347801],
],
[
[-1.03417513, -0.84465814, -1.53815849],
[-0.96884241, -1.01432347, -1.35545407],
[-0.82076925, -1.10135010, -1.48067081],
],
[
[-1.43828803, -1.16579869, -0.79630424],
[-1.38401855, -0.83654478, -1.15129927],
[-1.05188255, -1.29604414, -0.97522265],
],
[
[-1.34330978, -0.86678589, -1.14344457],
[-0.72518815, -1.32106859, -1.39063758],
[-1.09984781, -1.00059987, -1.20590993],
],
],
[
[
[-1.02221057, -1.47617485, -0.88748174],
[-1.18362952, -0.78488945, -1.43689575],
[-1.00784739, -1.28566450, -1.02574476],
],
[
[-1.02589709, -1.13153743, -1.14260096],
[-1.09942215, -1.12238913, -1.07459704],
[-1.09359647, -0.89829379, -1.35585602],
],
[
[-1.07782876, -0.84361953, -1.47178440],
[-1.23424792, -1.00248783, -1.07299990],
[-0.96521771, -1.19895815, -1.14698912],
],
[
[-1.50722446, -1.15380039, -0.76994115],
[-1.19125975, -0.89919308, -1.24041594],
[-0.91301359, -1.19665577, -1.21576258],
],
],
[
[
[-1.02221057, -1.47617485, -0.88748174],
[-1.18362952, -0.78488945, -1.43689575],
[-1.00784739, -1.28566450, -1.02574476],
],
[
[-1.02589709, -1.13153743, -1.14260096],
[-1.09942215, -1.12238913, -1.07459704],
[-1.09359647, -0.89829379, -1.35585602],
],
[
[-1.07782876, -0.84361953, -1.47178440],
[-1.23424792, -1.00248783, -1.07299990],
[-0.96521771, -1.19895815, -1.14698912],
],
[
[-1.50722446, -1.15380039, -0.76994115],
[-1.19125975, -0.89919308, -1.24041594],
[-0.91301359, -1.19665577, -1.21576258],
],
],
],
dtype=np.float32,
)
def set_gradient(self):
self.gradient = np.array(
[
[
[
[-0.43222645, -0.56777355, 0.0],
[-0.3656501, 0.0, -0.20212345],
[-0.20212345, 0.0, 0.0],
],
[
[-0.16521672, -0.26700973, 0.0],
[-0.39436539, 0.0, -0.23829444],
[-0.44041789, 0.0, 0.0],
],
[
[-0.05212979, -0.11308693, 0.0],
[-0.18313787, 0.0, -0.32431445],
[-0.76473234, 0.0, 0.0],
],
[
[0.0, -0.05212979, 0.0],
[0.0, 0.0, -0.23526766],
[-1.0, 0.0, 0.0],
],
],
[
[
[-0.71614241, -0.28385759, 0.0],
[-0.18382932, -0.10002826, 0.0],
[-0.10002826, 0.0, 0.0],
],
[
[-0.41121795, -0.30492447, 0.0],
[-0.32957594, -0.15917785, 0.0],
[-0.25920611, 0.0, 0.0],
],
[
[-0.11607642, -0.29514153, 0.0],
[-0.28653336, -0.3381841, 0.0],
[-0.59739022, 0.0, 0.0],
],
[
[0.0, -0.11607642, 0.0],
[0.0, -0.40260978, 0.0],
[-1.0, 0.0, 0.0],
],
],
[
[
[-0.71614241, -0.28385759, 0.0],
[-0.18382932, -0.10002826, 0.0],
[-0.10002826, 0.0, 0.0],
],
[
[-0.41121795, -0.30492447, 0.0],
[-0.32957594, -0.15917785, 0.0],
[-0.25920611, 0.0, 0.0],
],
[
[-0.11607642, -0.29514153, 0.0],
[-0.28653336, -0.3381841, 0.0],
[-0.59739022, 0.0, 0.0],
],
[
[0.0, -0.11607642, 0.0],
[0.0, -0.40260978, 0.0],
[-1.0, 0.0, 0.0],
],
],
],
dtype=np.float32,
)
def config(self):
self.blank = 0
self.fastemit_lambda = 0.0
self.set_act()
self.labels = np.array([[1, 2], [1, 1], [1, 1]], dtype=np.int32)
self.logit_lens = np.array([4, 4, 4], dtype=np.int32)
self.label_lens = np.array([2, 2, 2], dtype=np.int32)
self.loss = np.array(
[4.2806528590890736, 3.9384369822503591, 3.9384369822503591],
dtype=np.float64,
)
self.set_gradient()
def setUp(self):
self.op_type = "warprnnt"
self.config()
self.python_api = python_api
self.python_out_sig = ["loss"]
self.inputs = {
"input": self.acts,
"label": self.labels,
"input_lengths": self.logit_lens,
"label_lengths": self.label_lens,
}
self.outputs = {"loss": self.loss}
self.attrs = {
"blank": self.blank,
"fastemit_lambda": self.fastemit_lambda,
"num_threads": 1,
}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.outputs["warprnntgrad"] = self.gradient
if core.is_compiled_with_rocm():
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
check_eager=True,
)
else:
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
check_eager=True,
)
class TestWarpRNNTFP64Op(TestWarpRNNTOp):
def test_check_output(self):
self.acts.astype(np.float64)
self.check_output(check_eager=True)
def test_check_grad(self):
self.acts.astype(np.float64)
self.outputs["warprnntgrad"] = self.gradient
if core.is_compiled_with_rocm():
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
check_eager=True,
)
else:
self.check_grad(
["input"],
"loss",
numeric_grad_delta=0.009,
check_eager=True,
)
class TestWarpRNNTOpError(unittest.TestCase):
def test_errors(self):
print("test_errors")
with program_guard(Program(), Program()):
logits = fluid.data(name='input', shape=[5, 16, 6], dtype='float32')
logits_length = fluid.data(
name='logit_lengths', shape=[None], dtype='int32'
)
label = fluid.data(name='labels', shape=[16, 3], dtype='int32')
label_length = fluid.data(
name='label_lengths', shape=[None], dtype='int32'
)
def test_logits_Variable():
logits_data = fluid.data(
name='logits_data', shape=[5, 16, 6], dtype='int32'
)
paddle.nn.functional.rnnt_loss(
input=logits_data,
label=label,
input_lengths=logits_length,
label_lengths=label_length,
)
self.assertRaises(TypeError, test_logits_Variable)
def test_label_Variable():
label_data = fluid.data(
name='label_data', shape=[16, 3], dtype='int64'
)
paddle.nn.functional.rnnt_loss(
input=logits,
label=label_data,
input_lengths=logits_length,
label_lengths=label_length,
)
self.assertRaises(TypeError, test_label_Variable)
def test_logits_len_Variable():
logits_length_data = fluid.data(
name='logits_length_data', shape=[None], dtype='int64'
)
paddle.nn.functional.rnnt_loss(
input=logits,
label=label,
input_lengths=logits_length_data,
label_lengths=label_length,
)
self.assertRaises(TypeError, test_logits_len_Variable)
def test_label_len_Variable():
label_length_data = fluid.data(
name='label_length_data', shape=[None], dtype='int64'
)
paddle.nn.functional.rnnt_loss(
input=logits,
label=label,
input_lengths=logits_length,
label_lengths=label_length_data,
)
self.assertRaises(TypeError, test_label_len_Variable)
def test_dygraph_errors(self):
def test_dygraph_with_lod():
print("test_dygraph_with_lod")
logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32")
# labels should not be blank
labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32")
labels_len = np.random.randint(0, 15 - 1, [15, 1], dtype="int64")
logits_len = np.random.randint(0, 15 - 1, [15, 1], dtype="int32")
softmax = paddle.to_tensor(logits)
labels = paddle.to_tensor(labels)
logits_len = paddle.to_tensor(logits_len)
labels_len = paddle.to_tensor(labels_len)
paddle.nn.functional.rnnt_loss(
input=softmax,
label=labels,
input_lengths=logits_len,
label_lengths=labels_len,
)
paddle.disable_static()
self.assertRaises(ValueError, test_dygraph_with_lod)
paddle.enable_static()
class TestRNNTLossAPICase(unittest.TestCase):
def set_act(self):
# logsoftmax
self.acts = np.array(
[
[
[
[-1.40493705, -0.68276381, -1.38870219],
[-1.25243963, -1.03148021, -1.02802034],
[-1.19624572, -0.93786934, -1.18347801],
],
[
[-1.03417513, -0.84465814, -1.53815849],
[-0.96884241, -1.01432347, -1.35545407],
[-0.82076925, -1.10135010, -1.48067081],
],
[
[-1.43828803, -1.16579869, -0.79630424],
[-1.38401855, -0.83654478, -1.15129927],
[-1.05188255, -1.29604414, -0.97522265],
],
[
[-1.34330978, -0.86678589, -1.14344457],
[-0.72518815, -1.32106859, -1.39063758],
[-1.09984781, -1.00059987, -1.20590993],
],
],
[
[
[-1.02221057, -1.47617485, -0.88748174],
[-1.18362952, -0.78488945, -1.43689575],
[-1.00784739, -1.28566450, -1.02574476],
],
[
[-1.02589709, -1.13153743, -1.14260096],
[-1.09942215, -1.12238913, -1.07459704],
[-1.09359647, -0.89829379, -1.35585602],
],
[
[-1.07782876, -0.84361953, -1.47178440],
[-1.23424792, -1.00248783, -1.07299990],
[-0.96521771, -1.19895815, -1.14698912],
],
[
[-1.50722446, -1.15380039, -0.76994115],
[-1.19125975, -0.89919308, -1.24041594],
[-0.91301359, -1.19665577, -1.21576258],
],
],
[
[
[-1.02221057, -1.47617485, -0.88748174],
[-1.18362952, -0.78488945, -1.43689575],
[-1.00784739, -1.28566450, -1.02574476],
],
[
[-1.02589709, -1.13153743, -1.14260096],
[-1.09942215, -1.12238913, -1.07459704],
[-1.09359647, -0.89829379, -1.35585602],
],
[
[-1.07782876, -0.84361953, -1.47178440],
[-1.23424792, -1.00248783, -1.07299990],
[-0.96521771, -1.19895815, -1.14698912],
],
[
[-1.50722446, -1.15380039, -0.76994115],
[-1.19125975, -0.89919308, -1.24041594],
[-0.91301359, -1.19665577, -1.21576258],
],
],
],
dtype=np.float32,
)
def config(self):
self.blank = 0
self.fastemit_lambda = 0.0
self.set_act()
self.labels = np.array([[1, 2], [1, 1], [1, 1]], dtype=np.int32)
self.logit_lens = np.array([4, 4, 4], dtype=np.int32)
self.label_lens = np.array([2, 2, 2], dtype=np.int32)
self.loss = np.array(
[4.2806528590890736, 3.9384369822503591, 3.9384369822503591],
dtype=np.float64,
)
def test_functinal_api(self):
self.config()
paddle.disable_static()
acts = paddle.to_tensor(self.acts)
labels = paddle.to_tensor(self.labels)
logit_lens = paddle.to_tensor(self.logit_lens)
label_lens = paddle.to_tensor(self.label_lens)
loss_pd_mean = paddle.nn.functional.rnnt_loss(
acts,
labels,
logit_lens,
label_lens,
blank=self.blank,
reduction='mean',
fastemit_lambda=self.fastemit_lambda,
)
loss_pd_mean = loss_pd_mean.numpy()
loss_pd_sum = paddle.nn.functional.rnnt_loss(
acts,
labels,
logit_lens,
label_lens,
blank=self.blank,
reduction='sum',
fastemit_lambda=self.fastemit_lambda,
)
loss_pd_sum = loss_pd_sum.numpy()
paddle.enable_static()
B = self.loss.shape[0]
loss_np_mean = self.loss.sum() / B
loss_np_sum = self.loss.sum()
np.testing.assert_allclose(
loss_pd_mean, loss_np_mean, rtol=1e-05, atol=1
)
np.testing.assert_allclose(loss_pd_sum, loss_np_sum, rtol=1e-05, atol=1)
def test_class_api(self):
self.config()
paddle.disable_static()
acts = paddle.to_tensor(self.acts)
labels = paddle.to_tensor(self.labels)
logit_lens = paddle.to_tensor(self.logit_lens)
label_lens = paddle.to_tensor(self.label_lens)
loss_pd = paddle.nn.RNNTLoss(self.blank, self.fastemit_lambda, 'none')(
acts, labels, logit_lens, label_lens
)
loss_pd = loss_pd.numpy()
paddle.enable_static()
np.testing.assert_allclose(loss_pd, self.loss, rtol=1e-05, atol=1)
if __name__ == "__main__":
unittest.main()
...@@ -75,6 +75,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [ ...@@ -75,6 +75,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'trilinear_interp_v2', 'trilinear_interp_v2',
'var_conv_2d', 'var_conv_2d',
'warpctc', 'warpctc',
'warprnnt',
'bilateral_slice', 'bilateral_slice',
'cast', 'cast',
] ]
......
...@@ -106,6 +106,7 @@ from .layer.loss import KLDivLoss # noqa: F401 ...@@ -106,6 +106,7 @@ from .layer.loss import KLDivLoss # noqa: F401
from .layer.loss import MarginRankingLoss # noqa: F401 from .layer.loss import MarginRankingLoss # noqa: F401
from .layer.loss import MultiLabelSoftMarginLoss from .layer.loss import MultiLabelSoftMarginLoss
from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import RNNTLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401 from .layer.loss import CosineEmbeddingLoss # noqa: F401
...@@ -285,6 +286,7 @@ __all__ = [ # noqa ...@@ -285,6 +286,7 @@ __all__ = [ # noqa
'Silu', 'Silu',
'Conv2DTranspose', 'Conv2DTranspose',
'CTCLoss', 'CTCLoss',
'RNNTLoss',
'ThresholdedReLU', 'ThresholdedReLU',
'AdaptiveAvgPool2D', 'AdaptiveAvgPool2D',
'MaxPool1D', 'MaxPool1D',
......
...@@ -90,6 +90,7 @@ from .loss import softmax_with_cross_entropy # noqa: F401 ...@@ -90,6 +90,7 @@ from .loss import softmax_with_cross_entropy # noqa: F401
from .loss import margin_cross_entropy # noqa: F401 from .loss import margin_cross_entropy # noqa: F401
from .loss import square_error_cost # noqa: F401 from .loss import square_error_cost # noqa: F401
from .loss import ctc_loss # noqa: F401 from .loss import ctc_loss # noqa: F401
from .loss import rnnt_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401 from .loss import cosine_embedding_loss # noqa: F401
from .loss import multi_margin_loss from .loss import multi_margin_loss
...@@ -220,6 +221,7 @@ __all__ = [ # noqa ...@@ -220,6 +221,7 @@ __all__ = [ # noqa
'margin_cross_entropy', 'margin_cross_entropy',
'square_error_cost', 'square_error_cost',
'ctc_loss', 'ctc_loss',
'rnnt_loss',
'hinge_embedding_loss', 'hinge_embedding_loss',
'affine_grid', 'affine_grid',
'grid_sample', 'grid_sample',
......
...@@ -1757,7 +1757,7 @@ def ctc_loss( ...@@ -1757,7 +1757,7 @@ def ctc_loss(
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64. label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0. blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``. reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
norm_by_times (bool, default False) Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'. norm_by_times (bool, default False): Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
Returns: Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``. Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
...@@ -1895,6 +1895,130 @@ def ctc_loss( ...@@ -1895,6 +1895,130 @@ def ctc_loss(
return loss_out return loss_out
def rnnt_loss(
input,
label,
input_lengths,
label_lengths,
blank=0,
fastemit_lambda=0.001,
reduction='mean',
name=None,
):
"""
An operator integrating the open source Warp-Transducer library (https://github.com/b-flo/warp-transducer.git)
to compute Sequence Transduction with Recurrent Neural Networks (RNN-T) loss.
Parameters:
input (Tensor): The logprobs sequence with padding, which is a 4-D Tensor. The tensor shape is [B, Tmax, Umax, D], where Tmax, is the longest length of input logit sequence. The data type should be float32 or float64.
label (Tensor): The ground truth sequence with padding, which must be a 2-D Tensor. The tensor shape is [B, Umax], where Umax is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
blank (int, optional): The blank label index of RNN-T loss, which is in the half-opened interval [0, B). The data type must be int32. Default is 0.
fastemit_lambda (float, default 0.001): Regularization parameter for FastEmit (https://arxiv.org/pdf/2010.11148.pdf)
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output will be sum of loss and be divided by the batch_size; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, The RNN-T loss between ``logprobs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``logprobs``.
Examples:
.. code-block:: python
# declarative mode
import paddle.nn.functional as F
import numpy as np
import paddle
import functools
fn = functools.partial(F.rnnt_loss, reduction='sum', fastemit_lambda=0.0, blank=0)
acts = np.array([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
labels = [[1, 2]]
acts = paddle.to_tensor(acts, stop_gradient=False)
lengths = [acts.shape[1]] * acts.shape[0]
label_lengths = [len(l) for l in labels]
labels = paddle.to_tensor(labels, paddle.int32)
lengths = paddle.to_tensor(lengths, paddle.int32)
label_lengths = paddle.to_tensor(label_lengths, paddle.int32)
costs = fn(acts, labels, lengths, label_lengths)
print(costs)
# Tensor(shape=[1], dtype=float64, place=Place(gpu:0), stop_gradient=False,
# [4.49566677])
"""
def warprnnt(
input, label, input_length, label_length, blank=0, fastemit_lambda=0.001
):
if in_dygraph_mode():
loss_out = _C_ops.warprnnt(
input,
label,
input_length,
label_length,
blank,
fastemit_lambda,
)
return loss_out
helper = LayerHelper('warprnnt', **locals())
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], "warprnnt"
)
check_variable_and_dtype(label, 'label', ['int32'], "warprnnt")
check_variable_and_dtype(
input_length, 'input_lengths', ['int32'], "warprnnt"
)
check_variable_and_dtype(
label_length, 'label_lengths', ['int32'], "warprnnt"
)
this_inputs = {
'input': [input],
'label': [label],
'input_lengths': [input_length],
'label_lengths': [label_length],
}
loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='warprnnt',
inputs=this_inputs,
outputs={'warprnntgrad': [grad_out], 'loss': [loss_out]},
attrs={
'blank': blank,
'fastemit_lambda': fastemit_lambda,
},
)
return loss_out
B = input.shape[0]
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
# (B,)
loss_out = warprnnt(
input, label, input_lengths, label_lengths, blank, fastemit_lambda
)
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.sum(loss_out, name=name) / B
elif reduction == 'sum':
loss_out = paddle.sum(loss_out, name=name)
return loss_out
def margin_cross_entropy( def margin_cross_entropy(
logits, logits,
label, label,
......
...@@ -77,6 +77,7 @@ from .loss import KLDivLoss # noqa: F401 ...@@ -77,6 +77,7 @@ from .loss import KLDivLoss # noqa: F401
from .loss import MarginRankingLoss # noqa: F401 from .loss import MarginRankingLoss # noqa: F401
from .loss import MultiLabelSoftMarginLoss from .loss import MultiLabelSoftMarginLoss
from .loss import CTCLoss # noqa: F401 from .loss import CTCLoss # noqa: F401
from .loss import RNNTLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401 from .loss import SmoothL1Loss # noqa: F401
from .loss import HingeEmbeddingLoss # noqa: F401 from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss from .loss import TripletMarginWithDistanceLoss
......
...@@ -1121,6 +1121,79 @@ class CTCLoss(Layer): ...@@ -1121,6 +1121,79 @@ class CTCLoss(Layer):
) )
class RNNTLoss(Layer):
"""
Parameters:
blank (int, optional): blank label. Default: 0.
fastemit_lambda (float, optional): Regularization parameter for FastEmit (https://arxiv.org/pdf/2010.11148.pdf)
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
Shape:
input: logprob Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
label: 2 dimensional (batch, labelLength) Tensor containing all the targets of the batch with zero padded
input_lengths: Tensor of size (batch) containing size of each output sequence from the network
label_lengths: Tensor of (batch) containing label length of each example
Returns:
Tensor, The RNN-T loss between ``logprobs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``logprobs``.
Examples:
.. code-block:: python
# declarative mode
import numpy as np
import paddle
from paddle.nn import RNNTLoss
fn = RNNTLoss(reduction='sum', fastemit_lambda=0.0)
acts = np.array([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
labels = [[1, 2]]
acts = paddle.to_tensor(acts, stop_gradient=False)
lengths = [acts.shape[1]] * acts.shape[0]
label_lengths = [len(l) for l in labels]
labels = paddle.to_tensor(labels, paddle.int32)
lengths = paddle.to_tensor(lengths, paddle.int32)
label_lengths = paddle.to_tensor(label_lengths, paddle.int32)
costs = fn(acts, labels, lengths, label_lengths)
print(costs)
# Tensor(shape=[1], dtype=float64, place=Place(gpu:0), stop_gradient=False,
# [4.49566677])
"""
def __init__(
self, blank=0, fastemit_lambda=0.001, reduction='mean', name=None
):
super().__init__()
self.blank = blank
self.reduction = reduction
self.fastemit_lambda = fastemit_lambda
self.name = name
def forward(self, input, label, input_lengths, label_lengths):
return paddle.nn.functional.rnnt_loss(
input,
label,
input_lengths,
label_lengths,
blank=self.blank,
fastemit_lambda=self.fastemit_lambda,
reduction=self.reduction,
name=self.name,
)
class SmoothL1Loss(Layer): class SmoothL1Loss(Layer):
r""" r"""
This operator calculates smooth_l1_loss. Creates a criterion that uses a squared This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
......
...@@ -456,8 +456,12 @@ package_dir={ ...@@ -456,8 +456,12 @@ package_dir={
libs_path='${PADDLE_BINARY_DIR}/python/paddle/libs' libs_path='${PADDLE_BINARY_DIR}/python/paddle/libs'
package_data['paddle.libs']= [] package_data['paddle.libs']= []
package_data['paddle.libs']=[('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_name] package_data['paddle.libs']=[
('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_name,
('libwarprnnt' if os.name != 'nt' else 'warprnnt') + ext_name,
]
shutil.copy('${WARPCTC_LIBRARIES}', libs_path) shutil.copy('${WARPCTC_LIBRARIES}', libs_path)
shutil.copy('${WARPRNNT_LIBRARIES}', libs_path)
package_data['paddle.libs']+=[ package_data['paddle.libs']+=[
os.path.basename('${LAPACK_LIB}'), os.path.basename('${LAPACK_LIB}'),
......
...@@ -767,9 +767,11 @@ def get_package_data_and_package_dir(): ...@@ -767,9 +767,11 @@ def get_package_data_and_package_dir():
libs_path = paddle_binary_dir + '/python/paddle/libs' libs_path = paddle_binary_dir + '/python/paddle/libs'
package_data['paddle.libs'] = [] package_data['paddle.libs'] = []
package_data['paddle.libs'] = [ package_data['paddle.libs'] = [
('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_suffix ('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_suffix,
('libwarprnnt' if os.name != 'nt' else 'warprnnt') + ext_suffix,
] ]
shutil.copy(env_dict.get("WARPCTC_LIBRARIES"), libs_path) shutil.copy(env_dict.get("WARPCTC_LIBRARIES"), libs_path)
shutil.copy(env_dict.get("WARPRNNT_LIBRARIES"), libs_path)
package_data['paddle.libs'] += [ package_data['paddle.libs'] += [
os.path.basename(env_dict.get("LAPACK_LIB")), os.path.basename(env_dict.get("LAPACK_LIB")),
os.path.basename(env_dict.get("BLAS_LIB")), os.path.basename(env_dict.get("BLAS_LIB")),
...@@ -962,7 +964,7 @@ def get_package_data_and_package_dir(): ...@@ -962,7 +964,7 @@ def get_package_data_and_package_dir():
package_dir['paddle.libs'] = libs_path package_dir['paddle.libs'] = libs_path
# change rpath of ${FLUID_CORE_NAME}.ext, add $ORIGIN/../libs/ to it. # change rpath of ${FLUID_CORE_NAME}.ext, add $ORIGIN/../libs/ to it.
# The reason is that libwarpctc.ext, libiomp5.ext etc are in paddle.libs, and # The reason is that libwarpctc.ext, libwarprnnt.ext, libiomp5.ext etc are in paddle.libs, and
# ${FLUID_CORE_NAME}.ext is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries. # ${FLUID_CORE_NAME}.ext is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries.
# This operation will fix https://github.com/PaddlePaddle/Paddle/issues/3213 # This operation will fix https://github.com/PaddlePaddle/Paddle/issues/3213
if env_dict.get("CMAKE_BUILD_TYPE") == 'Release': if env_dict.get("CMAKE_BUILD_TYPE") == 'Release':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册