Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9514b4aa
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9514b4aa
编写于
1月 22, 2021
作者:
S
ShenLiang
提交者:
GitHub
1月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix scatter grad bug (#30604)
上级
1f5841c2
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
108 addition
and
54 deletion
+108
-54
paddle/fluid/operators/scatter.cu.h
paddle/fluid/operators/scatter.cu.h
+33
-3
paddle/fluid/operators/scatter.h
paddle/fluid/operators/scatter.h
+18
-0
paddle/fluid/operators/scatter_nd_add_op.cu
paddle/fluid/operators/scatter_nd_add_op.cu
+0
-1
paddle/fluid/operators/scatter_nd_add_op.h
paddle/fluid/operators/scatter_nd_add_op.h
+1
-2
paddle/fluid/operators/scatter_op.cc
paddle/fluid/operators/scatter_op.cc
+1
-5
paddle/fluid/operators/scatter_op.cu
paddle/fluid/operators/scatter_op.cu
+21
-15
paddle/fluid/operators/scatter_op.h
paddle/fluid/operators/scatter_op.h
+20
-14
python/paddle/fluid/tests/unittests/test_scatter_nd_op.py
python/paddle/fluid/tests/unittests/test_scatter_nd_op.py
+7
-7
python/paddle/fluid/tests/unittests/test_scatter_op.py
python/paddle/fluid/tests/unittests/test_scatter_op.py
+7
-7
未找到文件。
paddle/fluid/operators/scatter.cu.h
浏览文件 @
9514b4aa
...
...
@@ -28,8 +28,7 @@ using Tensor = framework::Tensor;
template
<
typename
T
,
typename
IndexT
=
int
>
__global__
void
ScatterInitCUDAKernel
(
const
IndexT
*
indices
,
T
*
output
,
size_t
index_size
,
size_t
slice_size
,
bool
overwrite
)
{
size_t
index_size
,
size_t
slice_size
)
{
CUDA_KERNEL_LOOP
(
i
,
index_size
*
slice_size
)
{
int
indices_i
=
i
/
slice_size
;
int
slice_i
=
i
-
indices_i
*
slice_size
;
// offset inside the slice
...
...
@@ -129,7 +128,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
ScatterInitCUDAKernel
<
T
,
IndexT
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
p_index
,
p_output
,
index_size
,
slice_size
,
overwrite
);
p_index
,
p_output
,
index_size
,
slice_size
);
}
ScatterCUDAKernel
<
T
,
IndexT
><<<
...
...
@@ -138,6 +137,37 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
p_src
,
p_index
,
p_output
,
index_size
,
slice_size
,
overwrite
);
}
// The function is only for scatter grad x,
// however update grad use gather
template
<
typename
T
,
typename
IndexT
=
int
>
void
GPUScatterGradForX
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
index
,
Tensor
*
output
)
{
IndexT
index_size
=
index
.
dims
()[
0
];
auto
dst_dims
=
output
->
dims
();
// slice size
IndexT
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
dst_dims
.
size
();
++
i
)
slice_size
*=
dst_dims
[
i
];
const
IndexT
*
p_index
=
index
.
data
<
IndexT
>
();
T
*
p_output
=
output
->
data
<
T
>
();
const
size_t
&
slice_bytes
=
slice_size
*
sizeof
(
T
);
// set block and grid num
int64_t
block
=
512
;
int64_t
n
=
slice_size
*
index_size
;
int64_t
height
=
(
n
+
block
-
1
)
/
block
;
int64_t
max_grid_dimx
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
GetCUDAMaxGridDimSize
()
.
x
;
int64_t
grid
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
ScatterInitCUDAKernel
<
T
,
IndexT
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
p_index
,
p_output
,
index_size
,
slice_size
);
}
template
<
typename
DeviceContext
,
typename
T
,
typename
IndexT
=
int
>
void
GPUScatterNdAdd
(
const
framework
::
ExecutionContext
&
context
,
const
Tensor
&
update
,
const
Tensor
&
index
,
...
...
paddle/fluid/operators/scatter.h
浏览文件 @
9514b4aa
...
...
@@ -171,6 +171,24 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
}
}
// The function is only for scatter grad x,
// however update grad use gather
template
<
typename
T
,
typename
IndexT
=
int
>
void
CPUScatterGradForX
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
index
,
Tensor
*
output
)
{
int
index_size
=
index
.
dims
()[
0
];
auto
dst_dims
=
output
->
dims
();
const
IndexT
*
p_index
=
index
.
data
<
IndexT
>
();
T
*
p_output
=
output
->
data
<
T
>
();
size_t
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
dst_dims
.
size
();
++
i
)
slice_size
*=
dst_dims
[
i
];
const
size_t
slice_bytes
=
slice_size
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
const
IndexT
&
index_
=
p_index
[
i
];
memset
(
p_output
+
slice_size
*
index_
,
0
,
slice_bytes
);
}
}
template
<
typename
T
,
typename
IndexT
=
int
>
void
ScatterNdAdd
(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
update
,
const
Tensor
&
index
,
Tensor
*
output
)
{
...
...
paddle/fluid/operators/scatter_nd_add_op.cu
浏览文件 @
9514b4aa
...
...
@@ -65,7 +65,6 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
}
if
(
dUpdates
)
{
...
...
paddle/fluid/operators/scatter_nd_add_op.h
浏览文件 @
9514b4aa
...
...
@@ -71,8 +71,7 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopySync
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
paddle/fluid/operators/scatter_op.cc
浏览文件 @
9514b4aa
...
...
@@ -138,9 +138,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer,
"Updates"
);
DECLARE_INPLACE_OP_INFERER
(
ScatterInplaceInferer
,
{
"X"
,
"Out"
});
DECLARE_INPLACE_OP_INFERER
(
ScatterGradInplaceInferer
,
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)});
}
// namespace operators
}
// namespace paddle
...
...
@@ -151,8 +148,7 @@ REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker,
ops
::
ScatterGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
ScatterInplaceInferer
);
REGISTER_OPERATOR
(
scatter_grad
,
ops
::
ScatterGradOp
,
ops
::
ScatterGradNoNeedBufferVarsInferer
,
ops
::
ScatterGradInplaceInferer
);
ops
::
ScatterGradNoNeedBufferVarsInferer
);
REGISTER_OP_CPU_KERNEL
(
scatter
,
ops
::
ScatterOpKernel
<
float
>
,
ops
::
ScatterOpKernel
<
double
>
,
ops
::
ScatterOpKernel
<
int
>
,
ops
::
ScatterOpKernel
<
int64_t
>
);
...
...
paddle/fluid/operators/scatter_op.cu
浏览文件 @
9514b4aa
...
...
@@ -67,26 +67,32 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto
*
dUpdates
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Updates"
));
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Ids]
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"scatter_op Index holds the wrong type, it holds [%s],
"
"scatter_op index holds the wrong type, it holds [%s],
"
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
dX
)
{
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUScatterGradForX
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
else
{
GPUScatterGradForX
<
T
,
int64_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Ids]
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUGather
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dOut
,
*
Ids
,
dUpdates
);
...
...
paddle/fluid/operators/scatter_op.h
浏览文件 @
9514b4aa
...
...
@@ -79,13 +79,6 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
dOut
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
if
(
dX
)
{
// In place gradient: dX = dO
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Ids]
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
...
...
@@ -99,6 +92,19 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
dX
)
{
framework
::
TensorCopy
(
*
dOut
,
ctx
.
GetPlace
(),
dX
);
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUScatterGradForX
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
else
{
CPUScatterGradForX
<
T
,
int64_t
>
(
ctx
.
device_context
(),
*
Ids
,
dX
);
}
}
if
(
dUpdates
)
{
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Ids]
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUGather
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dOut
,
*
Ids
,
dUpdates
);
}
else
{
...
...
python/paddle/fluid/tests/unittests/test_scatter_nd_op.py
浏览文件 @
9514b4aa
...
...
@@ -78,7 +78,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterNdAddWithEmptyIndex
(
OpTest
):
...
...
@@ -101,7 +101,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'X'
,
'Updates'
],
'Out'
)
class
TestScatterNdAddWithHighRankSame
(
OpTest
):
...
...
@@ -111,11 +111,11 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"scatter_nd_add"
shape
=
(
10
,
9
,
8
,
1
,
15
)
shape
=
(
3
,
2
,
2
,
1
,
10
)
ref_np
=
np
.
random
.
rand
(
*
shape
).
astype
(
"float64"
)
index_np
=
np
.
vstack
(
[
np
.
random
.
randint
(
0
,
s
,
size
=
1
5
0
)
for
s
in
shape
]).
T
.
astype
(
"int32"
)
0
,
s
,
size
=
1
0
0
)
for
s
in
shape
]).
T
.
astype
(
"int32"
)
update_shape
=
judge_update_shape
(
ref_np
,
index_np
)
updates_np
=
np
.
random
.
rand
(
*
update_shape
).
astype
(
"float64"
)
expect_np
=
numpy_scatter_nd_add
(
ref_np
.
copy
(),
index_np
,
updates_np
)
...
...
@@ -127,7 +127,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterNdAddWithHighRankDiff
(
OpTest
):
...
...
@@ -137,7 +137,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"scatter_nd_add"
shape
=
(
10
,
9
,
8
,
1
,
15
)
shape
=
(
8
,
2
,
2
,
1
,
10
)
ref_np
=
np
.
random
.
rand
(
*
shape
).
astype
(
"double"
)
index
=
np
.
vstack
([
np
.
random
.
randint
(
0
,
s
,
size
=
500
)
for
s
in
shape
]).
T
index_np
=
index
.
reshape
([
10
,
5
,
10
,
5
]).
astype
(
"int64"
)
...
...
@@ -152,7 +152,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
#Test Python API
...
...
python/paddle/fluid/tests/unittests/test_scatter_op.py
浏览文件 @
9514b4aa
...
...
@@ -37,7 +37,7 @@ class TestScatterOp(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
"X"
,
"Updates"
],
"Out"
)
class
TestScatterOp0
(
OpTest
):
...
...
@@ -56,7 +56,7 @@ class TestScatterOp0(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
"X"
,
"Updates"
],
"Out"
)
class
TestScatterOp1
(
OpTest
):
...
...
@@ -78,7 +78,7 @@ class TestScatterOp1(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
"X"
,
"Updates"
],
"Out"
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -102,7 +102,7 @@ class TestScatterOp2(OpTest):
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad_with_place
(
place
,
[
'
X'
,
'Updates'
],
'Out'
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -130,7 +130,7 @@ class TestScatterOp3(OpTest):
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad_with_place
(
place
,
[
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterOp4
(
OpTest
):
...
...
@@ -148,7 +148,7 @@ class TestScatterOp4(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad
([
'
X'
,
'Updates'
],
'Out'
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -172,7 +172,7 @@ class TestScatterOp5(OpTest):
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'
Updates'
],
'Out'
,
in_place
=
True
)
self
.
check_grad_with_place
(
place
,
[
'
X'
,
'Updates'
],
'Out'
)
class
TestScatterAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录