Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ca725c82
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca725c82
编写于
7月 15, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
7月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve fp16 performance of slice_grad, test=develop (#25523)
上级
5d3766ff
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
3 addition
and
136 deletion
+3
-136
paddle/fluid/operators/slice_op.cu
paddle/fluid/operators/slice_op.cu
+1
-134
paddle/fluid/operators/slice_op.h
paddle/fluid/operators/slice_op.h
+2
-2
未找到文件。
paddle/fluid/operators/slice_op.cu
浏览文件 @
ca725c82
...
...
@@ -12,145 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
size_t
D
>
__global__
void
Padding
(
const
paddle
::
platform
::
float16
*
d_out
,
const
int64_t
*
out_dims
,
const
int64_t
*
in_dims
,
const
int64_t
*
offsets
,
int64_t
n
,
paddle
::
platform
::
float16
*
d_in
)
{
int64_t
out_idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
out_idx
<
n
)
{
int64_t
out_idx_tmp
=
out_idx
;
int64_t
coords
[
D
]
=
{
0
};
for
(
int
i
=
D
-
1
;
i
>=
0
;
--
i
)
{
coords
[
i
]
=
out_idx_tmp
%
out_dims
[
i
];
out_idx_tmp
/=
out_dims
[
i
];
coords
[
i
]
+=
offsets
[
i
];
}
int64_t
in_idx
=
0
;
for
(
int
i
=
0
;
i
<
D
;
++
i
)
{
in_idx
=
in_idx
*
in_dims
[
i
]
+
coords
[
i
];
}
d_in
[
in_idx
]
=
d_out
[
out_idx
];
}
}
template
<
>
class
SliceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
:
public
framework
::
OpKernel
<
paddle
::
platform
::
float16
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_in
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
d_in
->
mutable_data
<
paddle
::
platform
::
float16
>
(
ctx
.
GetPlace
());
auto
out_dims
=
d_out
->
dims
();
auto
in_dims
=
d_in
->
dims
();
int
rank
=
out_dims
.
size
();
std
::
vector
<
int64_t
>
offsets
(
rank
,
0
);
auto
axes
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"axes"
);
auto
starts_int
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"starts"
);
std
::
vector
<
int64_t
>
starts
(
starts_int
.
begin
(),
starts_int
.
end
());
auto
list_new_starts_tensor
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"StartsTensorList"
);
if
(
list_new_starts_tensor
.
size
()
>
0
)
{
starts
=
GetDataFromTensorList
<
int64_t
>
(
list_new_starts_tensor
);
}
else
if
(
ctx
.
HasInput
(
"StartsTensor"
))
{
auto
*
starts_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"StartsTensor"
);
starts
=
GetDataFromTensor
<
int64_t
>
(
starts_tensor
);
}
for
(
size_t
i
=
0
;
i
<
starts
.
size
();
++
i
)
{
if
(
starts
[
i
]
<
0
)
{
starts
[
i
]
+=
in_dims
[
axes
[
i
]];
}
offsets
[
axes
[
i
]]
=
std
::
max
(
starts
[
i
],
static_cast
<
int64_t
>
(
0
));
}
math
::
SetConstant
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
set_zero
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
CUDADeviceContext
>();
set_zero
(
dev_ctx
,
d_in
,
static_cast
<
paddle
::
platform
::
float16
>
(
0
));
int64_t
numel
=
d_out
->
numel
();
dim3
blocks
((
numel
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
);
dim3
threads
(
PADDLE_CUDA_NUM_THREADS
);
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
const
std
::
vector
<
int64_t
>
out_shape
=
framework
::
vectorize
<
int64_t
>
(
out_dims
);
const
std
::
vector
<
int64_t
>
in_shape
=
framework
::
vectorize
<
int64_t
>
(
in_dims
);
framework
::
Tensor
out_dims_tensor
;
framework
::
Tensor
in_dims_tensor
;
framework
::
Tensor
offsets_tensor
;
framework
::
TensorFromVector
(
out_shape
,
ctx
.
device_context
(),
&
out_dims_tensor
);
framework
::
TensorFromVector
(
in_shape
,
ctx
.
device_context
(),
&
in_dims_tensor
);
framework
::
TensorFromVector
(
offsets
,
ctx
.
device_context
(),
&
offsets_tensor
);
const
int64_t
*
out_dims_ptr
=
out_dims_tensor
.
data
<
int64_t
>
();
const
int64_t
*
in_dims_ptr
=
in_dims_tensor
.
data
<
int64_t
>
();
const
int64_t
*
offsets_ptr
=
offsets_tensor
.
data
<
int64_t
>
();
switch
(
rank
)
{
case
1
:
Padding
<
1
><<<
blocks
,
threads
,
0
,
stream
>>>
(
d_out
->
data
<
paddle
::
platform
::
float16
>
(),
out_dims_ptr
,
in_dims_ptr
,
offsets_ptr
,
numel
,
d_in
->
data
<
paddle
::
platform
::
float16
>
());
break
;
case
2
:
Padding
<
2
><<<
blocks
,
threads
,
0
,
stream
>>>
(
d_out
->
data
<
paddle
::
platform
::
float16
>
(),
out_dims_ptr
,
in_dims_ptr
,
offsets_ptr
,
numel
,
d_in
->
data
<
paddle
::
platform
::
float16
>
());
break
;
case
3
:
Padding
<
3
><<<
blocks
,
threads
,
0
,
stream
>>>
(
d_out
->
data
<
paddle
::
platform
::
float16
>
(),
out_dims_ptr
,
in_dims_ptr
,
offsets_ptr
,
numel
,
d_in
->
data
<
paddle
::
platform
::
float16
>
());
break
;
case
4
:
Padding
<
4
><<<
blocks
,
threads
,
0
,
stream
>>>
(
d_out
->
data
<
paddle
::
platform
::
float16
>
(),
out_dims_ptr
,
in_dims_ptr
,
offsets_ptr
,
numel
,
d_in
->
data
<
paddle
::
platform
::
float16
>
());
break
;
case
5
:
Padding
<
5
><<<
blocks
,
threads
,
0
,
stream
>>>
(
d_out
->
data
<
paddle
::
platform
::
float16
>
(),
out_dims_ptr
,
in_dims_ptr
,
offsets_ptr
,
numel
,
d_in
->
data
<
paddle
::
platform
::
float16
>
());
break
;
case
6
:
Padding
<
6
><<<
blocks
,
threads
,
0
,
stream
>>>
(
d_out
->
data
<
paddle
::
platform
::
float16
>
(),
out_dims_ptr
,
in_dims_ptr
,
offsets_ptr
,
numel
,
d_in
->
data
<
paddle
::
platform
::
float16
>
());
break
;
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
slice
,
ops
::
SliceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SliceKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
...
...
paddle/fluid/operators/slice_op.h
浏览文件 @
ca725c82
...
...
@@ -350,7 +350,7 @@ class SliceGradKernel : public framework::OpKernel<T> {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
context
.
GetPlace
());
T
value
=
0.0
;
T
value
=
T
(
0
)
;
math
::
SetConstant
<
DeviceContext
,
T
>
functor
;
for
(
int
i
=
0
;
i
<
d_in_size
;
++
i
)
{
auto
dim
=
input_array
->
at
(
i
).
dims
();
...
...
@@ -440,7 +440,7 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto
d_out_t
=
framework
::
EigenTensor
<
T
,
D
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>::
From
(
*
d_out
,
out_dims
);
d_in_t
.
device
(
place
)
=
d_out_t
.
pad
(
paddings
,
0
);
d_in_t
.
device
(
place
)
=
d_out_t
.
pad
(
paddings
,
T
(
0
)
);
}
};
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录