Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7f50bb7e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7f50bb7e
编写于
3月 17, 2021
作者:
Z
Zhang Ting
提交者:
GitHub
3月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support NHWC for temporal_shift op (#31642)
上级
402288ad
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
338 addition
and
126 deletion
+338
-126
paddle/fluid/operators/temporal_shift_op.cc
paddle/fluid/operators/temporal_shift_op.cc
+14
-5
paddle/fluid/operators/temporal_shift_op.cu
paddle/fluid/operators/temporal_shift_op.cu
+132
-47
paddle/fluid/operators/temporal_shift_op.h
paddle/fluid/operators/temporal_shift_op.h
+147
-64
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+15
-7
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
...on/paddle/fluid/tests/unittests/test_temporal_shift_op.py
+30
-3
未找到文件。
paddle/fluid/operators/temporal_shift_op.cc
浏览文件 @
7f50bb7e
...
...
@@ -80,7 +80,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
AddInput
(
"X"
,
"The input tensor of temporal shift operator. "
"This is a 4-D tensor with shape of [N*T, C, H, W]. "
"This is a 4-D tensor with shape of [N*T, C, H, W] "
"or [N*T, H, W, C]. "
"While N is the batch size, T is the temporal segment "
"number, C is the channel number, H is the height of "
"features and W is the width of features. "
...
...
@@ -100,15 +101,23 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
"by 1 along the temporal dimension. :attr:`shift_ratio` should be in "
"range [0, 0.5]. Default 0.25."
)
.
SetDefault
(
0.25
);
AddAttr
<
std
::
string
>
(
"data_format"
,
"(string, default NCHW) Only used in "
"an optional string from:
\"
NHWC
\"
,
\"
NCHW
\"
. "
"Specify that the data format of the input and output data is "
"channel_first or channel_last."
)
.
SetDefault
(
"NCHW"
);
AddComment
(
R"DOC(
This operator calculates the temporal shifting features for Input(X).
Input(X) should be in shape of [N*T, C, H, W], while N is the batch
size, T is the temporal segment number specified by :attr:`seg_num`,
C is the channel number, H and W is the height and width of features.
Input(X) should be in shape of [N*T, C, H, W] or [N*T, H, W, C], while
N is the batch size, T is the temporal segment number specified by
:attr:`seg_num`, C is the channel number, H and W is the height and
width of features.
Temporal Shifting is calculated as follows:
Temporal Shifting is calculated as follows
when data format is NCHW
:
Step 1: Reshape Input(X) to [N, T, C, H, W].
...
...
paddle/fluid/operators/temporal_shift_op.cu
浏览文件 @
7f50bb7e
...
...
@@ -19,22 +19,46 @@ namespace operators {
using
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
KeTemporalShiftFw
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
float
shift_ratio
)
{
__global__
void
KeTemporalShiftFw
NCHW
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tc
hw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
int
in
=
tid
/
tchw
;
int
it
=
(
tid
%
tchw
)
/
chw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
tid
]
=
0
;
}
else
{
output
[
tid
]
=
input
[
tid
+
(
src_it
-
it
)
*
chw
];
}
}
}
template
<
typename
T
>
__global__
void
KeTemporalShiftFwNHWC
(
const
T
*
input
,
T
*
output
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
nthwc
;
tid
+=
stride
)
{
int
it
=
(
tid
%
thwc
)
/
hwc
;
int
ic
=
tid
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
...
...
@@ -47,42 +71,65 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
tid
]
=
0
;
}
else
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
output
[
tid
]
=
input
[
src_idx
];
output
[
tid
]
=
input
[
tid
+
(
src_it
-
it
)
*
hwc
];
}
}
}
template
<
typename
T
>
__global__
void
KeTemporalShiftBw
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
float
shift_ratio
)
{
__global__
void
KeTemporalShiftBwNCHW
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
int
in
=
tid
/
tchw
;
int
it
=
(
tid
%
tchw
)
/
chw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
tid
]
=
output_grad
[
tid
+
(
src_it
-
it
)
*
chw
];
}
else
{
input_grad
[
tid
]
=
0
;
}
}
}
template
<
typename
T
>
__global__
void
KeTemporalShiftBwNHWC
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
for
(;
tid
<
nthwc
;
tid
+=
stride
)
{
int
it
=
(
tid
%
thwc
)
/
hwc
;
int
ic
=
tid
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
input_grad
[
src_idx
]
=
output_grad
[
tid
];
input_grad
[
tid
]
=
output_grad
[
tid
+
(
src_it
-
it
)
*
hwc
];
}
else
{
input_grad
[
tid
]
=
0
;
}
}
}
...
...
@@ -98,27 +145,48 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
std
::
string
data_format_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
1
]
:
input
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
2
]
:
input
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
3
]
:
input
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
framework
::
DDim
out_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
framework
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
framework
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
{
nt
,
c
,
h
,
w
}
,
ctx
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
int
pixelNum
=
nt
*
chw
;
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
.
cuda_device_context
(),
pixelNum
);
int
threads
=
1024
;
int
grid
=
(
pixelNum
+
threads
-
1
)
/
threads
;
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
int
blocks_per_sm
=
dev_ctx
.
GetMaxPhysicalThreadCount
()
/
threads
;
grid
=
std
::
min
(
dev_ctx
.
GetSMCount
()
*
blocks_per_sm
,
grid
);
KeTemporalShiftFw
<
T
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
KeTemporalShiftFwNCHW
<
T
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
KeTemporalShiftFwNHWC
<
T
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
};
...
...
@@ -130,32 +198,49 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
std
::
string
data_format_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
c
=
output_grad
->
dims
()[
1
];
const
int
h
=
output_grad
->
dims
()[
2
];
const
int
w
=
output_grad
->
dims
()[
3
];
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
1
]
:
output_grad
->
dims
()[
3
]);
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
2
]
:
output_grad
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
3
]
:
output_grad
->
dims
()[
2
]);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
framework
::
DDim
in_grad_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
framework
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
framework
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
input_grad
->
mutable_data
<
T
>
(
in_grad_dims
,
ctx
.
GetPlace
());
int
pixelNum
=
nt
*
chw
;
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
.
cuda_device_context
(),
pixelNum
);
int
threads
=
1024
;
int
grid
=
(
pixelNum
+
threads
-
1
)
/
threads
;
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
int
blocks_per_sm
=
dev_ctx
.
GetMaxPhysicalThreadCount
()
/
threads
;
grid
=
std
::
min
(
dev_ctx
.
GetSMCount
()
*
blocks_per_sm
,
grid
);
KeTemporalShiftBw
<
T
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
KeTemporalShiftBwNCHW
<
T
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
KeTemporalShiftBwNHWC
<
T
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
};
...
...
paddle/fluid/operators/temporal_shift_op.h
浏览文件 @
7f50bb7e
...
...
@@ -17,12 +17,106 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
DataLayout
=
framework
::
DataLayout
;
static
HOSTDEVICE
inline
int
GetEntryIndex
(
int
in
,
int
it
,
int
ic
,
int
ih
,
int
iw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
)
{
return
in
*
tchw
+
it
*
chw
+
ic
*
hw
+
ih
*
w
+
iw
;
template
<
typename
T
>
void
TemporalShiftFwNCHW
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
ntchw
;
i
++
)
{
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
input
[
i
+
(
src_it
-
it
)
*
chw
];
}
}
}
template
<
typename
T
>
void
TemporalShiftFwNHWC
(
const
T
*
input
,
T
*
output
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
nthwc
;
i
++
)
{
int
it
=
(
i
%
thwc
)
/
hwc
;
int
ic
=
i
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
input
[
i
+
(
src_it
-
it
)
*
hwc
];
}
}
}
template
<
typename
T
>
void
TemporalShiftBwNCHW
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
t
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
ntchw
;
i
++
)
{
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
i
]
=
output_grad
[
i
+
(
src_it
-
it
)
*
chw
];
}
else
{
input_grad
[
i
]
=
0
;
}
}
}
template
<
typename
T
>
void
TemporalShiftBwNHWC
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
nthwc
,
const
int
thwc
,
const
int
hwc
,
const
int
t
,
const
int
c
,
const
int
c1
,
const
int
c2
)
{
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
nthwc
;
i
++
)
{
int
it
=
(
i
%
thwc
)
/
hwc
;
int
ic
=
i
%
c
;
if
(
ic
<
c1
)
{
src_it
=
it
+
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
-
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
input_grad
[
i
]
=
output_grad
[
i
+
(
src_it
-
it
)
*
hwc
];
}
else
{
input_grad
[
i
]
=
0
;
}
}
}
template
<
typename
T
>
...
...
@@ -33,44 +127,38 @@ class TemporalShiftKernel : public framework::OpKernel<T> {
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
std
::
string
data_format_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
]
;
const
int
w
=
input
->
dims
()[
3
];
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
1
]
:
input
->
dims
()[
3
])
;
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
2
]
:
input
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
input
->
dims
()[
3
]
:
input
->
dims
()[
2
]
);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
framework
::
DDim
out_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
framework
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
framework
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
output
->
numel
();
i
++
)
{
int
in
=
i
/
tchw
;
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
int
ih
=
(
i
%
hw
)
/
w
;
int
iw
=
i
%
w
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
output_data
[
i
]
=
0
;
}
else
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
output_data
[
i
]
=
input_data
[
src_idx
];
}
T
*
output_data
=
output
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
TemporalShiftFwNCHW
<
T
>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
TemporalShiftFwNHWC
<
T
>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
};
...
...
@@ -83,44 +171,39 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
std
::
string
data_format_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
data_format_str
);
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
c
=
output_grad
->
dims
()[
1
];
const
int
h
=
output_grad
->
dims
()[
2
]
;
const
int
w
=
output_grad
->
dims
()[
3
];
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
const
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
1
]
:
output_grad
->
dims
()[
3
])
;
const
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
2
]
:
output_grad
->
dims
()[
1
]);
const
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
?
output_grad
->
dims
()[
3
]
:
output_grad
->
dims
()[
2
]
);
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
framework
::
DDim
in_grad_dims
=
(
data_layout
==
DataLayout
::
kNCHW
?
framework
::
make_ddim
({
nt
,
c
,
h
,
w
})
:
framework
::
make_ddim
({
nt
,
h
,
w
,
c
}));
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
memset
(
input_grad_data
,
0
,
input_grad
->
numel
()
*
sizeof
(
T
));
int
src_it
=
0
;
for
(
int
i
=
0
;
i
<
output_grad
->
numel
();
i
++
)
{
int
in
=
i
/
tchw
;
int
it
=
(
i
%
tchw
)
/
chw
;
int
ic
=
(
i
%
chw
)
/
hw
;
int
ih
=
(
i
%
hw
)
/
w
;
int
iw
=
i
%
w
;
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
}
else
{
src_it
=
it
;
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
input_grad_data
[
src_idx
]
=
output_grad_data
[
i
];
}
input_grad
->
mutable_data
<
T
>
(
in_grad_dims
,
ctx
.
GetPlace
());
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
TemporalShiftBwNCHW
<
T
>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
t
,
c1
,
c2
);
}
else
{
TemporalShiftBwNHWC
<
T
>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
t
,
c
,
c1
,
c2
);
}
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
7f50bb7e
...
...
@@ -13334,7 +13334,7 @@ def shuffle_channel(x, group, name=None):
@templatedoc()
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None
, data_format="NCHW"
):
"""
**Temporal Shift Operator**
...
...
@@ -13348,6 +13348,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
data_format(str, optional): Data format that specifies the layout of input.
It can be "NCHW" or "NHWC". Default: "NCHW".
Returns:
out(Tensor): The temporal shifting result is a tensor with the
...
...
@@ -13365,6 +13367,13 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
input = paddle.randn([6, 4, 2, 2])
out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
"""
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if in_dygraph_mode():
return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift')
check_type(seg_num, 'seg_num', int, 'temporal_shift')
...
...
@@ -13375,16 +13384,15 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
if not isinstance(seg_num, int):
raise TypeError("seg_num must be int type.")
if in_dygraph_mode():
return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio)
helper.append_op(
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={"seg_num": seg_num,
"shift_ratio": shift_ratio})
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio,
"data_format": data_format
})
return out
...
...
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
浏览文件 @
7f50bb7e
...
...
@@ -22,7 +22,9 @@ import paddle
from
paddle.fluid
import
core
def
temporal_shift
(
x
,
seg_num
,
shift_ratio
):
def
temporal_shift
(
x
,
seg_num
,
shift_ratio
,
data_format
):
if
data_format
==
"NHWC"
:
x
=
np
.
transpose
(
x
,
(
0
,
3
,
1
,
2
))
shape
=
x
.
shape
reshape_x
=
x
.
reshape
((
-
1
,
seg_num
,
shape
[
1
],
shape
[
2
],
shape
[
3
]))
pad_x
=
np
.
pad
(
reshape_x
,
((
0
,
0
),
(
1
,
1
),
(
0
,
0
),
(
0
,
0
),
(
0
,
0
)),
...
...
@@ -33,7 +35,10 @@ def temporal_shift(x, seg_num, shift_ratio):
slice2
=
pad_x
[:,
2
:
seg_num
+
2
,
c1
:
c2
,
:,
:]
slice3
=
pad_x
[:,
1
:
seg_num
+
1
,
c2
:,
:,
:]
concat_x
=
np
.
concatenate
([
slice1
,
slice2
,
slice3
],
axis
=
2
)
return
concat_x
.
reshape
(
shape
)
out
=
concat_x
.
reshape
(
shape
)
if
data_format
==
"NHWC"
:
out
=
np
.
transpose
(
out
,
(
0
,
2
,
3
,
1
))
return
out
class
TestTemporalShift
(
OpTest
):
...
...
@@ -45,11 +50,13 @@ class TestTemporalShift(OpTest):
self
.
attrs
=
{
"seg_num"
:
self
.
seg_num
,
"shift_ratio"
:
self
.
shift_ratio
,
"data_format"
:
self
.
data_format
}
self
.
inputs
=
{
"X"
:
x
,
}
output
=
temporal_shift
(
x
,
self
.
seg_num
,
self
.
shift_ratio
)
output
=
temporal_shift
(
x
,
self
.
seg_num
,
self
.
shift_ratio
,
self
.
data_format
)
self
.
outputs
=
{
"Out"
:
output
}
def
test_check_output
(
self
):
...
...
@@ -63,6 +70,7 @@ class TestTemporalShift(OpTest):
self
.
seg_num
=
3
self
.
shift_ratio
=
0.25
self
.
dtype
=
'float64'
self
.
data_format
=
'NCHW'
class
TestTemporalShift2
(
TestTemporalShift
):
...
...
@@ -70,6 +78,7 @@ class TestTemporalShift2(TestTemporalShift):
self
.
x_shape
=
(
4
,
9
,
7
,
7
)
self
.
seg_num
=
2
self
.
shift_ratio
=
0.2
self
.
data_format
=
'NCHW'
class
TestTemporalShift3
(
TestTemporalShift
):
...
...
@@ -77,6 +86,15 @@ class TestTemporalShift3(TestTemporalShift):
self
.
x_shape
=
(
3
,
10
,
5
,
5
)
self
.
seg_num
=
1
self
.
shift_ratio
=
0.3
self
.
data_format
=
'NCHW'
class
TestTemporalShift4
(
TestTemporalShift
):
def
initTestCase
(
self
):
self
.
x_shape
=
(
6
,
5
,
5
,
4
)
self
.
seg_num
=
3
self
.
shift_ratio
=
0.25
self
.
data_format
=
'NHWC'
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -87,6 +105,7 @@ class TestTemporalShiftFP16(TestTemporalShift):
self
.
seg_num
=
1
self
.
shift_ratio
=
0.3
self
.
dtype
=
'float16'
self
.
data_format
=
'NCHW'
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
...
...
@@ -114,6 +133,14 @@ class TestTemporalShiftAPI(unittest.TestCase):
out
=
paddle
.
nn
.
functional
.
temporal_shift
(
x
=
input
,
seg_num
=
2
,
shift_ratio
=
0.2
)
def
test_error
(
self
):
def
attr_data_format
():
input
=
paddle
.
randn
([
6
,
4
,
2
,
2
])
out
=
paddle
.
nn
.
functional
.
temporal_shift
(
x
=
input
,
seg_num
=
2
,
shift_ratio
=
0.2
,
data_format
=
"HWC"
)
self
.
assertRaises
(
ValueError
,
attr_data_format
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录