Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
043f47b2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
043f47b2
编写于
3月 23, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix concat op
上级
76ae540f
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
5 addition
and
7 deletion
+5
-7
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+5
-7
未找到文件。
paddle/fluid/operators/math/concat.cu
浏览文件 @
043f47b2
...
...
@@ -70,9 +70,8 @@ __global__ void KernelConcat(T** inputs, const int input_col,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
double
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_
input_col
;
int
split
=
tid_x
*
1.0
/
input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
T
*
input_ptr
=
inputs
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
...
...
@@ -110,17 +109,16 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
const
int
output_col
s
,
const
int
input_col
,
const
int
output_col
,
T
**
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
double
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_in
put_col
;
int
in_offset
=
tid_x
-
split
*
in
put_col
;
int
split
=
tid_x
/
out
put_col
;
int
in_offset
=
tid_x
-
split
*
out
put_col
;
T
*
output_ptr
=
outputs
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
output_col
s
+
in_offset
]
=
output_ptr
[
tid_y
*
output_col
+
in_offset
]
=
input
[
tid_y
*
input_col
+
tid_x
];
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录