提交 a816443e 编写于 作者: L Liu Yiqun

Add submodule warp-ctc.

上级 5a97c98d
[submodule "warp-ctc"]
path = warp-ctc
url = https://github.com/baidu-research/warp-ctc.git
......@@ -2,6 +2,7 @@
sha: c25201a00e6b0514370501050cf2a8538ac12270
hooks:
- id: remove-crlf
files: (?!.*warp-ctc)^.*$
- repo: https://github.com/reyoung/mirrors-yapf.git
sha: v0.13.2
hooks:
......@@ -13,6 +14,7 @@
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*warp-ctc)^.*$
- id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
......
......@@ -58,6 +58,6 @@ void GetCurandDsoHandle(void** dso_handle);
* @param **dso_handle dso handler
*
*/
void GetWarpctcDsoHandle(void** dso_handle);
void GetWarpCTCDsoHandle(void** dso_handle);
#endif // HL_DSO_LOADER_H_
......@@ -16,7 +16,6 @@ limitations under the License. */
#define HL_WARPCTC_WRAP_H_
#include "hl_base.h"
/// #include "hl_cuda.h"
#include "warp-ctc/include/ctc.h"
typedef ctcStatus_t hl_warpctc_status_t;
......
......@@ -463,30 +463,18 @@ void KeSequence2BatchPadding(real* batch,
int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth;
int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth;
real scale = normByTimes ? (1.0f / (real)sequenceLength) : 1.0f;
if (sequenceIdx < sequenceLength) {
if (seq2batch) {
/* sequence -> batch */
if (normByTimes) {
real scale = 1.0f / (real)sequenceLength;
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i];
}
} else {
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = sequence[sequenceBaseIdx + i];
}
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i];
}
} else {
/* batch -> sequence */
if (normByTimes) {
real scale = 1.0f / (real)sequenceLength;
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i];
}
} else {
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = batch[batchBaseIdx + i];
}
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i];
}
}
} else if (sequenceIdx < maxSequenceLength) {
......
......@@ -163,7 +163,7 @@ void GetCurandDsoHandle(void** dso_handle) {
#endif
}
void GetWarpctcDsoHandle(void** dso_handle) {
void GetWarpCTCDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle);
#else
......
......@@ -30,32 +30,32 @@ void* warpctc_dso_handle = nullptr;
* the linked-libs of paddle or to LD_PRELOAD.
*/
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name, __type) \
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
__type operator()(Args... args) { \
typedef __type (*warpctcFunc)(Args...); \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
std::call_once( \
warpctc_dso_flag, GetWarpctcDsoHandle, &warpctc_dso_handle); \
warpctc_dso_flag, GetWarpCTCDsoHandle, &warpctc_dso_handle); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \
} __name; // struct DynLoad__##__name
#else
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name, __type) \
struct DynLoad__##__name { \
template <typename... Args> \
__type operator()(Args... args) { \
return __name(args...); \
} \
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \
} \
} __name; // struct DynLoad__##__name
#endif
// include all needed warp-ctc functions
DYNAMIC_LOAD_WARPCTC_WRAP(get_warpctc_version, int)
DYNAMIC_LOAD_WARPCTC_WRAP(ctcGetStatusString, const char*)
DYNAMIC_LOAD_WARPCTC_WRAP(compute_ctc_loss, hl_warpctc_status_t)
DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size, hl_warpctc_status_t)
DYNAMIC_LOAD_WARPCTC_WRAP(get_warpctc_version)
DYNAMIC_LOAD_WARPCTC_WRAP(ctcGetStatusString)
DYNAMIC_LOAD_WARPCTC_WRAP(compute_ctc_loss)
DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size)
#undef DYNAMIC_LOAD_WARPCTC_WRAP
......
......@@ -100,8 +100,8 @@ void WarpCTCLayer::forward(PassType passType) {
/* labels always in CPU memory */
Matrix::resizeOrCreate(cpuCosts_,
/* width */ numSequences,
/* height */ 1,
/* height */ numSequences,
/* width */ 1,
/* trans */ false,
/* useGpu */ false);
......@@ -209,17 +209,11 @@ void WarpCTCLayer::batch2seqPadding(const MatrixPtr& seqValue,
int sequenceStart = seqStartPositionsData[i];
int sequenceLength =
seqStartPositionsData[i + 1] - seqStartPositionsData[i];
real scale = normByTimes ? (1.0f / (real)sequenceLength) : 1.0f;
for (int j = 0; j < sequenceLength; j++) {
if (normByTimes) {
for (size_t k = 0; k < numClasses_; k++) {
seqData[(sequenceStart + j) * numClasses_ + k] =
batchData[(j * numSequences + i) * numClasses_ + k] /
sequenceLength;
}
} else {
memcpy(seqData + (sequenceStart + j) * numClasses_,
batchData + (j * numSequences + i) * numClasses_,
numClasses_ * sizeof(real));
for (size_t k = 0; k < numClasses_; k++) {
seqData[(sequenceStart + j) * numClasses_ + k] =
batchData[(j * numSequences + i) * numClasses_ + k] * scale;
}
}
}
......
......@@ -30,7 +30,7 @@ P_DECLARE_bool(use_gpu);
const real* getData(const Matrix& matrix) {
if (matrix.useGpu()) {
MatrixPtr cpuMatrix = Matrix::create(
matrix.getWidth(), matrix.getHeight(), matrix.isTransposed(), false);
matrix.getHeight(), matrix.getWidth(), matrix.isTransposed(), false);
cpuMatrix->copyFrom(matrix);
return cpuMatrix->getData();
} else {
......@@ -200,41 +200,43 @@ LayerPtr createWarpCTCLayer(string name,
TEST(Layer, WarpCTCLayer) {
for (auto layerSize : {10, 64, 128}) {
for (auto batchSize : {1, 10, 20, 64}) {
for (auto useGpu : {false, true}) {
for (auto normByTimes : {false, true}) {
for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU
if (useGpu) continue;
if (useGpu) continue;
#endif
LOG(INFO) << " layerSize=" << layerSize << " batchSize=" << batchSize
<< " useGpu=" << useGpu;
LOG(INFO) << " layerSize=" << layerSize << " batchSize=" << batchSize
<< " normByTimes = " << normByTimes << " useGpu=" << useGpu;
FLAGS_use_gpu = useGpu;
FLAGS_use_gpu = useGpu;
Argument data0;
initArgument(batchSize, layerSize, useGpu, data0);
Argument data0;
initArgument(batchSize, layerSize, useGpu, data0);
Argument data1;
data1.resizeAndCopyFrom(data0);
Argument data1;
data1.resizeAndCopyFrom(data0);
LayerPtr dataLayer0 =
createDataLayer("data", batchSize, layerSize, useGpu, data0);
LayerPtr dataLayer1 =
createDataLayer("data", batchSize, layerSize, useGpu, data1);
LayerPtr dataLayer0 =
createDataLayer("data", batchSize, layerSize, useGpu, data0);
LayerPtr dataLayer1 =
createDataLayer("data", batchSize, layerSize, useGpu, data1);
LayerPtr labelLayer =
createLabelLayer("label", batchSize, layerSize, useGpu);
LayerPtr labelLayer =
createLabelLayer("label", batchSize, layerSize, useGpu);
LayerPtr warpctcLayer = createWarpCTCLayer(
"cost", layerSize, useGpu, false, dataLayer0, labelLayer);
LayerPtr ctcLayer = createCTCLayer(
"cost", layerSize, useGpu, false, dataLayer1, labelLayer);
LayerPtr warpctcLayer = createWarpCTCLayer(
"cost", layerSize, useGpu, normByTimes, dataLayer0, labelLayer);
LayerPtr ctcLayer = createCTCLayer(
"cost", layerSize, useGpu, normByTimes, dataLayer1, labelLayer);
/// Check loss
checkError(*(warpctcLayer->getOutput().value),
*(ctcLayer->getOutput().value));
/// Check loss
checkError(*(warpctcLayer->getOutput().value),
*(ctcLayer->getOutput().value));
/// Check gradients
checkError(*(dataLayer0->getOutput().grad),
*(dataLayer1->getOutput().grad));
/// Check gradients
checkError(*(dataLayer0->getOutput().grad),
*(dataLayer1->getOutput().grad));
}
}
}
}
......
Subproject commit bd535c8d44e03c8ebd2d768e06c8c05fdccd11d2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册