提交 a816443e 编写于 作者: L Liu Yiqun

Add submodule warp-ctc.

上级 5a97c98d
[submodule "warp-ctc"]
path = warp-ctc
url = https://github.com/baidu-research/warp-ctc.git
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
sha: c25201a00e6b0514370501050cf2a8538ac12270 sha: c25201a00e6b0514370501050cf2a8538ac12270
hooks: hooks:
- id: remove-crlf - id: remove-crlf
files: (?!.*warp-ctc)^.*$
- repo: https://github.com/reyoung/mirrors-yapf.git - repo: https://github.com/reyoung/mirrors-yapf.git
sha: v0.13.2 sha: v0.13.2
hooks: hooks:
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
- id: check-merge-conflict - id: check-merge-conflict
- id: check-symlinks - id: check-symlinks
- id: detect-private-key - id: detect-private-key
files: (?!.*warp-ctc)^.*$
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29 sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
......
...@@ -58,6 +58,6 @@ void GetCurandDsoHandle(void** dso_handle); ...@@ -58,6 +58,6 @@ void GetCurandDsoHandle(void** dso_handle);
* @param **dso_handle dso handler * @param **dso_handle dso handler
* *
*/ */
void GetWarpctcDsoHandle(void** dso_handle); void GetWarpCTCDsoHandle(void** dso_handle);
#endif // HL_DSO_LOADER_H_ #endif // HL_DSO_LOADER_H_
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#define HL_WARPCTC_WRAP_H_ #define HL_WARPCTC_WRAP_H_
#include "hl_base.h" #include "hl_base.h"
/// #include "hl_cuda.h"
#include "warp-ctc/include/ctc.h" #include "warp-ctc/include/ctc.h"
typedef ctcStatus_t hl_warpctc_status_t; typedef ctcStatus_t hl_warpctc_status_t;
......
...@@ -463,31 +463,19 @@ void KeSequence2BatchPadding(real* batch, ...@@ -463,31 +463,19 @@ void KeSequence2BatchPadding(real* batch,
int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth; int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth;
int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth; int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth;
real scale = normByTimes ? (1.0f / (real)sequenceLength) : 1.0f;
if (sequenceIdx < sequenceLength) { if (sequenceIdx < sequenceLength) {
if (seq2batch) { if (seq2batch) {
/* sequence -> batch */ /* sequence -> batch */
if (normByTimes) {
real scale = 1.0f / (real)sequenceLength;
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i]; batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i];
} }
} else {
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = sequence[sequenceBaseIdx + i];
}
}
} else { } else {
/* batch -> sequence */ /* batch -> sequence */
if (normByTimes) {
real scale = 1.0f / (real)sequenceLength;
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) { for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i]; sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i];
} }
} else {
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = batch[batchBaseIdx + i];
}
}
} }
} else if (sequenceIdx < maxSequenceLength) { } else if (sequenceIdx < maxSequenceLength) {
if (seq2batch) { if (seq2batch) {
......
...@@ -163,7 +163,7 @@ void GetCurandDsoHandle(void** dso_handle) { ...@@ -163,7 +163,7 @@ void GetCurandDsoHandle(void** dso_handle) {
#endif #endif
} }
void GetWarpctcDsoHandle(void** dso_handle) { void GetWarpCTCDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle); GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle);
#else #else
......
...@@ -30,32 +30,32 @@ void* warpctc_dso_handle = nullptr; ...@@ -30,32 +30,32 @@ void* warpctc_dso_handle = nullptr;
* the linked-libs of paddle or to LD_PRELOAD. * the linked-libs of paddle or to LD_PRELOAD.
*/ */
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name, __type) \ #define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
__type operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
typedef __type (*warpctcFunc)(Args...); \ using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
std::call_once( \ std::call_once( \
warpctc_dso_flag, GetWarpctcDsoHandle, &warpctc_dso_handle); \ warpctc_dso_flag, GetWarpCTCDsoHandle, &warpctc_dso_handle); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \ void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \ return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \ } \
} __name; // struct DynLoad__##__name } __name; // struct DynLoad__##__name
#else #else
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name, __type) \ #define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
__type operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \ return __name(args...); \
} \ } \
} __name; // struct DynLoad__##__name } __name; // struct DynLoad__##__name
#endif #endif
// include all needed warp-ctc functions // include all needed warp-ctc functions
DYNAMIC_LOAD_WARPCTC_WRAP(get_warpctc_version, int) DYNAMIC_LOAD_WARPCTC_WRAP(get_warpctc_version)
DYNAMIC_LOAD_WARPCTC_WRAP(ctcGetStatusString, const char*) DYNAMIC_LOAD_WARPCTC_WRAP(ctcGetStatusString)
DYNAMIC_LOAD_WARPCTC_WRAP(compute_ctc_loss, hl_warpctc_status_t) DYNAMIC_LOAD_WARPCTC_WRAP(compute_ctc_loss)
DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size, hl_warpctc_status_t) DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size)
#undef DYNAMIC_LOAD_WARPCTC_WRAP #undef DYNAMIC_LOAD_WARPCTC_WRAP
......
...@@ -100,8 +100,8 @@ void WarpCTCLayer::forward(PassType passType) { ...@@ -100,8 +100,8 @@ void WarpCTCLayer::forward(PassType passType) {
/* labels always in CPU memory */ /* labels always in CPU memory */
Matrix::resizeOrCreate(cpuCosts_, Matrix::resizeOrCreate(cpuCosts_,
/* width */ numSequences, /* height */ numSequences,
/* height */ 1, /* width */ 1,
/* trans */ false, /* trans */ false,
/* useGpu */ false); /* useGpu */ false);
...@@ -209,17 +209,11 @@ void WarpCTCLayer::batch2seqPadding(const MatrixPtr& seqValue, ...@@ -209,17 +209,11 @@ void WarpCTCLayer::batch2seqPadding(const MatrixPtr& seqValue,
int sequenceStart = seqStartPositionsData[i]; int sequenceStart = seqStartPositionsData[i];
int sequenceLength = int sequenceLength =
seqStartPositionsData[i + 1] - seqStartPositionsData[i]; seqStartPositionsData[i + 1] - seqStartPositionsData[i];
real scale = normByTimes ? (1.0f / (real)sequenceLength) : 1.0f;
for (int j = 0; j < sequenceLength; j++) { for (int j = 0; j < sequenceLength; j++) {
if (normByTimes) {
for (size_t k = 0; k < numClasses_; k++) { for (size_t k = 0; k < numClasses_; k++) {
seqData[(sequenceStart + j) * numClasses_ + k] = seqData[(sequenceStart + j) * numClasses_ + k] =
batchData[(j * numSequences + i) * numClasses_ + k] / batchData[(j * numSequences + i) * numClasses_ + k] * scale;
sequenceLength;
}
} else {
memcpy(seqData + (sequenceStart + j) * numClasses_,
batchData + (j * numSequences + i) * numClasses_,
numClasses_ * sizeof(real));
} }
} }
} }
......
...@@ -30,7 +30,7 @@ P_DECLARE_bool(use_gpu); ...@@ -30,7 +30,7 @@ P_DECLARE_bool(use_gpu);
const real* getData(const Matrix& matrix) { const real* getData(const Matrix& matrix) {
if (matrix.useGpu()) { if (matrix.useGpu()) {
MatrixPtr cpuMatrix = Matrix::create( MatrixPtr cpuMatrix = Matrix::create(
matrix.getWidth(), matrix.getHeight(), matrix.isTransposed(), false); matrix.getHeight(), matrix.getWidth(), matrix.isTransposed(), false);
cpuMatrix->copyFrom(matrix); cpuMatrix->copyFrom(matrix);
return cpuMatrix->getData(); return cpuMatrix->getData();
} else { } else {
...@@ -200,12 +200,13 @@ LayerPtr createWarpCTCLayer(string name, ...@@ -200,12 +200,13 @@ LayerPtr createWarpCTCLayer(string name,
TEST(Layer, WarpCTCLayer) { TEST(Layer, WarpCTCLayer) {
for (auto layerSize : {10, 64, 128}) { for (auto layerSize : {10, 64, 128}) {
for (auto batchSize : {1, 10, 20, 64}) { for (auto batchSize : {1, 10, 20, 64}) {
for (auto normByTimes : {false, true}) {
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU #ifdef PADDLE_ONLY_CPU
if (useGpu) continue; if (useGpu) continue;
#endif #endif
LOG(INFO) << " layerSize=" << layerSize << " batchSize=" << batchSize LOG(INFO) << " layerSize=" << layerSize << " batchSize=" << batchSize
<< " useGpu=" << useGpu; << " normByTimes = " << normByTimes << " useGpu=" << useGpu;
FLAGS_use_gpu = useGpu; FLAGS_use_gpu = useGpu;
...@@ -224,9 +225,9 @@ TEST(Layer, WarpCTCLayer) { ...@@ -224,9 +225,9 @@ TEST(Layer, WarpCTCLayer) {
createLabelLayer("label", batchSize, layerSize, useGpu); createLabelLayer("label", batchSize, layerSize, useGpu);
LayerPtr warpctcLayer = createWarpCTCLayer( LayerPtr warpctcLayer = createWarpCTCLayer(
"cost", layerSize, useGpu, false, dataLayer0, labelLayer); "cost", layerSize, useGpu, normByTimes, dataLayer0, labelLayer);
LayerPtr ctcLayer = createCTCLayer( LayerPtr ctcLayer = createCTCLayer(
"cost", layerSize, useGpu, false, dataLayer1, labelLayer); "cost", layerSize, useGpu, normByTimes, dataLayer1, labelLayer);
/// Check loss /// Check loss
checkError(*(warpctcLayer->getOutput().value), checkError(*(warpctcLayer->getOutput().value),
...@@ -238,6 +239,7 @@ TEST(Layer, WarpCTCLayer) { ...@@ -238,6 +239,7 @@ TEST(Layer, WarpCTCLayer) {
} }
} }
} }
}
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
......
Subproject commit bd535c8d44e03c8ebd2d768e06c8c05fdccd11d2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册