Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9514b4aa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9514b4aa
编写于
1月 22, 2021
作者:
S
ShenLiang
提交者:
GitHub
1月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix scatter grad bug (#30604)
上级
1f5841c2
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
108 addition
and
54 deletion
+108
-54
paddle/fluid/operators/scatter.cu.h
paddle/fluid/operators/scatter.cu.h
+33
-3
paddle/fluid/operators/scatter.h
paddle/fluid/operators/scatter.h
+18
-0
paddle/fluid/operators/scatter_nd_add_op.cu
paddle/fluid/operators/scatter_nd_add_op.cu
+0
-1
paddle/fluid/operators/scatter_nd_add_op.h
paddle/fluid/operators/scatter_nd_add_op.h
+1
-2
paddle/fluid/operators/scatter_op.cc
paddle/fluid/operators/scatter_op.cc
+1
-5
paddle/fluid/operators/scatter_op.cu
paddle/fluid/operators/scatter_op.cu
+21
-15
paddle/fluid/operators/scatter_op.h
paddle/fluid/operators/scatter_op.h
+20
-14
python/paddle/fluid/tests/unittests/test_scatter_nd_op.py
python/paddle/fluid/tests/unittests/test_scatter_nd_op.py
+7
-7
python/paddle/fluid/tests/unittests/test_scatter_op.py
python/paddle/fluid/tests/unittests/test_scatter_op.py
+7
-7
未找到文件。
paddle/fluid/operators/scatter.cu.h
浏览文件 @
9514b4aa
...
...
@@ -28,8 +28,7 @@ using Tensor = framework::Tensor;
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
ScatterInitCUDAKernel
(
const
IndexT
*
indices
,
T
*
output
,
size_t
index_size
,
size_t
slice_size
,
bool
overwrite
)
{
size_t
index_size
,
size_t
slice_size
)
{
CUDA_KERNEL_LOOP
(
i
,
index_size
*
slice_size
)
{
int
indices_i
=
i
/
slice_size
;
int
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
...
...
@@ -129,7 +128,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
ScatterInitCUDAKernel
<
T
,
IndexT
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
p_index
,
p_output
,
index_size
,
slice_size
,
overwrite
);
p_index
,
p_output
,
index_size
,
slice_size
);
}
ScatterCUDAKernel
<
T
,
IndexT
><<<
...
...
@@ -138,6 +137,37 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
p_src
,
p_index
,
p_output
,
index_size
,
slice_size
,
overwrite
);
}
// The function is only for scatter grad x,
// however update grad use gather
template
<
typename
T
,
typename
IndexT
=
int
>
void
GPUScatterGradForX
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
index
,
Tensor
*
output
)
{
IndexT
index_size
=
index
.
dims
()[
0
];
auto
dst_dims
=
output
->
dims
();
// slice size
IndexT
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
dst_dims
.
size
();
++
i
)
slice_size
*=
dst_dims
[
i
];
const
IndexT
*
p_index
=
index
.
data
<
IndexT
>
();
T
*
p_output
=
output
->
data
<
T
>
();
const
size_t
&
slice_bytes
=
slice_size
*
sizeof
(
T
);
// set block and grid num
int64_t
block
=
512
;
int64_t
n
=
slice_size
*
index_size
;
int64_t
height
=
(
n
+
block
-
1
)
/
block
;
int64_t
max_grid_dimx
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
GetCUDAMaxGridDimSize
()
.
x
;
int64_t
grid
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
ScatterInitCUDAKernel
<
T
,
IndexT
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
p_index
,
p_output
,
index_size
,
slice_size
);
}
template
<
typename
DeviceContext
,
typename
T
,
typename
IndexT
=
int
>
void
GPUScatterNdAdd
(
const
framework
::
ExecutionContext
&
context
,
const
Tensor
&
update
,
const
Tensor
&
index
,
...
...
paddle/fluid/operators/scatter.h
浏览文件 @
9514b4aa
...
...
@@ -171,6 +171,24 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
}
}
// The function is only for scatter grad x,
// however update grad use gather
template
<
typename
T
,
typename
IndexT
=
int
>
void
CPUScatterGradForX
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
index
,
Tensor
*
output
)
{
int
index_size
=
index
.
dims
()[
0
];
auto
dst_dims
=
output
->
dims
();
const
IndexT
*
p_index
=
index
.
data
<
IndexT
>
();
T
*
p_output
=
output
->
data
<
T
>
();
size_t
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
dst_dims
.
size
();
++
i
)
slice_size
*=
dst_dims
[
i
];
const
size_t
slice_bytes
=
slice_size
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
const
IndexT
&
index_
=
p_index
[
i
];
memset
(
p_output
+
slice_size
*
index_
,
0
,
slice_bytes
);
}
}
template
<
typename
T
,
typename
IndexT
=
int
>
void
ScatterNdAdd
(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
update
,
const
Tensor
&
index
,
Tensor
*
output
)
{
...
...
paddle/fluid/operators/scatter_nd_add_op.cu
浏览文件 @
9514b4aa
...
...
@@ -65,7 +65,6 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
}
if
(
dUpdates
)
{
...
...
paddle/fluid/operators/scatter_nd_add_op.h
浏览文件 @
9514b4aa
...
...
@@ -71,8 +71,7 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopySync
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
paddle/fluid/operators/scatter_op.cc
浏览文件 @
9514b4aa
...
...
@@ -138,9 +138,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer,
"Updates"
);
DECLARE_INPLACE_OP_INFERER
(
ScatterInplaceInferer
,
{
"X"
,
"Out"
});
DECLARE_INPLACE_OP_INFERER
(
ScatterGradInplaceInferer
,
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)});
}
// namespace operators
}
// namespace paddle
...
...
@@ -151,8 +148,7 @@ REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker,
ops
::
ScatterGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
ScatterInplaceInferer
);
REGISTER_OPERATOR
(
scatter_grad
,
ops
::
ScatterGradOp
,
ops
::
ScatterGradNoNeedBufferVarsInferer
,
ops
::
ScatterGradInplaceInferer
);
ops
::
ScatterGradNoNeedBufferVarsInferer
);
REGISTER_OP_CPU_KERNEL
(
scatter
,
ops
::
ScatterOpKernel
<
float
>
,
ops
::
ScatterOpKernel
<
double
>
,
ops
::
ScatterOpKernel
<
int
>
,
ops
::
ScatterOpKernel
<
int64_t
>
);
...
...
paddle/fluid/operators/scatter_op.cu
浏览文件 @
9514b4aa
...
...
@@ -67,27 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto
*
dUpdates
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Updates"
));
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUScatterGradForX
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
else
{
GPUScatterGradForX
<
T
,
int64_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Ids]
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"scatter_op Index holds the wrong type, it holds [%s], "
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
// Gradient by Gather: dUpdates = dO[Ids]
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUGather
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dOut
,
*
Ids
,
dUpdates
);
}
else
{
...
...
paddle/fluid/operators/scatter_op.h
浏览文件 @
9514b4aa
...
...
@@ -79,26 +79,32 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUScatterGradForX
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
else
{
CPUScatterGradForX
<
T
,
int64_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Ids]
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUGather
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dOut
,
*
Ids
,
dUpdates
);
}
else
{
...
...
python/paddle/fluid/tests/unittests/test_scatter_nd_op.py
浏览文件 @
9514b4aa
...
...
@@ -78,7 +78,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterNdAddWithEmptyIndex
(
OpTest
):
...
...
@@ -101,7 +101,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'X'
,
'Updates'
],
'Out'
)
class
TestScatterNdAddWithHighRankSame
(
OpTest
):
...
...
@@ -111,11 +111,11 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"scatter_nd_add"
shape
=
(
10
,
9
,
8
,
1
,
15
)
shape
=
(
3
,
2
,
2
,
1
,
10
)
ref_np
=
np
.
random
.
rand
(
*
shape
).
astype
(
"float64"
)
index_np
=
np
.
vstack
(
[
np
.
random
.
randint
(
0
,
s
,
size
=
1
5
0
)
for
s
in
shape
]).
T
.
astype
(
"int32"
)
0
,
s
,
size
=
1
0
0
)
for
s
in
shape
]).
T
.
astype
(
"int32"
)
update_shape
=
judge_update_shape
(
ref_np
,
index_np
)
updates_np
=
np
.
random
.
rand
(
*
update_shape
).
astype
(
"float64"
)
expect_np
=
numpy_scatter_nd_add
(
ref_np
.
copy
(),
index_np
,
updates_np
)
...
...
@@ -127,7 +127,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterNdAddWithHighRankDiff
(
OpTest
):
...
...
@@ -137,7 +137,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"scatter_nd_add"
shape
=
(
10
,
9
,
8
,
1
,
15
)
shape
=
(
8
,
2
,
2
,
1
,
10
)
ref_np
=
np
.
random
.
rand
(
*
shape
).
astype
(
"double"
)
index
=
np
.
vstack
([
np
.
random
.
randint
(
0
,
s
,
size
=
500
)
for
s
in
shape
]).
T
index_np
=
index
.
reshape
([
10
,
5
,
10
,
5
]).
astype
(
"int64"
)
...
...
@@ -152,7 +152,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
#Test Python API
...
...
python/paddle/fluid/tests/unittests/test_scatter_op.py
浏览文件 @
9514b4aa
...
...
@@ -37,7 +37,7 @@ class TestScatterOp(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
"X"
,
"Updates"
],
"Out"
)
class
TestScatterOp0
(
OpTest
):
...
...
@@ -56,7 +56,7 @@ class TestScatterOp0(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
"X"
,
"Updates"
],
"Out"
)
class
TestScatterOp1
(
OpTest
):
...
...
@@ -78,7 +78,7 @@ class TestScatterOp1(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
"X"
,
"Updates"
],
"Out"
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -102,7 +102,7 @@ class TestScatterOp2(OpTest):
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad_with_place
(
place
,
[
'
X'
,
'Updates'
],
'Out'
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -130,7 +130,7 @@ class TestScatterOp3(OpTest):
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad_with_place
(
place
,
[
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterOp4
(
OpTest
):
...
...
@@ -148,7 +148,7 @@ class TestScatterOp4(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -172,7 +172,7 @@ class TestScatterOp5(OpTest):
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad_with_place
(
place
,
[
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录