提交 67d1ba0f 编写于 作者: Y yangyongjie

support 2-dimension target of CTCLossV2

上级 f30df6e3
......@@ -51,10 +51,12 @@ class CtcLossGpuKernel : public GpuKernel {
float *grads = GetDeviceAddress<float>(outputs, 1);
// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
void *labels_host = nullptr;
int *labels_host = nullptr;
int *no_blank_labels_host = nullptr;
void *input_lengths_host = nullptr;
void *label_lengths_host = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
......@@ -68,12 +70,21 @@ class CtcLossGpuKernel : public GpuKernel {
"cudaMemcpyAsync failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
size_t j = 0;
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
if (labels_host[i] != 0) {
no_blank_labels_host[j] = labels_host[i];
j++;
}
}
size_t workspace_size = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(labels_host),
reinterpret_cast<int *>(label_lengths_host),
reinterpret_cast<int *>(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
ctcloss_desc_, &workspace_size),
cudnnGetCTCLossWorkspaceSize(
cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
"cudnnGetCTCLossWorkspaceSize failed.");
void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
if (workspace == nullptr) {
......@@ -81,7 +92,7 @@ class CtcLossGpuKernel : public GpuKernel {
}
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(labels_host),
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
"cudnnCtcLoss failed.");
......@@ -91,6 +102,7 @@ class CtcLossGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册