Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7d95e598
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7d95e598
编写于
3月 04, 2021
作者:
Z
Zhang Ting
提交者:
GitHub
3月 04, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support float16 for temporal_shift op (#31432)
上级
3a8ef10e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
10 deletion
+34
-10
paddle/fluid/operators/temporal_shift_op.cu
paddle/fluid/operators/temporal_shift_op.cu
+12
-9
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
...on/paddle/fluid/tests/unittests/test_temporal_shift_op.py
+22
-1
未找到文件。
paddle/fluid/operators/temporal_shift_op.cu
浏览文件 @
7d95e598
...
...
@@ -33,8 +33,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
...
...
@@ -69,8 +69,8 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
...
...
@@ -163,8 +163,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
temporal_shift
,
ops
::
TemporalShiftOpCUDAKernel
<
float
>
,
ops
::
TemporalShiftOpCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
temporal_shift_grad
,
ops
::
TemporalShiftGradOpCUDAKernel
<
float
>
,
ops
::
TemporalShiftGradOpCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
temporal_shift
,
ops
::
TemporalShiftOpCUDAKernel
<
float
>
,
ops
::
TemporalShiftOpCUDAKernel
<
double
>
,
ops
::
TemporalShiftOpCUDAKernel
<
paddle
::
platform
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
temporal_shift_grad
,
ops
::
TemporalShiftGradOpCUDAKernel
<
float
>
,
ops
::
TemporalShiftGradOpCUDAKernel
<
double
>
,
ops
::
TemporalShiftGradOpCUDAKernel
<
paddle
::
platform
::
float16
>
);
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
浏览文件 @
7d95e598
...
...
@@ -40,7 +40,7 @@ class TestTemporalShift(OpTest):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
op_type
=
'temporal_shift'
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
'float64'
)
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
self
.
attrs
=
{
"seg_num"
:
self
.
seg_num
,
...
...
@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self
.
x_shape
=
(
6
,
4
,
4
,
4
)
self
.
seg_num
=
3
self
.
shift_ratio
=
0.25
self
.
dtype
=
'float64'
class
TestTemporalShift2
(
TestTemporalShift
):
...
...
@@ -78,6 +79,26 @@ class TestTemporalShift3(TestTemporalShift):
self
.
shift_ratio
=
0.3
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestTemporalShiftFP16
(
TestTemporalShift
):
def
initTestCase
(
self
):
self
.
x_shape
=
(
3
,
10
,
5
,
5
)
self
.
seg_num
=
1
self
.
shift_ratio
=
0.3
self
.
dtype
=
'float16'
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
)
def
test_check_grad_ignore_uv
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
)
class
TestTemporalShiftAPI
(
unittest
.
TestCase
):
def
test_api
(
self
):
input
=
paddle
.
randn
([
6
,
4
,
2
,
2
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录