Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
eef55ca7
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eef55ca7
编写于
8月 03, 2017
作者:
Z
Zhuoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remodify
上级
2b35fca1
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
108 addition
and
105 deletion
+108
-105
paddle/operators/gather_func.h
paddle/operators/gather_func.h
+41
-35
paddle/operators/scatter_func.h
paddle/operators/scatter_func.h
+67
-70
未找到文件。
paddle/operators/gather_func.h
浏览文件 @
eef55ca7
...
...
@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
#include "paddle/framework/ddim.h"
/**
* Return a new tensor from source tensor, gathered according to index
...
...
@@ -27,7 +27,7 @@ limitations under the License. */
template
<
typename
Place
,
typename
T
>
Tensor
*
Gather
(
Tensor
*
src
,
Tensor
*
index
)
{
// check index of shape 1-D
PADDLE_ENFORCE
(
index
->
dims
().
size
()
==
1
);
PADDLE_ENFORCE
(
index
->
dims
().
size
()
==
1
);
int
index_size
=
index
->
dims
()[
0
];
// Source shape
...
...
@@ -41,60 +41,66 @@ Tensor* Gather(Tensor* src, Tensor* index) {
/* slice size */
int
slice_size
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
for
(
size_t
i
=
0
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
/* Gathering */
if
(
place
==
CPUPlace
())
{
// init for CPU
output
=
New_tensor
.
mutable_data
<
T
>
(
output_dims
,
CPUPlace
());
CPUGather
(
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
CPUGather
(
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
}
else
{
// GPU
// init for GPU
output
=
New_tensor
.
mutable_data
<
T
>
(
output_dims
,
GPUPlace
());
/* how to specialize device??*/
GPUGather
(
d
,
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
GPUGather
(
d
,
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
}
return
New_tensor
;
}
/* Implementation of CPU copy */
template
<
typename
T
>
void
CPUGather
(
const
T
*
params
,
const
int
*
indices
,
const
int
slice_size
,
const
int
index_size
,
template
<
typename
T
>
void
CPUGather
(
const
T
*
params
,
const
int
*
indices
,
const
int
slice_size
,
const
int
index_size
,
T
*
output
)
{
const
size_t
slice_bytes
=
slice_size
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
for
(
size_t
i
=
0
;
i
<
index_size
;
++
i
)
{
int
index_
=
indices
[
i
];
/* copy src[index_] to output[i] */
memcpy
(
output
+
i
*
slice_bytes
,
params
+
index_
*
slice_bytes
,
slice_bytes
);
memcpy
(
output
+
i
*
slice_bytes
,
params
+
index_
*
slice_bytes
,
slice_bytes
);
}
}
/* Implementation of GPU copy:
I suppose the GPUDevice& d, contains gpu_id and thread_id
d = cuda_stream(gpu_id_, stream_id_);
*/
template
<
typename
T
>
template
<
typename
T
>
void
GPUGather
(
const
GPUDevice
&
d
,
const
T
*
src
,
const
int
*
index
,
const
int
slice_size
,
const
int
index_size
,
const
T
*
src
,
const
int
*
index
,
const
int
slice_size
,
const
int
index_size
,
T
*
output
)
{
int
block_count
=
slice_size
*
index_size
;
int
thread_per_block
=
1024
;
GatherOpKernel
<
T
>
<<<
block_count
,
thread_per_block
,
0
,
d
.
stream
()
>>>
(
src
,
index
,
output
,
slice_size
,
indices_size
,
slice_size
,
out_size
);
GatherOpKernel
<
T
><<<
block_count
,
thread_per_block
,
0
,
d
.
stream
()
>>>
(
src
,
index
,
output
,
slice_size
,
indices_size
,
slice_size
,
out_size
);
}
template
<
typename
T
>
__global__
void
GatherOpKernel
(
const
T
*
params
,
const
int
*
indices
,
T
*
out
,
__global__
void
GatherOpKernel
(
const
T
*
params
,
const
int
*
indices
,
T
*
out
,
int64
indices_size
,
int64
slice_size
,
int64
out_size
)
{
int64
slice_size
,
int64
out_size
)
{
/* I suppose we have the following macro,
which I strongly suggest that we should put in cuda:
#define CUDA_1D_KERNEL_LOOP(i, n) \
...
...
paddle/operators/scatter_func.h
浏览文件 @
eef55ca7
...
...
@@ -14,95 +14,92 @@ limitations under the License. */
#pragma once
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
#include "paddle/framework/ddim.h"
/**
* Return a updated tensor from source tensor, scattered according to index:
* dst[i] += src[index[i]]
* input[src]: type-T source Tensor
* input[
I
ndex]: type-int index Tensor (1-D)
* input[
i
ndex]: type-int index Tensor (1-D)
* return: output tensor
*/
template
<
typename
place
,
typename
T
>
void
ScatterUpdate_func
(
Tensor
*
Src
,
Tensor
*
Dst
,
Tensor
*
Index
)
{
// assert index is an int-type tensor
assert
(
Index
->
istype
(
int
));
template
<
typename
Place
,
typename
T
>
void
ScatterUpdate
(
Tensor
*
src
,
Tensor
*
dst
,
Tensor
*
index
)
{
// Source shape
auto
src_dims
=
S
rc
->
dims
();
auto
dst_dims
=
D
st
->
dims
();
auto
src_dims
=
s
rc
->
dims
();
auto
dst_dims
=
d
st
->
dims
();
DDim
output_dims
(
dims_src
);
// check Src shape and D
st shape should match
for
(
in
t
i
=
1
;
i
<
src_dims
.
size
();
i
++
)
assert
(
src_dims
[
i
]
==
dst_dims
[
i
]);
// check src shape and d
st shape should match
for
(
size_
t
i
=
1
;
i
<
src_dims
.
size
();
i
++
)
PADDLE_ENFORCE
(
src_dims
[
i
]
==
dst_dims
[
i
]);
int
index_size
=
I
ndex
->
dims
()[
0
];
int
index_size
=
i
ndex
->
dims
()[
0
];
/* slice size */
int
slice_size
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
for
(
size_t
i
=
0
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
if
(
place
==
CPUPlace
())
{
// init
output
=
new_tensor
.
mutable_data
<
T
>
(
output_dims
,
CPUPlace
());
CPUScatterUpdate
(
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
CPUScatterUpdate
(
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
}
else
{
// GPU
// init
output
=
new_tensor
.
mutable_data
<
T
>
(
output_dims
,
GPUPlace
());
/* how to specialize device??*/
GPUScatterUpdate
(
d
,
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
GPUScatterUpdate
(
d
,
src
->
data
(),
index
->
data
(),
slice_size
,
new_tensor
->
mutable_data
());
}
}
/* Implementation of CPU copy */
template
<
typename
T
>
void
CPUScatterUpdate
(
const
T
*
src
,
const
int
*
Index
,
const
int
slice_size
,
const
int
index_size
,
template
<
typename
T
>
void
CPUScatterUpdate
(
const
T
*
src
,
const
int
*
index
,
const
int
slice_size
,
const
int
index_size
,
T
*
output
)
{
//const size_t slice_bytes = slice_size * sizeof(T);
//
const size_t slice_bytes = slice_size * sizeof(T);
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
for
(
size_t
i
=
0
;
i
<
index_size
;
++
i
)
{
int
index_
=
index
[
i
];
/* dst[index_] += src[index_]
add operation size: slice_size
*/
math
::
vAdd
<
T
>
(
slice_size
,
src
+
index_
*
slice_bytes
,
math
::
vAdd
<
T
>
(
slice_size
,
src
+
index_
*
slice_bytes
,
output
+
i
*
slice_bytes
,
output
+
i
*
slice_bytes
);
/* Scatter update, not just assign
memcpy(output + i * slice_bytes,
src + index_ * slice_bytes,
slice_bytes);
*/
}
}
/* Implementation of GPU scatter:
I suppose the GPUDevice& d, contains gpu_id and thread_id
d = cuda_stream(gpu_id_, stream_id_);
*/
template
<
typename
T
>
template
<
typename
T
>
void
GPUScatterUpdate
(
const
GPUDevice
&
d
,
const
T
*
src
,
const
int
*
Index
,
const
int
slice_size
,
const
int
index_size
,
const
T
*
src
,
const
int
*
index
,
const
int
slice_size
,
const
int
index_size
,
T
*
output
)
{
int
block_count
=
slice_size
*
index_size
;
int
thread_per_block
=
1024
;
ScatterOpKernel
<
T
>
<<<
block_count
,
thread_per_block
,
0
,
d
.
stream
()
>>>
(
src
,
Index
,
output
,
slice_size
,
indices_size
,
slice_size
,
out_size
);
ScatterOpKernel
<
T
><<<
block_count
,
thread_per_block
,
0
,
d
.
stream
()
>>>
(
src
,
index
,
output
,
slice_size
,
indices_size
,
slice_size
,
out_size
);
}
template
<
typename
T
>
__global__
void
ScatterOpKernel
(
const
T
*
params
,
const
int
*
indices
,
T
*
out
,
__global__
void
ScatterOpKernel
(
const
T
*
params
,
const
int
*
indices
,
T
*
out
,
int64
indices_size
,
int64
slice_size
,
int64
out_size
)
{
int64
slice_size
,
int64
out_size
)
{
/* I suppose we have the following macro,
which I strongly suggest that we should put in cuda:
#define CUDA_1D_KERNEL_LOOP(i, n) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录