Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e14ed71c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e14ed71c
编写于
11月 09, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
11月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine the performance of gather Op (#28458)
上级
e29ab5ea
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
24 addition
and
14 deletion
+24
-14
paddle/fluid/operators/gather.cu.h
paddle/fluid/operators/gather.cu.h
+16
-12
paddle/fluid/operators/gather_op.cc
paddle/fluid/operators/gather_op.cc
+5
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+3
-2
未找到文件。
paddle/fluid/operators/gather.cu.h
浏览文件 @
e14ed71c
...
@@ -20,8 +20,8 @@ limitations under the License. */
...
@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -165,14 +165,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
...
@@ -165,14 +165,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int
out_index_dim_size
,
int
out_index_dim_size
,
int
input_index_dim_size
,
int
size
)
{
int
input_index_dim_size
,
int
size
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
outer_size
=
outer_dim_size
*
out_index_dim_size
;
for
(;
idx
<
size
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(;
idx
<
size
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
inner_dim_index
=
idx
/
(
outer_dim_size
*
out_index_dim_size
);
int
inner_dim_index
=
idx
/
outer_size
;
int
next_idx
=
idx
%
(
outer_dim_size
*
out_index_dim_size
);
int
next_idx
=
idx
-
outer_size
*
inner_dim_index
;
int
index_dim_index
=
next_idx
/
(
outer_dim_size
);
int
index_dim_index
=
next_idx
/
outer_dim_size
;
int
out_dim_index
=
next_idx
%
outer_dim_size
;
int
index_val
=
index
[
index_dim_index
];
int
out_dim_index
=
next_idx
-
outer_dim_size
*
index_dim_index
;
int
input_index
=
int
input_index
=
inner_dim_index
*
(
outer_dim_size
*
input_index_dim_size
)
+
inner_dim_index
*
(
outer_dim_size
*
input_index_dim_size
)
+
index
[
index_dim_index
]
*
outer_dim_size
+
out_dim_index
;
index
_val
*
outer_dim_size
+
out_dim_index
;
out
[
idx
]
=
input
[
input_index
];
out
[
idx
]
=
input
[
input_index
];
}
}
}
}
...
@@ -234,10 +236,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
...
@@ -234,10 +236,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
place
);
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
place
);
int
out_size
=
out
->
numel
();
int
out_size
=
out
->
numel
();
int
threads
=
512
;
platform
::
GpuLaunchConfig
config
=
int
grid
=
(
out_size
+
threads
-
1
)
/
threads
;
platform
::
GetGpuLaunchConfig1D
(
ctx
.
cuda_device_context
(),
out_size
)
;
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
GatherGPUKernel
<
T
,
U
><<<
grid
,
threads
,
0
,
stream
>>>
(
GatherGPUKernel
<
T
,
U
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
input_data
,
index_data
,
out_data
,
outer_dim_size
,
inner_dim_size
,
input_data
,
index_data
,
out_data
,
outer_dim_size
,
inner_dim_size
,
index_size
,
index_dim_size
,
out_size
);
index_size
,
index_dim_size
,
out_size
);
}
}
...
@@ -280,10 +283,11 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
...
@@ -280,10 +283,11 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
int
out_index_dim_size
=
out_dim
[
axis_index
];
int
out_index_dim_size
=
out_dim
[
axis_index
];
operators
::
math
::
set_constant
(
*
dev_ctx
,
out
,
0.0
);
operators
::
math
::
set_constant
(
*
dev_ctx
,
out
,
0.0
);
int
threads
=
512
;
platform
::
GpuLaunchConfig
config
=
int
grid
=
(
input_size
+
threads
-
1
)
/
threads
;
platform
::
GetGpuLaunchConfig1D
(
ctx
.
cuda_device_context
(),
input_size
)
;
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
GatherGradGPUKernel
<
T
,
U
><<<
grid
,
threads
,
0
,
stream
>>>
(
GatherGradGPUKernel
<
T
,
U
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
input_data
,
index_data
,
out_data
,
outer_dim_size
,
inner_dim_size
,
input_data
,
index_data
,
out_data
,
outer_dim_size
,
inner_dim_size
,
input_index_dim_size
,
out_index_dim_size
,
input_size
);
input_index_dim_size
,
out_index_dim_size
,
input_size
);
}
}
...
...
paddle/fluid/operators/gather_op.cc
浏览文件 @
e14ed71c
...
@@ -66,6 +66,11 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -66,6 +66,11 @@ class GatherOp : public framework::OperatorWithKernel {
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
return
expected_kernel_type
;
}
};
};
class
GatherGradOp
:
public
framework
::
OperatorWithKernel
{
class
GatherGradOp
:
public
framework
::
OperatorWithKernel
{
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
e14ed71c
...
@@ -16,7 +16,7 @@ from __future__ import print_function
...
@@ -16,7 +16,7 @@ from __future__ import print_function
from
..fluid.layers
import
core
from
..fluid.layers
import
core
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.framework
import
Variable
,
OpProtoHolder
,
in_dygraph_mode
,
convert_np_dtype_to_dtype_
from
..fluid.framework
import
Variable
,
OpProtoHolder
,
in_dygraph_mode
,
convert_np_dtype_to_dtype_
,
device_guard
from
..fluid.data_feeder
import
convert_dtype
,
check_variable_and_dtype
,
check_type
,
check_dtype
from
..fluid.data_feeder
import
convert_dtype
,
check_variable_and_dtype
,
check_type
,
check_dtype
from
..fluid.layers.tensor
import
fill_constant
from
..fluid.layers.tensor
import
fill_constant
from
..fluid.layers
import
utils
from
..fluid.layers
import
utils
...
@@ -794,6 +794,7 @@ def gather(x, index, axis=None, name=None):
...
@@ -794,6 +794,7 @@ def gather(x, index, axis=None, name=None):
axis
=
0
axis
=
0
axis_tensor
=
axis
axis_tensor
=
axis
if
not
isinstance
(
axis
,
Variable
):
if
not
isinstance
(
axis
,
Variable
):
with
device_guard
(
"cpu"
):
axis_tensor
=
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
axis
)
axis_tensor
=
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
axis
)
if
in_dygraph_mode
():
if
in_dygraph_mode
():
return
core
.
ops
.
gather
(
x
,
index
,
axis_tensor
)
return
core
.
ops
.
gather
(
x
,
index
,
axis_tensor
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录