Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
008857be
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
008857be
编写于
5月 14, 2020
作者:
S
ShenLiang
提交者:
GitHub
5月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix error message for scatter and scatter_nd (#24514)
上级
14376486
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
155 addition
and
91 deletion
+155
-91
paddle/fluid/operators/scatter.cu.h
paddle/fluid/operators/scatter.cu.h
+9
-3
paddle/fluid/operators/scatter.h
paddle/fluid/operators/scatter.h
+28
-13
paddle/fluid/operators/scatter_nd_add_op.cc
paddle/fluid/operators/scatter_nd_add_op.cc
+24
-13
paddle/fluid/operators/scatter_nd_add_op.cu
paddle/fluid/operators/scatter_nd_add_op.cu
+13
-8
paddle/fluid/operators/scatter_nd_add_op.h
paddle/fluid/operators/scatter_nd_add_op.h
+15
-10
paddle/fluid/operators/scatter_op.cc
paddle/fluid/operators/scatter_op.cc
+21
-13
paddle/fluid/operators/scatter_op.cu
paddle/fluid/operators/scatter_op.cu
+22
-15
paddle/fluid/operators/scatter_op.h
paddle/fluid/operators/scatter_op.h
+23
-16
未找到文件。
paddle/fluid/operators/scatter.cu.h
浏览文件 @
008857be
...
...
@@ -95,11 +95,17 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
const
auto
&
ctx
=
context
.
device_context
();
if
(
index
.
dims
().
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index
.
dims
()[
1
],
1
,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in scatter_op."
);
platform
::
errors
::
InvalidArgument
(
"index.dims()[1] should be 1 when "
"index.dims().size() = 2 in scatter_op."
"But received value is [%d]"
,
index
.
dims
()[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
index
.
dims
().
size
(),
1
,
"index.dims().size() should be 1 or 2 in scatter_op."
);
platform
::
errors
::
InvalidArgument
(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]"
,
index
.
dims
().
size
()));
}
int
index_size
=
index
.
dims
()[
0
];
...
...
paddle/fluid/operators/scatter.h
浏览文件 @
008857be
...
...
@@ -73,15 +73,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
template
<
typename
T
,
typename
IndexT
=
int
>
void
ScatterAssign
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
// check index of shape 1-D
if
(
index
.
dims
().
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index
.
dims
()[
1
],
1
,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in scatter_op."
);
platform
::
errors
::
InvalidArgument
(
"index.dims()[1] should be 1 when "
"index.dims().size() =2 in scatter_op."
"But received value is [%d]"
,
index
.
dims
()[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
index
.
dims
().
size
(),
1
,
"index.dims().size() should be 1 or 2 in scatter_op."
);
platform
::
errors
::
InvalidArgument
(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]"
,
index
.
dims
().
size
()));
}
int
index_size
=
index
.
dims
()[
0
];
...
...
@@ -94,7 +102,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
// check src shape and dst shape should match
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
i
++
)
PADDLE_ENFORCE_EQ
(
src_dims
[
i
],
dst_dims
[
i
]);
PADDLE_ENFORCE_EQ
(
src_dims
[
i
],
dst_dims
[
i
],
platform
::
errors
::
InvalidArgument
(
"src shape and dst shape should match"
));
// slice size
size_t
slice_size
=
1
;
...
...
@@ -111,12 +121,14 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
template
<
typename
T
,
typename
IndexT
=
int
>
void
ScatterAssignAdd
(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
device_context
().
GetPlace
()),
true
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
device_context
().
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
// check index of shape 1-D
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
1
||
(
index
.
dims
().
size
()
==
2
&&
index
.
dims
()[
1
]
==
1
),
""
);
PADDLE_ENFORCE_EQ
(
index
.
dims
().
size
()
==
1
||
(
index
.
dims
().
size
()
==
2
&&
index
.
dims
()[
1
]
==
1
),
true
,
platform
::
errors
::
InvalidArgument
(
"index's shape is error."
));
int
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
.
dims
();
...
...
@@ -130,7 +142,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// check src shape and dst shape should match
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
i
++
)
PADDLE_ENFORCE_EQ
(
src_dims
[
i
],
dst_dims
[
i
]);
PADDLE_ENFORCE_EQ
(
src_dims
[
i
],
dst_dims
[
i
],
platform
::
errors
::
InvalidArgument
(
"src shape and dst shape should match"
));
// slice size
size_t
slice_size
=
1
;
...
...
@@ -156,8 +170,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
template
<
typename
T
,
typename
IndexT
=
int
>
void
ScatterNdAdd
(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
update
,
const
Tensor
&
index
,
Tensor
*
output
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
device_context
().
GetPlace
()),
true
,
"It should be running on the CPU"
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
device_context
().
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It should be running on the CPU"
));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
auto
index_dims
=
index
.
dims
();
...
...
paddle/fluid/operators/scatter_nd_add_op.cc
浏览文件 @
008857be
...
...
@@ -26,13 +26,19 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of ScatterNdAddOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Index"
),
true
,
"Input(Index) of ScatterNdAddOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Updates"
),
true
,
"Input(Updates) of ScatterNdAddOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) of ScatterNdAddOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Index"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Index) of ScatterNdAddOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Updates"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Updates) of ScatterNdAddOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of ScatterNdAddOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Output(Out) of ScatterNdAddOp should not be null."
));
auto
ref_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
ref_dims_size
=
ref_dims
.
size
();
...
...
@@ -43,9 +49,11 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE
(
index_dims
[
index_dims_size
-
1
],
ref_dims_size
,
"Input(Index).shape[-1] should be no greater than Input(X).rank"
);
platform
::
errors
::
InvalidArgument
(
"Input(Index).shape[-1] should be no greater than Input(X).rank"
));
PADDLE_ENFORCE_GE
(
index_dims_size
,
2UL
,
"The rank of Input(Index) should be greater than 1"
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Index) should be greater than 1"
));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std
::
vector
<
int64_t
>
r_updates_dims
;
...
...
@@ -56,12 +64,14 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
r_updates_dims
.
emplace_back
(
ref_dims
[
i
]);
}
PADDLE_ENFORCE_EQ
(
r_updates_dims
.
size
(),
updates_dims_size
,
"Updates has wrong shape"
);
PADDLE_ENFORCE_EQ
(
r_updates_dims
.
size
(),
updates_dims_size
,
platform
::
errors
::
InvalidArgument
(
"Updates has wrong shape"
));
for
(
int64_t
i
=
0
;
i
<
updates_dims_size
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
r_updates_dims
[
i
],
updates_dims
[
i
],
"Updates has wrong shape"
);
PADDLE_ENFORCE_EQ
(
r_updates_dims
[
i
],
updates_dims
[
i
],
platform
::
errors
::
InvalidArgument
(
"Updates has wrong shape"
));
}
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
...
...
@@ -72,7 +82,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Updates"
),
"Ref and Updates must have same type"
);
platform
::
errors
::
InvalidArgument
(
"Ref and Updates must have same type"
));
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
device_context
());
}
...
...
paddle/fluid/operators/scatter_nd_add_op.cu
浏览文件 @
008857be
...
...
@@ -25,7 +25,8 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on GPU device."
);
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
Updates
=
ctx
.
Input
<
Tensor
>
(
"Updates"
);
...
...
@@ -35,12 +36,15 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUScatterNdAdd
<
DeviceContext
,
T
,
int32_t
>
(
ctx
,
*
Updates
,
*
Ids
,
Out
);
}
else
{
...
...
@@ -54,7 +58,8 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on GPU device."
);
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dUpdates
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Updates"
));
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
...
...
paddle/fluid/operators/scatter_nd_add_op.h
浏览文件 @
008857be
...
...
@@ -27,8 +27,9 @@ template <typename T>
class
ScatterNdAddOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on CPU."
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
Updates
=
ctx
.
Input
<
Tensor
>
(
"Updates"
);
...
...
@@ -39,12 +40,15 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
ScatterNdAdd
<
T
,
int32_t
>
(
ctx
,
*
Updates
,
*
Ids
,
Out
);
...
...
@@ -58,8 +62,9 @@ template <typename T>
class
ScatterNdAddGradientOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on CPU."
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dUpdates
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Updates"
));
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
...
...
paddle/fluid/operators/scatter_op.cc
浏览文件 @
008857be
...
...
@@ -24,24 +24,32 @@ class ScatterOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ScatterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) of ScatterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Updates"
),
"Input(Updates) of ScatterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ScatterOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(X) of ScatterOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Ids"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Ids) of ScatterOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Updates"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Updates) of ScatterOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Out) of ScatterOp should not be null."
));
auto
updates_dims
=
ctx
->
GetInputDim
(
"Updates"
);
auto
ref_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Ids"
).
size
(),
1
,
"Update Ids should be 1-D."
);
PADDLE_ENFORCE_EQ
(
ref_dims
.
size
(),
updates_dims
.
size
(),
"Xerence and Updates should have the same shape size"
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Ids"
).
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Update Ids should be 1-D."
));
PADDLE_ENFORCE_EQ
(
ref_dims
.
size
(),
updates_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Rerence and Updates should have the same shape size."
));
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Updates"
)[
0
],
ctx
->
GetInputDim
(
"Ids"
)[
0
],
"Updates and Ids should have same batch-size."
);
platform
::
errors
::
InvalidArgument
(
"Updates and Ids should have same batch-size."
));
ctx
->
SetOutputDim
(
"Out"
,
ref_dims
);
}
...
...
paddle/fluid/operators/scatter_op.cu
浏览文件 @
008857be
...
...
@@ -24,8 +24,9 @@ template <typename T>
class
ScatterOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
Updates
=
ctx
.
Input
<
Tensor
>
(
"Updates"
);
...
...
@@ -39,11 +40,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
"scatter_op Index holds the wrong type, it holds %s, but desires to be "
"%s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
platform
::
errors
::
InvalidArgument
(
"scatter_op Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUScatterAssign
<
T
,
int32_t
>
(
ctx
,
*
Updates
,
*
Ids
,
Out
,
overwrite
);
}
else
{
...
...
@@ -56,8 +60,9 @@ template <typename T>
class
ScatterGradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dUpdates
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Updates"
));
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
...
...
@@ -74,12 +79,14 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
"scatter_op Index holds the wrong type, it holds %s, but desires to "
"be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
platform
::
errors
::
InvalidArgument
(
"scatter_op Index holds the wrong type, it holds [%s], "
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
// Gradient by Gather: dUpdates = dO[Ids]
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUGather
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dOut
,
*
Ids
,
dUpdates
);
...
...
paddle/fluid/operators/scatter_op.h
浏览文件 @
008857be
...
...
@@ -27,8 +27,9 @@ template <typename T>
class
ScatterOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on CPU."
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
Updates
=
ctx
.
Input
<
Tensor
>
(
"Updates"
);
...
...
@@ -41,12 +42,15 @@ class ScatterOpKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
Ids
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE
(
index_type_match
,
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
overwrite
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
ScatterAssign
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
Updates
,
*
Ids
,
Out
);
...
...
@@ -67,8 +71,9 @@ template <typename T>
class
ScatterGradientOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on CPU."
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dUpdates
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Updates"
));
auto
*
Ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
...
...
@@ -86,12 +91,14 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
"scatter_op index holds the wrong type, it holds %s, but desires to "
"be %s or %s"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
platform
::
errors
::
InvalidArgument
(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUGather
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dOut
,
*
Ids
,
dUpdates
);
}
else
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录