Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
757c32f2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
757c32f2
编写于
9月 20, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
lstm unit gpu
上级
2b10d322
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
4 addition
and
4 deletion
+4
-4
paddle/operators/lstm_unit_op.cu
paddle/operators/lstm_unit_op.cu
+4
-4
未找到文件。
paddle/operators/lstm_unit_op.cu
浏览文件 @
757c32f2
...
...
@@ -35,7 +35,7 @@ __device__ Dtype cuda_tanh(const Dtype x) {
}
template
<
typename
T
>
__global__
void
LSTMUnitKernel
(
const
int
nthreads
,
const
int
dim
,
const
int
t
,
__global__
void
LSTMUnitKernel
(
const
int
nthreads
,
const
int
dim
,
const
T
*
C_prev
,
const
T
*
X
,
T
*
C
,
T
*
H
,
const
T
forget_bias
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
...
...
@@ -159,9 +159,9 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
int
n
=
N
*
D
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
LSTMUnitGradientKernel
<
T
><<<
N
*
D
,
block
>>>
(
n
,
D
,
C_prev
,
X
,
C
,
H
,
C_diff
,
H_diff
,
C_prev_diff
,
X_diff
,
T
forget_bias
)
LSTMUnitGradientKernel
<
T
><<<
grid
,
block
>>>
(
n
,
D
,
C_prev
,
X
,
C
,
H
,
C_diff
,
H_diff
,
C_prev_diff
,
X_diff
,
forget_bias
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录