Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
67d1ba0f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
67d1ba0f
编写于
7月 20, 2020
作者:
Y
yangyongjie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support 2-dimension target of CTCLossV2
上级
f30df6e3
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
18 addition
and
6 deletion
+18
-6
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h
...ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h
+18
-6
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h
浏览文件 @
67d1ba0f
...
...
@@ -51,10 +51,12 @@ class CtcLossGpuKernel : public GpuKernel {
float
*
grads
=
GetDeviceAddress
<
float
>
(
outputs
,
1
);
// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
void
*
labels_host
=
nullptr
;
int
*
labels_host
=
nullptr
;
int
*
no_blank_labels_host
=
nullptr
;
void
*
input_lengths_host
=
nullptr
;
void
*
label_lengths_host
=
nullptr
;
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMallocHost
(
&
labels_host
,
inputs
[
1
]
->
size
),
"cudaMallocHost failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMallocHost
(
&
no_blank_labels_host
,
inputs
[
1
]
->
size
),
"cudaMallocHost failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMallocHost
(
&
input_lengths_host
,
inputs
[
2
]
->
size
),
"cudaMallocHost failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMallocHost
(
&
label_lengths_host
,
inputs
[
3
]
->
size
),
"cudaMallocHost failed."
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
);
...
...
@@ -68,12 +70,21 @@ class CtcLossGpuKernel : public GpuKernel {
"cudaMemcpyAsync failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaStreamSynchronize
(
stream
),
"cudaStreamSynchronize failed."
);
size_t
j
=
0
;
for
(
size_t
i
=
0
;
i
<
inputs
[
1
]
->
size
/
sizeof
(
int
);
i
++
)
{
if
(
labels_host
[
i
]
!=
0
)
{
no_blank_labels_host
[
j
]
=
labels_host
[
i
];
j
++
;
}
}
size_t
workspace_size
=
0
;
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnGetCTCLossWorkspaceSize
(
cudnn_handle_
,
probs_desc_
,
probs_desc_
,
reinterpret_cast
<
int
*>
(
labels_host
),
reinterpret_cast
<
int
*>
(
label_length
s_host
),
reinterpret_cast
<
int
*>
(
input_lengths_host
),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
,
ctcloss_desc_
,
&
workspace_size
),
cudnnGetCTCLossWorkspaceSize
(
cudnn_handle_
,
probs_desc_
,
probs_desc_
,
reinterpret_cast
<
int
*>
(
no_blank_label
s_host
),
reinterpret_cast
<
int
*>
(
label_lengths_host
),
reinterpret_cast
<
int
*>
(
input_lengths_host
)
,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
,
ctcloss_desc_
,
&
workspace_size
),
"cudnnGetCTCLossWorkspaceSize failed."
);
void
*
workspace
=
device
::
gpu
::
GPUMemoryAllocator
::
GetInstance
().
AllocTensorMem
(
workspace_size
);
if
(
workspace
==
nullptr
)
{
...
...
@@ -81,7 +92,7 @@ class CtcLossGpuKernel : public GpuKernel {
}
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnCTCLoss
(
cudnn_handle_
,
probs_desc_
,
probs
,
reinterpret_cast
<
int
*>
(
labels_host
),
cudnnCTCLoss
(
cudnn_handle_
,
probs_desc_
,
probs
,
reinterpret_cast
<
int
*>
(
no_blank_
labels_host
),
reinterpret_cast
<
int
*>
(
label_lengths_host
),
reinterpret_cast
<
int
*>
(
input_lengths_host
),
costs
,
probs_desc_
,
grads
,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
,
ctcloss_desc_
,
workspace
,
workspace_size
),
"cudnnCtcLoss failed."
);
...
...
@@ -91,6 +102,7 @@ class CtcLossGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaFreeHost
(
label_lengths_host
),
"cudaFreeHost failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaFreeHost
(
input_lengths_host
),
"cudaFreeHost failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaFreeHost
(
labels_host
),
"cudaFreeHost failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaFreeHost
(
no_blank_labels_host
),
"cudaFreeHost failed."
);
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录