提交 4d487c6f 编写于 作者: L Liu Yiqun

Integrate warp-ctc as WarpCTCLayer, including unitest and layer interface.

上级 9e65ceed
...@@ -94,6 +94,11 @@ endif() ...@@ -94,6 +94,11 @@ endif()
if(NOT WITH_GPU) if(NOT WITH_GPU)
add_definitions(-DPADDLE_ONLY_CPU) add_definitions(-DPADDLE_ONLY_CPU)
add_definitions(-DHPPL_STUB_FUNC) add_definitions(-DHPPL_STUB_FUNC)
if(WITH_DSO)
add_definitions(-DPADDLE_USE_DSO)
endif(WITH_DSO)
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
else() else()
if(${CUDA_VERSION_MAJOR} GREATER 6) if(${CUDA_VERSION_MAJOR} GREATER 6)
......
...@@ -148,6 +148,11 @@ function(link_paddle_exe TARGET_NAME) ...@@ -148,6 +148,11 @@ function(link_paddle_exe TARGET_NAME)
target_link_libraries(${TARGET_NAME} rt) target_link_libraries(${TARGET_NAME} rt)
endif() endif()
endif() endif()
if(NOT WITH_DSO)
target_link_libraries(${TARGET_NAME}
${WARPCTC_LIBRARY})
endif()
endfunction() endfunction()
# link_paddle_test # link_paddle_test
......
...@@ -15,20 +15,29 @@ else() ...@@ -15,20 +15,29 @@ else()
endif() endif()
set(CUDA_CXX_WITH_GPU_SOURCES set(CUDA_CXX_WITH_GPU_SOURCES
src/hl_cudart_wrap.cc
src/hl_cuda_cublas.cc src/hl_cuda_cublas.cc
src/hl_cuda_cudnn.cc src/hl_cuda_cudnn.cc
src/hl_cuda_device.cc) src/hl_cuda_device.cc
)
set_source_files_properties(${CUDA_CXX_WITH_GPU_SOURCES} if(WITH_GPU)
PROPERTIES COMPILE_FLAGS "-D__NVCC__") set(CUDA_CXX_SOURCES
src/hl_dso_loader.cc
src/hl_warpctc_wrap.cc
${CUDA_CXX_WITH_GPU_SOURCES})
set_source_files_properties(${CUDA_CXX_SOURCES}
PROPERTIES COMPILE_FLAGS "-D__NVCC__")
else()
set(CUDA_CXX_SOURCES
src/hl_dso_loader.cc
src/hl_warpctc_wrap.cc)
endif()
set_source_files_properties(${AVX_SOURCES} set_source_files_properties(${AVX_SOURCES}
PROPERTIES COMPILE_FLAGS "-mavx") PROPERTIES COMPILE_FLAGS "-mavx")
set(CUDA_DSO_SOURCES
src/hl_dso_loader.cc
src/hl_cudart_wrap.cc)
set(CUDA_CU_SOURCES set(CUDA_CU_SOURCES
src/hl_perturbation_util.cu src/hl_perturbation_util.cu
src/hl_cuda_aggregate.cu src/hl_cuda_aggregate.cu
...@@ -44,6 +53,7 @@ set(CUDA_CU_SOURCES ...@@ -44,6 +53,7 @@ set(CUDA_CU_SOURCES
set(CUDA_HEADERS set(CUDA_HEADERS
include/hl_time.h include/hl_time.h
include/hl_dso_loader.h include/hl_dso_loader.h
include/hl_warpctc_wrap.h
include/hl_sequence.h include/hl_sequence.h
include/hl_cuda_cublas.h include/hl_cuda_cublas.h
include/hl_batch_transpose.h include/hl_batch_transpose.h
...@@ -75,14 +85,14 @@ if(WITH_GPU) ...@@ -75,14 +85,14 @@ if(WITH_GPU)
cuda_add_library(paddle_cuda cuda_add_library(paddle_cuda
${CUDA_SOURCES} ${CUDA_SOURCES}
${CUDA_CU_SOURCES} ${CUDA_CU_SOURCES}
${CUDA_DSO_SOURCES} ${CUDA_CXX_SOURCES})
${CUDA_CXX_WITH_GPU_SOURCES})
else() else()
add_library(paddle_cuda ${CUDA_SOURCES}) add_library(paddle_cuda
${CUDA_SOURCES}
${CUDA_CXX_SOURCES})
endif() endif()
add_style_check_target(paddle_cuda add_style_check_target(paddle_cuda
${CUDA_SOURCES} ${CUDA_SOURCES}
${CUDA_HEADERS} ${CUDA_HEADERS}
${CUDA_DSO_SOURCES} ${CUDA_CXX_SOURCES})
${CUDA_CXX_WITH_GPU_SOURCES})
...@@ -18,10 +18,6 @@ limitations under the License. */ ...@@ -18,10 +18,6 @@ limitations under the License. */
#include <dlfcn.h> #include <dlfcn.h>
#include <string> #include <string>
#include <memory> #include <memory>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <curand.h>
#include <cudnn.h>
#include "hl_base.h" #include "hl_base.h"
/** /**
...@@ -56,4 +52,12 @@ void GetCudartDsoHandle(void** dso_handle); ...@@ -56,4 +52,12 @@ void GetCudartDsoHandle(void** dso_handle);
*/ */
void GetCurandDsoHandle(void** dso_handle); void GetCurandDsoHandle(void** dso_handle);
/**
* @brief load the DSO of warp-ctc
*
* @param **dso_handle dso handler
*
*/
void GetWarpctcDsoHandle(void** dso_handle);
#endif // HL_DSO_LOADER_H_ #endif // HL_DSO_LOADER_H_
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "hl_sparse.h" #include "hl_sparse.h"
#include "hl_lstm.h" #include "hl_lstm.h"
#include "hl_sequence.h" #include "hl_sequence.h"
#include "hl_warpctc_wrap.h"
#ifdef HPPL_STUB_FUNC #ifdef HPPL_STUB_FUNC
#include "stub/hl_cuda_stub.h" #include "stub/hl_cuda_stub.h"
......
...@@ -172,6 +172,39 @@ extern void hl_sequence2batch_add(real* batch, ...@@ -172,6 +172,39 @@ extern void hl_sequence2batch_add(real* batch,
int batchCount, int batchCount,
bool seq2batch); bool seq2batch);
/**
* @brief Memory copy from sequence to batch,
* while padding all sequences to the same length.
*
* if seq2batch == true
*
* copy from sequence to batch:
* batch[i] = sequence[sequenceStartPositions[i]]
*
* if seq2batch == false
*
* copy from batch to sequence:
* sequence[sequenceStartPositions[i]] = batch[i]
*
* @param[in,out] batch batch matrix.
* @param[in,out] sequence sequence matrix.
* @param[in] sequenceStartPositions index vector.
* @param[in] sequenceWidth width of sequence.
* @param[in] maxSequenceLength maximum length of sequences.
* @param[in] numSequences number of sequences.
* @param[in] normByTimes whether dividing sequence's length.
* @param[in] seq2batch copy direction.
*
*/
extern void hl_sequence2batch_copy_padding(real* batch,
real* sequence,
const int* sequenceStartPositions,
const size_t sequenceWidth,
const size_t maxSequenceLength,
const size_t numSequences,
bool normByTimes,
bool seq2batch);
/** /**
* @brief dst = Op(src), src is sequence. * @brief dst = Op(src), src is sequence.
* *
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_WARPCTC_WRAP_H_
#define HL_WARPCTC_WRAP_H_
#include "hl_base.h"
/// #include "hl_cuda.h"
#include "warp-ctc/include/ctc.h"
typedef ctcStatus_t hl_warpctc_status_t;
typedef ctcOptions hl_warpctc_options_t;
/**
* @brief Init ctc options.
*
* @param[in] blank blank label used in ctc loss function.
* @param[in] useGpu whether use gpu.
* @param[out] options handle to store cpu or gpu informations.
*
*/
extern void hl_warpctc_init(const size_t blank,
bool useGpu,
hl_warpctc_options_t* options);
/**
* @brief Compute the connectionist temporal classification loss,
* and optionally compute the gradient with respect to the inputs.
*
* if batchGrad == nullptr
*
* only compute the ctc loss.
*
* if batchGrad != nullptr
*
* compute both ctc loss and gradient.
*
* @param[in] batchInput batch matrix of input probabilities,
* in maxSequenceLength x numSequence x numClasses
* (row-major) format.
* @param[out] batchGrad batch matrix of gradient.
* @param[in] cpuLabels labels always in CPU memory.
* @param[in] cpuLabelLengths length of all labels in CPU memory.
* @param[in] cpuInputLengths length of all sequences in CPU memory.
* @param[in] numClasses number of possible output symbols.
* @param[in] numSequences number of sequence.
* @param[out] cpuCosts cost of each sequence in CPU memory.
* @param[out] workspace workspace to store some temporary results.
* @param[in] options handle to store cpu or gpu informations.
*
*/
extern void hl_warpctc_compute_loss(const real* batchInput,
real* batchGrad,
const int* cpuLabels,
const int* cpuLabelLengths,
const int* cpuInputLengths,
const size_t numClasses,
const size_t numSequences,
real* cpuCosts,
void* workspace,
hl_warpctc_options_t* options);
/**
* @brief Compute the required workspace size.
* There is no memory allocated operations within warp-ctc.
*
* @param[in] cpuLabelLengths length of all labels in CPU memory.
* @param[in] cpuInputLengths length of all sequences in CPU memory.
* @param[in] numClasses number of possible output symbols.
* @param[in] numSequences number of sequence.
* @param[in] options handle to store cpu or gpu informations.
* @param[out] bytes pointer to a scalar where the memory
* requirement in bytes will be placed.
*
*/
extern void hl_warpctc_get_workspace_size(const int* cpuLabelLengths,
const int* cpuInputLengths,
const size_t numClasses,
const size_t numSequences,
hl_warpctc_options_t* options,
size_t* bytes);
#endif // HL_WARPCTC_WRAP_H_
...@@ -70,6 +70,15 @@ inline void hl_sequence2batch_add(real* batch, ...@@ -70,6 +70,15 @@ inline void hl_sequence2batch_add(real* batch,
int batchCount, int batchCount,
bool seq2batch) {} bool seq2batch) {}
inline void hl_sequence2batch_copy_padding(real* batch,
real* sequence,
const int* sequenceStartPositions,
const size_t sequenceWidth,
const size_t maxSequenceLength,
const size_t numSequences,
bool normByTimes,
bool seq2batch) {}
inline void hl_sequence_avg_forward(real* dst, inline void hl_sequence_avg_forward(real* dst,
real* src, real* src,
const int* starts, const int* starts,
......
...@@ -447,6 +447,124 @@ void hl_sequence2batch_add(real *batch, ...@@ -447,6 +447,124 @@ void hl_sequence2batch_add(real *batch,
CHECK_SYNC("hl_sequence2batch_add failed"); CHECK_SYNC("hl_sequence2batch_add failed");
} }
template<bool normByTimes, bool seq2batch>
__global__
void KeSequence2BatchPadding(real* batch,
real* sequence,
const int* sequenceStartPositions,
const size_t sequenceWidth,
const size_t maxSequenceLength,
const size_t numSequences) {
int batchIdx = blockIdx.y;
int sequenceStart = sequenceStartPositions[batchIdx];
int sequenceLength = sequenceStartPositions[batchIdx + 1] - sequenceStart;
int sequenceIdx = blockIdx.x * blockDim.y + threadIdx.y;
int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth;
int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth;
if (sequenceIdx < sequenceLength) {
if (seq2batch) {
/* sequence -> batch */
if (normByTimes) {
real scale = 1.0f / (real)sequenceLength;
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i];
}
} else {
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = sequence[sequenceBaseIdx + i];
}
}
} else {
/* batch -> sequence */
if (normByTimes) {
real scale = 1.0f / (real)sequenceLength;
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i];
}
} else {
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
sequence[sequenceBaseIdx + i] = batch[batchBaseIdx + i];
}
}
}
} else if (sequenceIdx < maxSequenceLength) {
if (seq2batch) {
/* sequence -> batch */
for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
batch[batchBaseIdx + i] = 0;
}
}
}
}
void hl_sequence2batch_copy_padding(real* batch,
real* sequence,
const int* sequenceStartPositions,
const size_t sequenceWidth,
const size_t maxSequenceLength,
const size_t numSequences,
bool normByTimes,
bool seq2batch) {
CHECK_NOTNULL(batch);
CHECK_NOTNULL(sequence);
CHECK_NOTNULL(sequenceStartPositions);
if (!normByTimes && numSequences == 1) {
size_t elementCount = maxSequenceLength * sequenceWidth;
if (seq2batch) {
/* sequence -> batch */
hl_memcpy_device2device(batch, sequence, sizeof(real) * elementCount);
} else {
/* batch -> sequence */
hl_memcpy_device2device(sequence, batch, sizeof(real) * elementCount);
}
return;
}
const int CUDA_BLOCK_SIZE = 512;
/* At least use 32 threads to copy sequenceWidth elements,
and at least 8 elements for each thread. */
int blockDimX = ((((sequenceWidth + 7) >> 3) + 31) >> 5) << 5;
blockDimX = (blockDimX < CUDA_BLOCK_SIZE) ? blockDimX : CUDA_BLOCK_SIZE;
int blockDimY = CUDA_BLOCK_SIZE / blockDimX;
dim3 threads(blockDimX, blockDimY);
int gridDimX = (maxSequenceLength * blockDimX + CUDA_BLOCK_SIZE - 1) /
CUDA_BLOCK_SIZE;
int gridDimY = numSequences;
dim3 grid(gridDimX, gridDimY);
if (seq2batch) {
/* sequence -> batch */
if (normByTimes) {
KeSequence2BatchPadding<1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
} else {
KeSequence2BatchPadding<0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
}
} else {
/* batch -> sequence */
if (normByTimes) {
KeSequence2BatchPadding<1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
} else {
KeSequence2BatchPadding<0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
batch, sequence, sequenceStartPositions,
sequenceWidth, maxSequenceLength, numSequences);
}
}
CHECK_SYNC("hl_sequence2batch_copy_padding failed");
}
__device__ inline float my_rsqrt(float x) { __device__ inline float my_rsqrt(float x) {
return rsqrtf(x); return rsqrtf(x);
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#include <mutex> #include <mutex>
#include <cuda_runtime.h>
#include "hl_dso_loader.h" #include "hl_dso_loader.h"
/** /**
......
...@@ -30,6 +30,8 @@ P_DEFINE_string(cuda_dir, ...@@ -30,6 +30,8 @@ P_DEFINE_string(cuda_dir,
"build-in function in cudart already ran before main entry). " "build-in function in cudart already ran before main entry). "
"If default, dlopen will search cuda from LD_LIBRARY_PATH"); "If default, dlopen will search cuda from LD_LIBRARY_PATH");
P_DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
static inline std::string join(const std::string& part1, static inline std::string join(const std::string& part1,
const std::string& part2) { const std::string& part2) {
// directory separator // directory separator
...@@ -92,27 +94,28 @@ static inline void GetDsoHandleFromSearchPath(const std::string& search_root, ...@@ -92,27 +94,28 @@ static inline void GetDsoHandleFromSearchPath(const std::string& search_root,
*dso_handle = dlopen(dlPath.c_str(), dynload_flags); *dso_handle = dlopen(dlPath.c_str(), dynload_flags);
// if not found, search from default path // if not found, search from default path
if (nullptr == *dso_handle) { if (nullptr == *dso_handle) {
LOG(WARNING) << "Failed to find cuda library: " << dlPath; LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " ("
<< dlerror() << ")";
dlPath = dso_name; dlPath = dso_name;
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags);
} }
} }
CHECK(nullptr != *dso_handle) << "Failed to find cuda library: " << dlPath CHECK(nullptr != *dso_handle) << "Failed to find dynamic library: " << dlPath
<< std::endl << " (" << dlerror() << ") \n"
<< "Please specify its path correctly using " << "Please specify its path correctly using "
"one of the following ways: \n" // NOLINT "one of the following ways: \n"
<< "Method 1. set cuda and cudnn lib path at " << "Method 1. set cuda and cudnn lib path at "
"runtime. " "runtime. "
<< "http://www.paddlepaddle.org/doc/ui/" << "http://www.paddlepaddle.org/doc/ui/"
"cmd_argument/" "cmd_argument/"
"argument_outline.html \n" // NOLINT "argument_outline.html \n"
<< "For instance, issue command: paddle train " << "For instance, issue command: paddle train "
"--use_gpu=1 " "--use_gpu=1 "
<< "--cuda_dir=/usr/local/cuda/lib64 " << "--cuda_dir=/usr/local/cuda/lib64 "
"--cudnn_dir=/usr/local/cudnn/lib " "--cudnn_dir=/usr/local/cudnn/lib "
"...\n" // NOLINT "...\n"
<< "Method 2. set environment variable " << "Method 2. set environment variable "
"LD_LIBRARY_PATH on Linux or " "LD_LIBRARY_PATH on Linux or "
...@@ -124,7 +127,7 @@ static inline void GetDsoHandleFromSearchPath(const std::string& search_root, ...@@ -124,7 +127,7 @@ static inline void GetDsoHandleFromSearchPath(const std::string& search_root,
"DYLD_LIBRARY_PATH is impossible " "DYLD_LIBRARY_PATH is impossible "
<< "unless System Integrity Protection (SIP) " << "unless System Integrity Protection (SIP) "
"is disabled. However, " "is disabled. However, "
"method 1 " // NOLINT "method 1 "
<< "always work well."; << "always work well.";
} }
...@@ -159,3 +162,11 @@ void GetCurandDsoHandle(void** dso_handle) { ...@@ -159,3 +162,11 @@ void GetCurandDsoHandle(void** dso_handle) {
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle); GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle);
#endif #endif
} }
void GetWarpctcDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle);
#endif
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <mutex>
#include "hl_warpctc_wrap.h"
#include "hl_dso_loader.h"
#include "paddle/utils/Logging.h"
namespace dynload {
std::once_flag warpctc_dso_flag;
void* warpctc_dso_handle = nullptr;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load warpctc routine
* via operator overloading. When PADDLE_USE_DSO is
* false, you need to add the path of libwarp-ctc.so to
* the linked-libs of paddle or to LD_PRELOAD.
*/
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name, __type) \
struct DynLoad__##__name { \
template <typename... Args> \
__type operator()(Args... args) { \
typedef __type (*warpctcFunc)(Args...); \
std::call_once( \
warpctc_dso_flag, GetWarpctcDsoHandle, &warpctc_dso_handle); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \
} __name; // struct DynLoad__##__name
#else
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name, __type) \
struct DynLoad__##__name { \
template <typename... Args> \
__type operator()(Args... args) { \
return __name(args...); \
} \
} __name; // struct DynLoad__##__name
#endif
// include all needed warp-ctc functions
DYNAMIC_LOAD_WARPCTC_WRAP(get_warpctc_version, int)
DYNAMIC_LOAD_WARPCTC_WRAP(ctcGetStatusString, const char*)
DYNAMIC_LOAD_WARPCTC_WRAP(compute_ctc_loss, hl_warpctc_status_t)
DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size, hl_warpctc_status_t)
#undef DYNAMIC_LOAD_WARPCTC_WRAP
} /* namespace dynload */
#define WARPCTC_GET_VERSION dynload::get_warpctc_version
#define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString
#ifndef PADDLE_TYPE_DOUBLE
#define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss
#define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size
#else
#define WARPCTC_LOG_FATAL \
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion \
<< "] Error: not support double precision."
#define WARPCTC_COMPUTE_LOSS(...) WARPCTC_LOG_FATAL(__VA_ARGS__)
#define WARPCTC_GET_WORKSPACE_SIZE(...) WARPCTC_LOG_FATAL(__VA_ARGS__)
#endif
/**
* Check build-in warp-ctc function using glog and it also
* support << operator for more details error info.
*/
static int g_warpctcVersion = -1;
#define CHECK_WARPCTC(warpctcStat) \
CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \
<< "warp-ctc [version " << g_warpctcVersion \
<< "] Error: " << WARPCTC_GET_STATUS_STRING(warpctcStat) << " "
void hl_warpctc_init(const size_t blank,
bool useGpu,
hl_warpctc_options_t* options) {
CHECK_NOTNULL(options);
g_warpctcVersion = WARPCTC_GET_VERSION();
if (useGpu) {
#ifdef __NVCC__
options->loc = CTC_GPU;
options->stream = STREAM_DEFAULT;
#else
LOG(FATAL) << "[warpctc init] GPU is not enabled.";
#endif
} else {
options->loc = CTC_CPU;
options->num_threads = 1;
}
options->blank_label = blank;
}
void hl_warpctc_compute_loss(const real* batchInput,
real* batchGrad,
const int* cpuLabels,
const int* cpuLabelLengths,
const int* cpuInputLengths,
const size_t numClasses,
const size_t numSequences,
real* cpuCosts,
void* workspace,
hl_warpctc_options_t* options) {
CHECK_NOTNULL(batchInput);
CHECK_NOTNULL(cpuLabels);
CHECK_NOTNULL(cpuLabelLengths);
CHECK_NOTNULL(cpuInputLengths);
CHECK_NOTNULL(cpuCosts);
CHECK_NOTNULL(workspace);
CHECK_NOTNULL(options);
CHECK_WARPCTC(WARPCTC_COMPUTE_LOSS(batchInput,
batchGrad,
cpuLabels,
cpuLabelLengths,
cpuInputLengths,
numClasses,
numSequences,
cpuCosts,
workspace,
*options));
}
void hl_warpctc_get_workspace_size(const int* cpuLabelLengths,
const int* cpuInputLengths,
const size_t numClasses,
const size_t numSequences,
hl_warpctc_options_t* options,
size_t* bytes) {
CHECK_NOTNULL(cpuLabelLengths);
CHECK_NOTNULL(cpuInputLengths);
CHECK_NOTNULL(options);
CHECK_NOTNULL(bytes);
CHECK_WARPCTC(WARPCTC_GET_WORKSPACE_SIZE(cpuLabelLengths,
cpuInputLengths,
numClasses,
numSequences,
*options,
bytes));
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "WarpCTCLayer.h"
namespace paddle {
REGISTER_LAYER(warp_ctc, WarpCTCLayer);
bool WarpCTCLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parament class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2UL);
/* The inputLayers_[0] must be sequence output without softmax */
numClasses_ = config_.size();
CHECK_GE(numClasses_, 2UL);
CHECK_EQ(numClasses_, inputLayers_[0]->getSize());
blank_ = config_.blank();
CHECK_GE(blank_, 0UL);
CHECK_LT(blank_, numClasses_);
normByTimes_ = config_.norm_by_times();
// We don't need sequenceStartPositions because each sample of output_ is
// for the cost of one sequence.
setNeedSequenceInfo(false);
return true;
}
void WarpCTCLayer::forward(PassType passType) {
Layer::forward(passType);
const Argument& output = getInput(0);
const Argument& labels = getInput(1);
CHECK(output.sequenceStartPositions);
CHECK(labels.sequenceStartPositions);
CHECK(labels.ids);
size_t numSequences = labels.sequenceStartPositions->getSize() - 1;
CHECK_EQ(numSequences, output.sequenceStartPositions->getSize() - 1);
resizeOutput(numSequences, 1);
const int* cpuLabelStartPositions =
labels.sequenceStartPositions->getData(false);
const int* cpuOutputStartPositions =
output.sequenceStartPositions->getData(false);
std::vector<int> cpuLabelLengths(numSequences);
std::vector<int> cpuOutputLengths(numSequences);
for (size_t i = 0; i < numSequences; i++) {
cpuLabelLengths[i] =
cpuLabelStartPositions[i + 1] - cpuLabelStartPositions[i];
cpuOutputLengths[i] =
cpuOutputStartPositions[i + 1] - cpuOutputStartPositions[i];
}
/* Get the maximum sequence length */
maxSequenceLength_ = 0;
maxSequenceLength_ = *std::max_element(
cpuOutputLengths.data(), cpuOutputLengths.data() + numSequences);
Matrix::resizeOrCreate(batchValue_,
/* height */ numSequences * maxSequenceLength_,
/* width */ numClasses_,
/* trans */ false,
/* useGpu */ useGpu_);
Matrix::resizeOrCreate(batchGrad_,
/* height */ numSequences * maxSequenceLength_,
/* width */ numClasses_,
/* trans */ false,
/* useGpu */ useGpu_);
batchGrad_->zeroMem();
seq2batchPadding(output.value, batchValue_, output.sequenceStartPositions);
/* labels always in CPU memory */
IVector::resizeOrCreate(cpuLabels_,
/* size */ (labels.ids)->getSize(),
/* useGpu */ false);
cpuLabels_->copyFrom(*(labels.ids));
/* labels always in CPU memory */
Matrix::resizeOrCreate(cpuCosts_,
/* width */ numSequences,
/* height */ 1,
/* trans */ false,
/* useGpu */ false);
/* Init warp-ctc options */
hl_warpctc_options_t options;
hl_warpctc_init(blank_, useGpu_, &options);
/* Get the needed workspace size */
size_t workspaceBytes = 0;
hl_warpctc_get_workspace_size(cpuLabelLengths.data(),
cpuOutputLengths.data(),
numClasses_,
numSequences,
&options,
&workspaceBytes);
CHECK_GT(workspaceBytes, 0UL);
size_t workspaceLength = workspaceBytes / sizeof(real) + 1;
Vector::resizeOrCreate(workspace_,
/* size */ workspaceLength,
/* useGpu */ useGpu_);
hl_warpctc_compute_loss(batchValue_->getData(),
batchGrad_->getData(),
cpuLabels_->getData(),
cpuLabelLengths.data(),
cpuOutputLengths.data(),
numClasses_,
numSequences,
cpuCosts_->getData(),
workspace_->getData(),
&options);
/* Copy the costs */
output_.value->copyFrom(*cpuCosts_);
}
void WarpCTCLayer::backward(const UpdateCallback& callback) {
(void)callback;
const Argument& output = getInput(0);
CHECK(batchGrad_);
batch2seqPadding(
output.grad, batchGrad_, output.sequenceStartPositions, normByTimes_);
}
void WarpCTCLayer::seq2batchPadding(const MatrixPtr& seqValue,
MatrixPtr& batchValue,
const ICpuGpuVectorPtr& seqStartPositions) {
size_t numSequences = seqStartPositions->getSize() - 1;
const int* seqStartPositionsData = seqStartPositions->getData(useGpu_);
real* seqData = seqValue->getData();
real* batchData = batchValue->getData();
if (useGpu_) {
hl_sequence2batch_copy_padding(batchData,
seqData,
seqStartPositionsData,
numClasses_,
maxSequenceLength_,
numSequences,
false,
true);
} else {
for (size_t i = 0; i < maxSequenceLength_; i++) {
for (size_t j = 0; j < numSequences; j++) {
size_t sequenceStart = seqStartPositionsData[j];
size_t sequenceLength =
seqStartPositionsData[j + 1] - seqStartPositionsData[j];
if (i < sequenceLength) {
memcpy(batchData + (i * numSequences + j) * numClasses_,
seqData + (sequenceStart + i) * numClasses_,
numClasses_ * sizeof(real));
} else {
memset(batchData + (i * numSequences + j) * numClasses_,
0,
numClasses_ * sizeof(real));
}
}
}
}
}
void WarpCTCLayer::batch2seqPadding(const MatrixPtr& seqValue,
MatrixPtr& batchValue,
const ICpuGpuVectorPtr& seqStartPositions,
bool normByTimes) {
size_t numSequences = seqStartPositions->getSize() - 1;
const int* seqStartPositionsData = seqStartPositions->getData(useGpu_);
real* seqData = seqValue->getData();
real* batchData = batchValue->getData();
if (useGpu_) {
hl_sequence2batch_copy_padding(batchData,
seqData,
seqStartPositionsData,
numClasses_,
maxSequenceLength_,
numSequences,
normByTimes,
false);
} else {
for (size_t i = 0; i < numSequences; i++) {
int sequenceStart = seqStartPositionsData[i];
int sequenceLength =
seqStartPositionsData[i + 1] - seqStartPositionsData[i];
for (int j = 0; j < sequenceLength; j++) {
if (normByTimes) {
for (size_t k = 0; k < numClasses_; k++) {
seqData[(sequenceStart + j) * numClasses_ + k] =
batchData[(j * numSequences + i) * numClasses_ + k] /
sequenceLength;
}
} else {
memcpy(seqData + (sequenceStart + j) * numClasses_,
batchData + (j * numSequences + i) * numClasses_,
numClasses_ * sizeof(real));
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* @brief A layer integrating the open-source warp-ctc library
* <https://github.com/baidu-research/warp-ctc> to compute connectionist
* temporal classification cost.
*
* The config file api is warp_ctc_layer.
*/
class WarpCTCLayer : public Layer {
public:
explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {}
~WarpCTCLayer() {}
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
virtual void forward(PassType passType);
virtual void backward(const UpdateCallback& callback);
protected:
/**
* sequence matrix and batch matrix copy:
* sequence (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* batch (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
*/
void seq2batchPadding(const MatrixPtr& seqValue,
MatrixPtr& batchValue,
const ICpuGpuVectorPtr& seqStartPositions);
void batch2seqPadding(const MatrixPtr& seqValue,
MatrixPtr& batchValue,
const ICpuGpuVectorPtr& seqStartPositions,
bool normByTimes);
protected:
size_t numClasses_;
size_t blank_;
size_t maxSequenceLength_;
bool normByTimes_;
MatrixPtr batchValue_;
MatrixPtr batchGrad_;
VectorPtr workspace_;
IVectorPtr cpuLabels_;
MatrixPtr cpuCosts_;
};
} // namespace paddle
...@@ -62,6 +62,13 @@ add_unittest(test_RecurrentLayer ...@@ -62,6 +62,13 @@ add_unittest(test_RecurrentLayer
test_RecurrentLayer.cpp test_RecurrentLayer.cpp
TestUtil.cpp) TestUtil.cpp)
############### test_WarpCTCLayer #######################
if(NOT WITH_DOUBLE)
add_unittest(test_WarpCTCLayer
test_WarpCTCLayer.cpp
TestUtil.cpp)
endif()
############### test_RecurrentGradientMachine ############### ############### test_RecurrentGradientMachine ###############
# TODO(yuyang18): There is some bug in test_RecurrentGradientMachine # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine
# I will fix it. # I will fix it.
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/utils/Version.h>
#include "paddle/gserver/layers/Layer.h"
#include "paddle/gserver/layers/DataLayer.h"
#include "paddle/gserver/layers/CTCLayer.h"
#include "paddle/gserver/layers/WarpCTCLayer.h"
#include "ModelConfig.pb.h"
#include "TestUtil.h"
using namespace paddle; // NOLINT
using namespace std; // NOLINT
P_DECLARE_bool(use_gpu);
const real* getData(const Matrix& matrix) {
if (matrix.useGpu()) {
MatrixPtr cpuMatrix = Matrix::create(
matrix.getWidth(), matrix.getHeight(), matrix.isTransposed(), false);
cpuMatrix->copyFrom(matrix);
return cpuMatrix->getData();
} else {
return matrix.getData();
}
}
void checkError(const Matrix& matrix1, const Matrix& matrix2) {
CHECK_EQ(matrix1.getHeight(), matrix2.getHeight());
CHECK_EQ(matrix1.getWidth(), matrix2.getWidth());
CHECK_EQ(matrix1.isTransposed(), matrix2.isTransposed());
#ifndef PADDLE_TYPE_DOUBLE
real err = 1e-3;
#else
real err = 1e-10;
#endif
int height = matrix1.getHeight();
int width = matrix1.getWidth();
const real* data1 = getData(matrix1);
const real* data2 = getData(matrix2);
int count = 0;
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
if (fabs(data1[i * width + j] - data2[i * width + j]) > err) {
count++;
}
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different element.";
}
void initArgument(size_t batchSize,
int layerSize,
bool useGpu,
Argument& data) {
data.value = Matrix::create(batchSize, layerSize, false, useGpu);
data.grad = Matrix::create(batchSize, layerSize, false, useGpu);
data.value->randomizeUniform();
data.value->add(-0.5);
/// data.value->sigmoid(*data.value);
data.grad->zeroMem();
generateSequenceStartPositions(batchSize, data.sequenceStartPositions);
}
LayerPtr createDataLayer(
string name, size_t batchSize, int layerSize, bool useGpu, Argument& data) {
LayerConfig layerConfig;
layerConfig.set_name(name);
layerConfig.set_type("data");
layerConfig.set_size(layerSize);
LayerPtr layer = LayerPtr(new DataLayer(layerConfig));
DataLayerPtr dataLayer = std::dynamic_pointer_cast<DataLayer>(layer);
dataLayer->setData(data);
dataLayer->forward(PASS_GC);
/// std::cout << "dataLayer: " << std::endl;
/// (dataLayer->getOutput().value)->print(std::cout);
return layer;
}
LayerPtr createLabelLayer(string name,
size_t batchSize,
size_t numClasses,
bool useGpu) {
LayerConfig layerConfig;
layerConfig.set_name(name);
layerConfig.set_type("data");
layerConfig.set_size(1);
LayerPtr layer = LayerPtr(new DataLayer(layerConfig));
Argument data;
data.ids = IVector::create(batchSize, useGpu);
data.ids->rand(numClasses - 1);
generateSequenceStartPositions(batchSize, data.sequenceStartPositions);
DataLayerPtr labelLayer = std::dynamic_pointer_cast<DataLayer>(layer);
labelLayer->setData(data);
labelLayer->forward(PASS_GC);
return layer;
}
LayerPtr createCTCLayer(string name,
size_t numClasses,
bool useGpu,
bool normByTimes,
LayerPtr dataLayer,
LayerPtr labelLayer) {
LayerMap layerMap;
layerMap[dataLayer->getName()] = dataLayer;
layerMap[labelLayer->getName()] = labelLayer;
ParameterMap parameterMap;
LayerConfig layerConfig;
layerConfig.set_name(name);
layerConfig.set_type("ctc");
layerConfig.set_size(numClasses);
layerConfig.set_norm_by_times(normByTimes);
layerConfig.add_inputs();
LayerInputConfig& input0 = *(layerConfig.mutable_inputs(0));
input0.set_input_layer_name(dataLayer->getName());
layerConfig.add_inputs();
LayerInputConfig& input1 = *(layerConfig.mutable_inputs(1));
input1.set_input_layer_name(labelLayer->getName());
LayerPtr layer = LayerPtr(new CTCLayer(layerConfig));
layerMap[layer->getName()] = layer;
layer->init(layerMap, parameterMap);
ActivationFunction* softmaxActivation = ActivationFunction::create("softmax");
softmaxActivation->forward(dataLayer->getOutput());
layer->forward(PASS_GC);
layer->backward();
softmaxActivation->backward(dataLayer->getOutput());
return layer;
}
LayerPtr createWarpCTCLayer(string name,
size_t numClasses,
bool useGpu,
bool normByTimes,
LayerPtr dataLayer,
LayerPtr labelLayer) {
LayerMap layerMap;
layerMap[dataLayer->getName()] = dataLayer;
layerMap[labelLayer->getName()] = labelLayer;
ParameterMap parameterMap;
LayerConfig layerConfig;
layerConfig.set_name(name);
layerConfig.set_type("warp_ctc");
layerConfig.set_size(numClasses);
layerConfig.set_blank(numClasses - 1);
layerConfig.set_norm_by_times(normByTimes);
layerConfig.add_inputs();
LayerInputConfig& input0 = *(layerConfig.mutable_inputs(0));
input0.set_input_layer_name(dataLayer->getName());
layerConfig.add_inputs();
LayerInputConfig& input1 = *(layerConfig.mutable_inputs(1));
input1.set_input_layer_name(labelLayer->getName());
LayerPtr layer = LayerPtr(new WarpCTCLayer(layerConfig));
layerMap[layer->getName()] = layer;
layer->init(layerMap, parameterMap);
layer->forward(PASS_GC);
layer->backward();
return layer;
}
TEST(Layer, WarpCTCLayer) {
for (auto layerSize : {10, 64, 128}) {
for (auto batchSize : {1, 10, 20, 64}) {
for (auto useGpu : {false, true}) {
#ifdef PADDLE_ONLY_CPU
if (useGpu) continue;
#endif
LOG(INFO) << " layerSize=" << layerSize << " batchSize=" << batchSize
<< " useGpu=" << useGpu;
FLAGS_use_gpu = useGpu;
Argument data0;
initArgument(batchSize, layerSize, useGpu, data0);
Argument data1;
data1.resizeAndCopyFrom(data0);
LayerPtr dataLayer0 =
createDataLayer("data", batchSize, layerSize, useGpu, data0);
LayerPtr dataLayer1 =
createDataLayer("data", batchSize, layerSize, useGpu, data1);
LayerPtr labelLayer =
createLabelLayer("label", batchSize, layerSize, useGpu);
LayerPtr warpctcLayer = createWarpCTCLayer(
"cost", layerSize, useGpu, false, dataLayer0, labelLayer);
LayerPtr ctcLayer = createCTCLayer(
"cost", layerSize, useGpu, false, dataLayer1, labelLayer);
/// Check loss
checkError(*(warpctcLayer->getOutput().value),
*(ctcLayer->getOutput().value));
/// Check gradients
checkError(*(dataLayer0->getOutput().grad),
*(dataLayer1->getOutput().grad));
}
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
return RUN_ALL_TESTS();
}
...@@ -414,6 +414,8 @@ sinclude(`ModelConfigLayer.proto.m4') ...@@ -414,6 +414,8 @@ sinclude(`ModelConfigLayer.proto.m4')
// to string and reinterpreted in the user's own layer implementation. // to string and reinterpreted in the user's own layer implementation.
optional string user_arg = 49; optional string user_arg = 49;
// For WarpCTCLayer
optional uint32 blank = 50 [default = 0];
} }
message EvaluatorConfig { message EvaluatorConfig {
......
...@@ -2993,6 +2993,27 @@ class CTCLayer(LayerBase): ...@@ -2993,6 +2993,27 @@ class CTCLayer(LayerBase):
config_assert(len(self.inputs) == 2, 'CTCLayer must have 2 inputs') config_assert(len(self.inputs) == 2, 'CTCLayer must have 2 inputs')
@config_layer('warp_ctc')
class WarpCTCLayer(LayerBase):
def __init__(self,
name,
size,
inputs,
blank=0,
norm_by_times=False,
device=None):
super(WarpCTCLayer, self).__init__(
name, 'warp_ctc', size=size, inputs=inputs, device=device)
self.config.blank = blank
self.config.norm_by_times = norm_by_times
config_assert(len(self.inputs) == 2, 'WarpCTCLayer must have 2 inputs')
input_layer = self.get_input_layer(0)
config_assert(
(input_layer.active_type == '' or
input_layer.active_type == 'linear'),
"Expecting the active_type of input layer to be linear or null")
@config_layer('recurrent_layer_group') @config_layer('recurrent_layer_group')
class RecurrentLayerGroup(LayerBase): class RecurrentLayerGroup(LayerBase):
def __init__(self, name, device=None): def __init__(self, name, device=None):
......
...@@ -91,6 +91,7 @@ __all__ = [ ...@@ -91,6 +91,7 @@ __all__ = [
'linear_comb_layer', 'linear_comb_layer',
'convex_comb_layer', 'convex_comb_layer',
'ctc_layer', 'ctc_layer',
'warp_ctc_layer',
'crf_layer', 'crf_layer',
'crf_decoding_layer', 'crf_decoding_layer',
'nce_layer', 'nce_layer',
...@@ -169,6 +170,7 @@ class LayerType(object): ...@@ -169,6 +170,7 @@ class LayerType(object):
PRINT_LAYER = "print" PRINT_LAYER = "print"
CTC_LAYER = "ctc" CTC_LAYER = "ctc"
WARP_CTC_LAYER = "warp_ctc"
CRF_LAYER = "crf" CRF_LAYER = "crf"
CRF_DECODING_LAYER = "crf_decoding" CRF_DECODING_LAYER = "crf_decoding"
NCE_LAYER = 'nce' NCE_LAYER = 'nce'
...@@ -4085,6 +4087,83 @@ def ctc_layer(input, ...@@ -4085,6 +4087,83 @@ def ctc_layer(input,
return LayerOutput(name, LayerType.CTC_LAYER, [input, label], size=size) return LayerOutput(name, LayerType.CTC_LAYER, [input, label], size=size)
@wrap_name_default()
@layer_support()
def warp_ctc_layer(input,
label,
size=None,
name=None,
blank=0,
norm_by_times=False,
layer_attr=None):
"""
A layer intergrating the open-source `warp-ctc
<https://github.com/baidu-research/warp-ctc>` library, which is used in
`Deep Speech 2: End-toEnd Speech Recognition in English and Mandarin
<https://arxiv.org/pdf/1512.02595v1.pdf>`, to compute Connectionist Temporal
Classification (CTC) loss.
More details of CTC can be found by referring to `Connectionist Temporal
Classification: Labelling Unsegmented Sequence Data with Recurrent
Neural Networks <http://machinelearning.wustl.edu/mlpapers/paper_files/
icml2006_GravesFGS06.pdf>`_
Note:
- Let num_classes represent the category number. Considering the 'blank'
label needed by CTC, you need to use (num_classes + 1) as the input size.
Thus, the size of both warp_ctc_layer and 'input' layer should be set to
num_classes + 1.
- You can set 'blank' to [0, num_classes - 1], which should be consistent
as that used in your labels.
- As a native 'softmax' activation is interated to the warp-ctc library,
'linear' activation is expected instead in the 'input' layer.
The simple usage:
.. code-block:: python
ctc = warp_ctc_layer(input=input,
label=label,
size=1001,
blank=1000,
norm_by_times=False)
:param input: The input layer.
:type input: LayerOutput
:param label: The data layer of label with variable length.
:type label: LayerOutput
:param size: category numbers + 1.
:type size: int
:param name: The name of this layer, which can not specify.
:type name: basestring|None
:param blank: the 'blank' label used in ctc
:type blank: int
:param norm_by_times: Whether to normalization by times. False by default.
:type norm_by_times: bool
:param layer_attr: Extra Layer config.
:type layer_attr: ExtraLayerAttribute|None
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
assert isinstance(label, LayerOutput)
if label.size is not None:
if size is not None:
assert size == label.size + 1
else:
size = label.size + 1
Layer(
name=name,
type=LayerType.WARP_CTC_LAYER,
size=size,
blank=blank,
norm_by_times=norm_by_times,
inputs=[input.name, label.name],
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.WARP_CTC_LAYER, parents=[input, label], size=size)
@wrap_name_default() @wrap_name_default()
@wrap_param_attr_default() @wrap_param_attr_default()
@layer_support() @layer_support()
......
...@@ -47,6 +47,20 @@ layers { ...@@ -47,6 +47,20 @@ layers {
} }
norm_by_times: false norm_by_times: false
} }
layers {
name: "__warp_ctc_layer_0__"
type: "warp_ctc"
size: 5001
active_type: ""
inputs {
input_layer_name: "input"
}
inputs {
input_layer_name: "labels"
}
norm_by_times: false
blank: 0
}
layers { layers {
name: "crf_label" name: "crf_label"
type: "data" type: "data"
...@@ -244,6 +258,7 @@ input_layer_names: "xe-label" ...@@ -244,6 +258,7 @@ input_layer_names: "xe-label"
input_layer_names: "huber_probs" input_layer_names: "huber_probs"
input_layer_names: "huber_label" input_layer_names: "huber_label"
output_layer_names: "__ctc_layer_0__" output_layer_names: "__ctc_layer_0__"
output_layer_names: "__warp_ctc_layer_0__"
output_layer_names: "__crf_layer_0__" output_layer_names: "__crf_layer_0__"
output_layer_names: "__rank_cost_0__" output_layer_names: "__rank_cost_0__"
output_layer_names: "__lambda_cost_0__" output_layer_names: "__lambda_cost_0__"
...@@ -260,6 +275,7 @@ sub_models { ...@@ -260,6 +275,7 @@ sub_models {
layer_names: "xe-label" layer_names: "xe-label"
layer_names: "__fc_layer_0__" layer_names: "__fc_layer_0__"
layer_names: "__ctc_layer_0__" layer_names: "__ctc_layer_0__"
layer_names: "__warp_ctc_layer_0__"
layer_names: "crf_label" layer_names: "crf_label"
layer_names: "__crf_layer_0__" layer_names: "__crf_layer_0__"
layer_names: "left" layer_names: "left"
...@@ -289,6 +305,7 @@ sub_models { ...@@ -289,6 +305,7 @@ sub_models {
input_layer_names: "huber_probs" input_layer_names: "huber_probs"
input_layer_names: "huber_label" input_layer_names: "huber_label"
output_layer_names: "__ctc_layer_0__" output_layer_names: "__ctc_layer_0__"
output_layer_names: "__warp_ctc_layer_0__"
output_layer_names: "__crf_layer_0__" output_layer_names: "__crf_layer_0__"
output_layer_names: "__rank_cost_0__" output_layer_names: "__rank_cost_0__"
output_layer_names: "__lambda_cost_0__" output_layer_names: "__lambda_cost_0__"
......
...@@ -12,6 +12,8 @@ hidden = fc_layer(input=seq_in, size=4) ...@@ -12,6 +12,8 @@ hidden = fc_layer(input=seq_in, size=4)
outputs( outputs(
ctc_layer( ctc_layer(
input=seq_in, label=labels), input=seq_in, label=labels),
warp_ctc_layer(
input=seq_in, label=labels, blank=0),
crf_layer( crf_layer(
input=hidden, label=data_layer( input=hidden, label=data_layer(
name='crf_label', size=4)), name='crf_label', size=4)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册