Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
53e3c534
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53e3c534
编写于
5月 14, 2020
作者:
S
ShenLiang
提交者:
GitHub
5月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix error message, test=develop (#24425)
上级
ea2c4987
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
178 addition
and
111 deletion
+178
-111
paddle/fluid/operators/distributed_ops/allreduce_op.h
paddle/fluid/operators/distributed_ops/allreduce_op.h
+6
-4
paddle/fluid/operators/distributed_ops/broadcast_op.cc
paddle/fluid/operators/distributed_ops/broadcast_op.cc
+6
-4
paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
+12
-7
paddle/fluid/operators/eye_op.cc
paddle/fluid/operators/eye_op.cc
+11
-7
paddle/fluid/operators/gather.cu.h
paddle/fluid/operators/gather.cu.h
+6
-4
paddle/fluid/operators/gather.h
paddle/fluid/operators/gather.h
+24
-12
paddle/fluid/operators/gather_nd_op.cc
paddle/fluid/operators/gather_nd_op.cc
+10
-5
paddle/fluid/operators/gather_nd_op.cu
paddle/fluid/operators/gather_nd_op.cu
+22
-14
paddle/fluid/operators/gather_nd_op.h
paddle/fluid/operators/gather_nd_op.h
+24
-16
paddle/fluid/operators/gather_op.cc
paddle/fluid/operators/gather_op.cc
+9
-6
paddle/fluid/operators/gather_op.cu
paddle/fluid/operators/gather_op.cu
+24
-16
paddle/fluid/operators/gather_op.h
paddle/fluid/operators/gather_op.h
+24
-16
未找到文件。
paddle/fluid/operators/distributed_ops/allreduce_op.h
浏览文件 @
53e3c534
...
@@ -33,8 +33,9 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
...
@@ -33,8 +33,9 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
is_gpu_place
(
place
),
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
"AllReduce op can run on gpu place only for now."
);
platform
::
errors
::
PreconditionNotMet
(
"AllReduce op can run on gpu place only for now."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
...
@@ -49,7 +50,8 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
...
@@ -49,7 +50,8 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
auto
*
comm
=
dev_ctx
.
nccl_comm
();
auto
*
comm
=
dev_ctx
.
nccl_comm
();
// FIXME(typhoonzero): should use nccl stream here.
// FIXME(typhoonzero): should use nccl stream here.
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
PADDLE_ENFORCE_NOT_NULL
(
stream
,
"Should initialize NCCL firstly."
);
PADDLE_ENFORCE_NOT_NULL
(
stream
,
platform
::
errors
::
NotFound
(
"Should initialize NCCL firstly."
));
int
reduce_type
=
ctx
.
Attr
<
int
>
(
"reduce_type"
);
int
reduce_type
=
ctx
.
Attr
<
int
>
(
"reduce_type"
);
ncclRedOp_t
red_type
=
ncclSum
;
ncclRedOp_t
red_type
=
ncclSum
;
...
@@ -67,7 +69,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
...
@@ -67,7 +69,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
red_type
=
ncclMin
;
red_type
=
ncclMin
;
break
;
break
;
}
}
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
PADDLE_ENFORCE
_CUDA_SUCCESS
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
comm
,
stream
));
comm
,
stream
));
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
...
...
paddle/fluid/operators/distributed_ops/broadcast_op.cc
浏览文件 @
53e3c534
...
@@ -26,10 +26,12 @@ class BroadcastOp : public framework::OperatorWithKernel {
...
@@ -26,10 +26,12 @@ class BroadcastOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of BroadcastOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Input(X) of BroadcastOp should not be null."
));
"Output(Output) of ConvOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Output) of ConvOp should not be null."
));
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc
浏览文件 @
53e3c534
...
@@ -34,8 +34,10 @@ template <typename T>
...
@@ -34,8 +34,10 @@ template <typename T>
class
NCCLBroadcastOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
NCCLBroadcastOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE_EQ
(
"The place of ExecutionContext should be CUDAPlace."
);
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The place of ExecutionContext should be CUDAPlace."
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
int
dev_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
()).
device
;
int
dev_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
()).
device
;
...
@@ -43,19 +45,22 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
...
@@ -43,19 +45,22 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
PADDLE_ENFORCE
(
out
->
IsInitialized
(),
PADDLE_ENFORCE_EQ
(
"Currently, the output of broadcast op must be initialized, "
out
->
IsInitialized
(),
true
,
"because this op can only be an In-Place operation."
);
platform
::
errors
::
PreconditionNotMet
(
"Currently, the output of broadcast op must be initialized,"
"because this op can only be an In-Place operation."
));
void
*
send_recv_buffer
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
void
*
send_recv_buffer
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
send_recv_buffer
,
in
->
data
<
void
>
(),
send_recv_buffer
,
in
->
data
<
void
>
(),
"Currently, the broadcast op can only be an In-Place operation."
);
platform
::
errors
::
PreconditionNotMet
(
"Currently, the broadcast op can "
"only be an In-Place operation."
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
comm
=
dev_ctx
.
nccl_comm
();
auto
comm
=
dev_ctx
.
nccl_comm
();
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclBcast
(
PADDLE_ENFORCE
_CUDA_SUCCESS
(
platform
::
dynload
::
ncclBcast
(
send_recv_buffer
,
static_cast
<
size_t
>
(
in
->
numel
()),
send_recv_buffer
,
static_cast
<
size_t
>
(
in
->
numel
()),
platform
::
ToNCCLDataType
(
in
->
type
()),
root_dev_id
,
comm
,
stream
));
platform
::
ToNCCLDataType
(
in
->
type
()),
root_dev_id
,
comm
,
stream
));
...
...
paddle/fluid/operators/eye_op.cc
浏览文件 @
53e3c534
...
@@ -22,16 +22,20 @@ class EyeOp : public framework::OperatorWithKernel {
...
@@ -22,16 +22,20 @@ class EyeOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of EyeOP should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Output(Out) of EyeOP should not be null."
));
auto
num_rows
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"num_rows"
);
auto
num_rows
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"num_rows"
);
PADDLE_ENFORCE
(
num_rows
>=
0
,
PADDLE_ENFORCE_EQ
(
"The value of Input(num_rows) should be non-negative int."
);
num_rows
>=
0
,
true
,
platform
::
errors
::
InvalidArgument
(
"The value of Input(num_rows) should be non-negative int."
));
auto
num_columns
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"num_columns"
);
auto
num_columns
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"num_columns"
);
if
(
num_columns
==
-
1
)
num_columns
=
num_rows
;
if
(
num_columns
==
-
1
)
num_columns
=
num_rows
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE_EQ
(
num_columns
>=
0
,
num_columns
>=
0
,
true
,
"The value of Input(num_columns) should be non-negative int."
);
platform
::
errors
::
InvalidArgument
(
"The value of Input(num_columns) should be non-negative int."
));
ctx
->
SetOutputDim
(
"Out"
,
{
num_rows
,
num_columns
});
ctx
->
SetOutputDim
(
"Out"
,
{
num_rows
,
num_columns
});
}
}
...
...
paddle/fluid/operators/gather.cu.h
浏览文件 @
53e3c534
...
@@ -78,12 +78,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
...
@@ -78,12 +78,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// check index of shape 1-D
// check index of shape 1-D
if
(
index
.
dims
().
size
()
==
1
)
{
if
(
index
.
dims
().
size
()
==
1
)
{
PADDLE_ENFORCE_GT
(
index
.
dims
()[
0
],
0
,
PADDLE_ENFORCE_GT
(
index
.
dims
()[
0
],
0
,
"The index of gather_op should not be empty when the "
platform
::
errors
::
InvalidArgument
(
"index's rank is 1."
);
"The index of gather_op should not be empty"
"when the index's rank is 1."
));
}
else
if
(
index
.
dims
().
size
()
==
2
)
{
}
else
if
(
index
.
dims
().
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index
.
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
index
.
dims
()[
1
],
1
,
" If the index's rank of gather_op is 2, the second "
platform
::
errors
::
InvalidArgument
(
"dimension should be 1."
);
"If the index's rank of gather_op is 2,"
" the second dimension should be 1."
));
}
}
int
index_size
=
index
.
dims
()[
0
];
int
index_size
=
index
.
dims
()[
0
];
...
...
paddle/fluid/operators/gather.h
浏览文件 @
53e3c534
...
@@ -36,15 +36,23 @@ using framework::Tensor;
...
@@ -36,15 +36,23 @@ using framework::Tensor;
template
<
typename
T
,
typename
IndexT
=
int
>
template
<
typename
T
,
typename
IndexT
=
int
>
void
CPUGather
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
void
CPUGather
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
const
Tensor
&
index
,
Tensor
*
output
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
);
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It should be running on the CPU."
));
// check index of shape 1-D
// check index of shape 1-D
if
(
index
.
dims
().
size
()
==
2
)
{
if
(
index
.
dims
().
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index
.
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
"index.dims()[1] should be 1 when index.dims().size() == "
index
.
dims
()[
1
],
1
,
"2 in gather_op."
);
platform
::
errors
::
InvalidArgument
(
"index.dims()[1] should be 1 when index.dims().size() = 2"
"in gather_op, but received value is [%d]."
,
index
.
dims
()[
1
]));
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
index
.
dims
().
size
(),
1
,
PADDLE_ENFORCE_EQ
(
index
.
dims
().
size
(),
1
,
"index.dims().size() should be 1 or 2 in gather_op."
);
platform
::
errors
::
InvalidArgument
(
"index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d]."
,
index
.
dims
().
size
()));
}
}
int64_t
index_size
=
index
.
dims
()[
0
];
int64_t
index_size
=
index
.
dims
()[
0
];
...
@@ -69,8 +77,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
...
@@ -69,8 +77,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
template
<
typename
T
,
typename
IndexT
=
int
>
template
<
typename
T
,
typename
IndexT
=
int
>
void
CPUGatherNd
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
input
,
void
CPUGatherNd
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
input
,
const
Tensor
&
index
,
Tensor
*
output
)
{
const
Tensor
&
index
,
Tensor
*
output
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
PADDLE_ENFORCE_EQ
(
"It should be running on the CPU"
);
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It should be running on the CPU."
));
auto
index_dims
=
index
.
dims
();
auto
index_dims
=
index
.
dims
();
auto
index_dims_size
=
index_dims
.
size
();
auto
index_dims_size
=
index_dims
.
size
();
...
@@ -98,11 +107,14 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
...
@@ -98,11 +107,14 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
int64_t
temp
=
1
;
int64_t
temp
=
1
;
for
(
int64_t
j
=
end_size
-
1
;
j
>=
0
;
--
j
)
{
for
(
int64_t
j
=
end_size
-
1
;
j
>=
0
;
--
j
)
{
IndexT
index_value
=
p_index
[
i
*
end_size
+
j
];
IndexT
index_value
=
p_index
[
i
*
end_size
+
j
];
PADDLE_ENFORCE_LT
(
index_value
,
input_dims
[
j
],
PADDLE_ENFORCE_LT
(
"Input(index[-1)] has wrong value, it is %d"
,
index_value
,
input_dims
[
j
],
index_value
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GE
(
index_value
,
0UL
,
"Input(index[-1)] has wrong value, it is [%d]"
,
index_value
));
"The value of Input(index) must be no less than 0"
);
PADDLE_ENFORCE_GE
(
index_value
,
0UL
,
platform
::
errors
::
InvalidArgument
(
"The value of Input(index) must be no less than 0"
));
index_
+=
(
index_value
*
temp
);
index_
+=
(
index_value
*
temp
);
temp
*=
input_dims
[
j
];
temp
*=
input_dims
[
j
];
...
...
paddle/fluid/operators/gather_nd_op.cc
浏览文件 @
53e3c534
...
@@ -27,11 +27,14 @@ class GatherNdOp : public framework::OperatorWithKernel {
...
@@ -27,11 +27,14 @@ class GatherNdOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of GatherNdOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) of GatherNdOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Index"
),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Index"
),
true
,
"Input(Index) of GatherNdOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(Index) of GatherNdOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of GatherNdOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Output(Out) of GatherNdOp should not be null."
));
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims_size
=
x_dims
.
size
();
auto
x_dims_size
=
x_dims
.
size
();
...
@@ -40,9 +43,11 @@ class GatherNdOp : public framework::OperatorWithKernel {
...
@@ -40,9 +43,11 @@ class GatherNdOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
index_dims
[
index_dims_size
-
1
],
x_dims_size
,
index_dims
[
index_dims_size
-
1
],
x_dims_size
,
"Input(Index).shape[-1] should be no greater than Input(X).rank"
);
platform
::
errors
::
InvalidArgument
(
"Input(Index).shape[-1] should be no greater than Input(X).rank"
));
PADDLE_ENFORCE_GE
(
index_dims_size
,
2UL
,
PADDLE_ENFORCE_GE
(
index_dims_size
,
2UL
,
"The rank of Input(Index) should be greater than 1"
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Index) should be greater than 1"
));
std
::
vector
<
int64_t
>
result_dims
;
std
::
vector
<
int64_t
>
result_dims
;
// The result dims is
// The result dims is
...
...
paddle/fluid/operators/gather_nd_op.cu
浏览文件 @
53e3c534
...
@@ -25,7 +25,8 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -25,7 +25,8 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on GPU device."
);
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
...
@@ -35,12 +36,15 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -35,12 +36,15 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s], but "
paddle
::
framework
::
DataTypeToString
(
index_type
),
"desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUGatherNd
<
DeviceContext
,
T
,
int
>
(
ctx
,
*
x
,
*
index
,
output
);
GPUGatherNd
<
DeviceContext
,
T
,
int
>
(
ctx
,
*
x
,
*
index
,
output
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
...
@@ -54,7 +58,8 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -54,7 +58,8 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on GPU device."
);
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
...
@@ -70,12 +75,15 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -70,12 +75,15 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUScatterNdAdd
<
DeviceContext
,
T
,
int
>
(
ctx
,
*
dO
,
*
index
,
dX
);
GPUScatterNdAdd
<
DeviceContext
,
T
,
int
>
(
ctx
,
*
dO
,
*
index
,
dX
);
...
...
paddle/fluid/operators/gather_nd_op.h
浏览文件 @
53e3c534
...
@@ -27,8 +27,9 @@ template <typename T>
...
@@ -27,8 +27,9 @@ template <typename T>
class
GatherNdOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GatherNdOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
PADDLE_ENFORCE_EQ
(
"This kernel only runs on CPU."
);
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
...
@@ -40,12 +41,15 @@ class GatherNdOpKernel : public framework::OpKernel<T> {
...
@@ -40,12 +41,15 @@ class GatherNdOpKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUGatherNd
<
T
,
int
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
CPUGatherNd
<
T
,
int
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
...
@@ -58,8 +62,9 @@ template <typename T>
...
@@ -58,8 +62,9 @@ template <typename T>
class
GatherNdGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GatherNdGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
PADDLE_ENFORCE_EQ
(
"This kernel only runs on CPU."
);
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
...
@@ -73,12 +78,15 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> {
...
@@ -73,12 +78,15 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
ScatterNdAdd
<
T
,
int32_t
>
(
ctx
,
*
dO
,
*
index
,
dX
);
ScatterNdAdd
<
T
,
int32_t
>
(
ctx
,
*
dO
,
*
index
,
dX
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
...
...
paddle/fluid/operators/gather_op.cc
浏览文件 @
53e3c534
...
@@ -26,12 +26,15 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -26,12 +26,15 @@ class GatherOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of GatherOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Index"
),
"Input(X) of GatherOp should not be null."
));
"Input(Index) of GatherOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Index"
),
true
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
platform
::
errors
::
InvalidArgument
(
"Output(Out) of GatherOp should not be null."
);
"Input(Index) of GatherOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Out) of GatherOp should not be null."
));
auto
index_dims
=
ctx
->
GetInputDim
(
"Index"
);
auto
index_dims
=
ctx
->
GetInputDim
(
"Index"
);
PADDLE_ENFORCE
(
index_dims
.
size
()
==
1
||
PADDLE_ENFORCE
(
index_dims
.
size
()
==
1
||
...
...
paddle/fluid/operators/gather_op.cu
浏览文件 @
53e3c534
...
@@ -24,8 +24,9 @@ template <typename T>
...
@@ -24,8 +24,9 @@ template <typename T>
class
GatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GatherOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on GPU device."
);
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
...
@@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUGather
<
T
,
int
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
GPUGather
<
T
,
int
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
...
@@ -53,8 +57,9 @@ template <typename T>
...
@@ -53,8 +57,9 @@ template <typename T>
class
GatherGradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GatherGradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
"This kernel only runs on GPU device."
);
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on GPU device."
));
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
...
@@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
GPUScatterAssign
<
T
,
int
>
(
ctx
,
*
dO
,
*
index
,
dX
,
GPUScatterAssign
<
T
,
int
>
(
ctx
,
*
dO
,
*
index
,
dX
,
ctx
.
Attr
<
bool
>
(
"overwrite"
));
ctx
.
Attr
<
bool
>
(
"overwrite"
));
...
...
paddle/fluid/operators/gather_op.h
浏览文件 @
53e3c534
...
@@ -27,8 +27,9 @@ template <typename T>
...
@@ -27,8 +27,9 @@ template <typename T>
class
GatherOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GatherOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE_EQ
(
"This kernel only runs on CPU."
);
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
...
@@ -40,12 +41,15 @@ class GatherOpKernel : public framework::OpKernel<T> {
...
@@ -40,12 +41,15 @@ class GatherOpKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CPUGather
<
T
,
int
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
CPUGather
<
T
,
int
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
...
@@ -58,8 +62,9 @@ template <typename T>
...
@@ -58,8 +62,9 @@ template <typename T>
class
GatherGradientOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GatherGradientOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE_EQ
(
"This kernel only runs on CPU."
);
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on CPU."
));
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
@@ -76,12 +81,15 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
...
@@ -76,12 +81,15 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
const
auto
&
index_type
=
index
->
type
();
const
auto
&
index_type
=
index
->
type
();
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
index_type_match
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds %s, but desires to be %s or %s"
,
"Index holds the wrong type, it holds [%s],"
paddle
::
framework
::
DataTypeToString
(
index_type
),
"but desires to be [%s] or [%s]."
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
));
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
overwrite
)
{
if
(
overwrite
)
{
ScatterAssign
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dO
,
*
index
,
dX
);
ScatterAssign
<
T
,
int32_t
>
(
ctx
.
device_context
(),
*
dO
,
*
index
,
dX
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录