Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f7aca9e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
6f7aca9e
编写于
9月 10, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
9月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix scatter and gather bug (#35595)
* fix scatter gather bug: * fix windows ci
上级
42847d2e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
86 addition
and
38 deletion
+86
-38
paddle/fluid/operators/gather.cu.h
paddle/fluid/operators/gather.cu.h
+6
-10
paddle/fluid/operators/scatter.cu.h
paddle/fluid/operators/scatter.cu.h
+26
-26
paddle/fluid/operators/scatter.h
paddle/fluid/operators/scatter.h
+2
-2
python/paddle/fluid/tests/unittests/test_gather_op.py
python/paddle/fluid/tests/unittests/test_gather_op.py
+12
-0
python/paddle/fluid/tests/unittests/test_scatter_op.py
python/paddle/fluid/tests/unittests/test_scatter_op.py
+40
-0
未找到文件。
paddle/fluid/operators/gather.cu.h
浏览文件 @
6f7aca9e
...
...
@@ -36,7 +36,7 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
int64_t
indices_i
=
i
/
slice_size
;
int64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
gather_i
=
indices
[
indices_i
];
IndexT
params_i
=
gather_i
*
slice_size
+
slice_i
;
int64_t
params_i
=
gather_i
*
slice_size
+
slice_i
;
*
(
output
+
i
)
=
*
(
params
+
params_i
);
}
}
...
...
@@ -49,7 +49,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
CUDA_KERNEL_LOOP_TYPE
(
i
,
remain_size
*
slice_size
,
int64_t
)
{
int64_t
indices_i
=
i
/
slice_size
;
int64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
gather_i
=
0
;
int64_t
gather_i
=
0
;
int64_t
temp
=
slice_size
;
for
(
int64_t
j
=
end_size
-
1
;
j
>=
0
;
--
j
)
{
auto
index_value
=
indices
[
indices_i
*
end_size
+
j
];
...
...
@@ -63,7 +63,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
gather_i
+=
(
index_value
*
temp
);
temp
*=
input_dims
[
j
];
}
IndexT
input_i
=
gather_i
+
slice_i
;
int64_t
input_i
=
gather_i
+
slice_i
;
*
(
output
+
i
)
=
*
(
input
+
input_i
);
}
}
...
...
@@ -78,13 +78,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
template
<
typename
T
,
typename
IndexT
=
int
>
void
GPUGather
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
// check index of shape 1-D
if
(
index
.
dims
().
size
()
==
1
)
{
PADDLE_ENFORCE_GT
(
index
.
dims
()[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"The index of gather_op should not be empty"
"when the index's rank is 1."
));
}
else
if
(
index
.
dims
().
size
()
==
2
)
{
if
(
index
.
dims
().
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index
.
dims
()[
1
],
1
,
platform
::
errors
::
InvalidArgument
(
"If the index's rank of gather_op is 2,"
...
...
@@ -93,6 +87,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// index size
int64_t
index_size
=
index
.
dims
()[
0
];
if
(
index_size
==
0
)
return
;
auto
src_dims
=
src
.
dims
();
framework
::
DDim
output_dims
(
src_dims
);
...
...
@@ -248,6 +243,7 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
out
->
Resize
(
out_dim
);
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
place
);
int64_t
out_size
=
out
->
numel
();
if
(
out_size
==
0
)
return
;
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
.
cuda_device_context
(),
out_size
);
...
...
paddle/fluid/operators/scatter.cu.h
浏览文件 @
6f7aca9e
...
...
@@ -29,9 +29,9 @@ using Tensor = framework::Tensor;
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
ScatterInitCUDAKernel
(
const
IndexT
*
indices
,
T
*
output
,
size_t
index_size
,
size_t
slice_size
)
{
CUDA_KERNEL_LOOP
(
i
,
index_size
*
slice_size
)
{
int
indices_i
=
i
/
slice_size
;
int
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
CUDA_KERNEL_LOOP
_TYPE
(
i
,
index_size
*
slice_size
,
int64_t
)
{
int
64_t
indices_i
=
i
/
slice_size
;
int
64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
scatter_i
=
indices
[
indices_i
];
PADDLE_ENFORCE
(
scatter_i
>=
0
,
...
...
@@ -41,7 +41,7 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
"be greater than or equal to 0, but received [%d]"
,
scatter_i
);
IndexT
out_i
=
scatter_i
*
slice_size
+
slice_i
;
int64_t
out_i
=
scatter_i
*
slice_size
+
slice_i
;
*
(
output
+
out_i
)
=
static_cast
<
T
>
(
0
);
}
}
...
...
@@ -50,9 +50,9 @@ template <typename T, typename IndexT = int>
__global__
void
ScatterCUDAKernel
(
const
T
*
params
,
const
IndexT
*
indices
,
T
*
output
,
size_t
index_size
,
size_t
slice_size
,
bool
overwrite
)
{
CUDA_KERNEL_LOOP
(
i
,
index_size
*
slice_size
)
{
int
indices_i
=
i
/
slice_size
;
int
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
CUDA_KERNEL_LOOP
_TYPE
(
i
,
index_size
*
slice_size
,
int64_t
)
{
int
64_t
indices_i
=
i
/
slice_size
;
int
64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
scatter_i
=
indices
[
indices_i
];
PADDLE_ENFORCE
(
scatter_i
>=
0
,
...
...
@@ -62,7 +62,7 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
"be greater than or equal to 0, but received [%d]"
,
scatter_i
);
IndexT
out_i
=
scatter_i
*
slice_size
+
slice_i
;
int64_t
out_i
=
scatter_i
*
slice_size
+
slice_i
;
if
(
overwrite
)
{
*
(
output
+
out_i
)
=
*
(
params
+
i
);
}
else
{
...
...
@@ -73,13 +73,13 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
ScatterNdCUDAKernel
(
const
T
*
update
,
const
IndexT
*
indices
,
T
*
output
,
const
int
*
output_dims
,
T
*
output
,
const
int
64_t
*
output_dims
,
size_t
remain_size
,
size_t
slice_size
,
size_t
end_size
)
{
CUDA_KERNEL_LOOP
(
i
,
remain_size
*
slice_size
)
{
int
indices_i
=
i
/
slice_size
;
int
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
IndexT
gather_i
=
0
;
CUDA_KERNEL_LOOP
_TYPE
(
i
,
remain_size
*
slice_size
,
int64_t
)
{
int
64_t
indices_i
=
i
/
slice_size
;
int
64_t
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
int64_t
gather_i
=
0
;
int64_t
temp
=
slice_size
;
for
(
int64_t
j
=
end_size
-
1
;
j
>=
0
;
--
j
)
{
IndexT
index_value
=
indices
[
indices_i
*
end_size
+
j
];
...
...
@@ -95,7 +95,7 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
gather_i
+=
(
index_value
*
temp
);
temp
*=
output_dims
[
j
];
}
IndexT
output_i
=
gather_i
+
slice_i
;
int64_t
output_i
=
gather_i
+
slice_i
;
paddle
::
platform
::
CudaAtomicAdd
(
output
+
output_i
,
*
(
update
+
i
));
}
}
...
...
@@ -128,14 +128,14 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
"But received value is [%d]"
,
index
.
dims
().
size
()));
}
int
index_size
=
index
.
dims
()[
0
];
int
64_t
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
.
dims
();
framework
::
DDim
output_dims
(
src_dims
);
output_dims
[
0
]
=
index_size
;
// slice size
int
slice_size
=
1
;
int
64_t
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
const
T
*
p_src
=
src
.
data
<
T
>
();
...
...
@@ -145,8 +145,8 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
// set block and grid num
int
block
=
512
;
int
n
=
slice_size
*
index_size
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
int
64_t
n
=
slice_size
*
index_size
;
int
64_t
grid
=
(
n
+
block
-
1
)
/
block
;
// if not overwrite mode, init data
if
(
!
overwrite
)
{
...
...
@@ -167,10 +167,10 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
template
<
typename
T
,
typename
IndexT
=
int
>
void
GPUScatterGradForX
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
index
,
Tensor
*
output
)
{
IndexT
index_size
=
index
.
dims
()[
0
];
int64_t
index_size
=
index
.
dims
()[
0
];
auto
dst_dims
=
output
->
dims
();
// slice size
IndexT
slice_size
=
1
;
int64_t
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
dst_dims
.
size
();
++
i
)
slice_size
*=
dst_dims
[
i
];
const
IndexT
*
p_index
=
index
.
data
<
IndexT
>
();
T
*
p_output
=
output
->
data
<
T
>
();
...
...
@@ -224,20 +224,20 @@ void GPUScatterNdAdd(const framework::ExecutionContext& context,
const
auto
gplace
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
auto
cplace
=
platform
::
CPUPlace
();
std
::
vector
<
int
>
v_output_dims
(
output_dims_size
);
std
::
vector
<
int
64_t
>
v_output_dims
(
output_dims_size
);
for
(
int
i
=
0
;
i
<
output_dims_size
;
++
i
)
{
v_output_dims
[
i
]
=
static_cast
<
int
>
(
output_dims
[
i
])
;
v_output_dims
[
i
]
=
output_dims
[
i
]
;
}
auto
&
dev_ctx
=
context
.
cuda_device_context
();
int
bytes
=
output_dims_size
*
sizeof
(
in
t
);
int
64_t
bytes
=
output_dims_size
*
sizeof
(
int64_
t
);
auto
output_dims_ptr
=
memory
::
Alloc
(
dev_ctx
,
bytes
);
int
*
g_output_dims
=
reinterpret_cast
<
in
t
*>
(
output_dims_ptr
->
ptr
());
int
64_t
*
g_output_dims
=
reinterpret_cast
<
int64_
t
*>
(
output_dims_ptr
->
ptr
());
memory
::
Copy
(
gplace
,
g_output_dims
,
cplace
,
v_output_dims
.
data
(),
bytes
,
ctx
.
stream
());
int
block
=
512
;
int
n
=
slice_size
*
remain_numel
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
int
64_t
n
=
slice_size
*
remain_numel
;
int
64_t
grid
=
(
n
+
block
-
1
)
/
block
;
ScatterNdCUDAKernel
<
T
,
IndexT
><<<
grid
,
block
,
0
,
...
...
paddle/fluid/operators/scatter.h
浏览文件 @
6f7aca9e
...
...
@@ -112,7 +112,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const
size_t
slice_bytes
=
slice_size
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
for
(
int
64_t
i
=
0
;
i
<
index_size
;
++
i
)
{
IndexT
index_
=
p_index
[
i
];
PADDLE_ENFORCE_GE
(
index_
,
0
,
...
...
@@ -175,7 +175,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
}
// if not in overwrite mode, need to init output data
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
for
(
int
64_t
i
=
0
;
i
<
index_size
;
++
i
)
{
const
IndexT
&
index_val
=
p_index
[
i
];
PADDLE_ENFORCE_GE
(
index_val
,
0
,
...
...
python/paddle/fluid/tests/unittests/test_gather_op.py
浏览文件 @
6f7aca9e
...
...
@@ -248,6 +248,17 @@ class API_TestDygraphGather(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
output_np
,
expected_output
))
paddle
.
enable_static
()
def
test_zero_index
(
self
):
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
([[
1
,
2
],
[
3
,
4
]])
index
=
paddle
.
to_tensor
(
np
.
array
([]).
astype
(
'int64'
))
for
axis
in
range
(
len
(
x
.
shape
)):
out
=
paddle
.
gather
(
x
,
index
,
axis
)
expected_shape
=
list
(
x
.
shape
)
expected_shape
[
axis
]
=
0
self
.
assertEqual
(
list
(
out
.
shape
),
expected_shape
)
paddle
.
enable_static
()
def
test_large_data
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
...
...
@@ -340,4 +351,5 @@ class TestCheckOutType(unittest.TestCase):
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_scatter_op.py
浏览文件 @
6f7aca9e
...
...
@@ -16,10 +16,12 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
os
import
paddle
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
class
TestScatterOp
(
OpTest
):
...
...
@@ -228,6 +230,44 @@ class TestScatterAPI(unittest.TestCase):
self
.
assertEqual
((
output1
.
numpy
()
==
\
np
.
array
([[
3.
,
3.
],[
6.
,
6.
],[
1.
,
1.
]])).
all
(),
True
)
def
test_large_data
(
self
):
if
os
.
name
==
"nt"
or
not
paddle
.
is_compiled_with_cuda
():
return
x
=
np
.
random
.
rand
(
183826
,
256
).
astype
(
"float32"
)
index
=
np
.
ones
(
10759233
,
dtype
=
"int64"
)
updates
=
np
.
ones
(
shape
=
[
10759233
,
256
],
dtype
=
"float32"
)
def
test_dygraph
():
with
fluid
.
dygraph
.
guard
():
gpu_out
=
paddle
.
scatter
(
paddle
.
to_tensor
(
x
),
paddle
.
to_tensor
(
index
),
paddle
.
to_tensor
(
updates
))
return
gpu_out
.
numpy
()
@
switch_to_static_graph
def
test_static_graph
():
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
x_t
=
paddle
.
static
.
data
(
name
=
"x"
,
dtype
=
x
.
dtype
,
shape
=
x
.
shape
)
index_t
=
paddle
.
static
.
data
(
name
=
"index"
,
dtype
=
index
.
dtype
,
shape
=
index
.
shape
)
updates_t
=
paddle
.
static
.
data
(
name
=
"updates"
,
dtype
=
updates
.
dtype
,
shape
=
updates
.
shape
)
out_t
=
paddle
.
scatter
(
x_t
,
index_t
,
updates_t
)
feed
=
{
x_t
.
name
:
x
,
index_t
.
name
:
index
,
updates_t
.
name
:
updates
}
fetch
=
[
out_t
]
gpu_exe
=
paddle
.
static
.
Executor
(
paddle
.
CUDAPlace
(
0
))
gpu_value
=
gpu_exe
.
run
(
feed
=
feed
,
fetch_list
=
fetch
)[
0
]
return
gpu_value
self
.
assertTrue
(
np
.
array_equal
(
test_dygraph
(),
test_static_graph
()))
class
TestScatterInplaceAPI
(
TestScatterAPI
):
def
executed_api
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录