Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
82d4f903
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
82d4f903
编写于
3月 09, 2019
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix format. test=develop
上级
28949f8e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
97 addition
and
89 deletion
+97
-89
paddle/fluid/operators/temporal_shift_op.cc
paddle/fluid/operators/temporal_shift_op.cc
+20
-18
paddle/fluid/operators/temporal_shift_op.cu
paddle/fluid/operators/temporal_shift_op.cu
+59
-55
paddle/fluid/operators/temporal_shift_op.h
paddle/fluid/operators/temporal_shift_op.h
+9
-6
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-4
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
...on/paddle/fluid/tests/unittests/test_temporal_shift_op.py
+7
-6
未找到文件。
paddle/fluid/operators/temporal_shift_op.cc
浏览文件 @
82d4f903
...
@@ -17,7 +17,7 @@ namespace operators {
...
@@ -17,7 +17,7 @@ namespace operators {
using
framework
::
Tensor
;
using
framework
::
Tensor
;
class
TemporalShiftOp
:
public
framework
::
OperatorWithKernel
{
class
TemporalShiftOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -29,23 +29,23 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
...
@@ -29,23 +29,23 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
"Output(Out) of TemporalShiftOp should not be null."
);
"Output(Out) of TemporalShiftOp should not be null."
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."
);
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."
);
int
seg_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"seg_num"
);
int
seg_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
->
Attrs
().
Get
<
float
>
(
"shift_ratio"
);
float
shift_ratio
=
ctx
->
Attrs
().
Get
<
float
>
(
"shift_ratio"
);
PADDLE_ENFORCE_GT
(
seg_num
,
0
,
PADDLE_ENFORCE_GT
(
seg_num
,
0
,
"Attr(seg_num) should be greater than 0."
);
"Attr(seg_num) should be greater than 0."
);
PADDLE_ENFORCE
(
shift_ratio
>
0
||
shift_ratio
<
.5
,
PADDLE_ENFORCE
(
shift_ratio
>
0
||
shift_ratio
<
.5
,
"Attr(shift_ratio) should be greater than 0 and less "
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5."
);
"than 0.5."
);
if
(
ctx
->
IsRuntime
())
{
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
dim_x
[
0
]
%
seg_num
,
0
,
PADDLE_ENFORCE_EQ
(
"Input(X) dims[0] should be divided exactly by Attr(seg_num)."
);
dim_x
[
0
]
%
seg_num
,
0
,
"Input(X) dims[0] should be divided exactly by Attr(seg_num)."
);
}
}
ctx
->
SetOutputDim
(
"Out"
,
dim_x
);
ctx
->
SetOutputDim
(
"Out"
,
dim_x
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
}
}
...
@@ -70,14 +70,15 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -70,14 +70,15 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
"The output tensor of temporal shift operator. "
"The output tensor of temporal shift operator. "
"This is a 4-D tensor in the same shape with Input(X)."
);
"This is a 4-D tensor in the same shape with Input(X)."
);
AddAttr
<
int
>
(
"seg_num"
,
AddAttr
<
int
>
(
"seg_num"
,
"The temporal segment number, this should be a positive "
"The temporal segment number, this should be a positive "
"interger."
);
"interger."
);
AddAttr
<
float
>
(
"shift_ratio"
,
AddAttr
<
float
>
(
"The shift ratio of the channels, the first shift ratio part "
"shift_ratio"
,
"of channels will be shifted by -1 along the temporal dimension, "
"The shift ratio of the channels, the first shift ratio part "
"and the second shift ratio part of channels will be shifted by "
"of channels will be shifted by -1 along the temporal dimension, "
"1 along the temporal dimension. Default 0.25."
)
"and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25."
)
.
SetDefault
(
0.25
);
.
SetDefault
(
0.25
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
...
@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
class
TemporalShiftOpGrad
:
public
framework
::
OperatorWithKernel
{
class
TemporalShiftOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel {
...
@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
temporal_shift
,
ops
::
TemporalShiftOp
,
ops
::
TemporalShiftOpMaker
,
REGISTER_OPERATOR
(
temporal_shift
,
ops
::
TemporalShiftOp
,
ops
::
TemporalShiftOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
temporal_shift_grad
,
ops
::
TemporalShiftOpGrad
);
REGISTER_OPERATOR
(
temporal_shift_grad
,
ops
::
TemporalShiftOpGrad
);
REGISTER_OP_CPU_KERNEL
(
temporal_shift
,
ops
::
TemporalShiftKernel
<
float
>
,
REGISTER_OP_CPU_KERNEL
(
temporal_shift
,
ops
::
TemporalShiftKernel
<
float
>
,
...
...
paddle/fluid/operators/temporal_shift_op.cu
浏览文件 @
82d4f903
...
@@ -17,70 +17,72 @@ namespace operators {
...
@@ -17,70 +17,72 @@ namespace operators {
using
framework
::
Tensor
;
using
framework
::
Tensor
;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KeTemporalShiftFw
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
__global__
void
KeTemporalShiftFw
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
float
shift_ratio
)
{
const
int
w
,
const
int
t
,
const
int
c
,
const
float
shift_ratio
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
int
src_it
=
0
;
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
int
in
=
tid
/
tchw
;
int
in
=
tid
/
tchw
;
int
it
=
(
tid
%
tchw
)
/
chw
;
int
it
=
(
tid
%
tchw
)
/
chw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
int
iw
=
tid
%
w
;
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
src_it
=
it
+
1
;
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
}
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
if
(
src_it
<
0
||
src_it
>=
t
)
{
output
[
tid
]
=
0
;
output
[
tid
]
=
0
;
}
else
{
}
else
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
output
[
tid
]
=
input
[
src_idx
];
output
[
tid
]
=
input
[
src_idx
];
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KeTemporalShiftBw
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
__global__
void
KeTemporalShiftBw
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
int
ntchw
,
const
int
tchw
,
const
float
shift_ratio
)
{
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
float
shift_ratio
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
int
src_it
=
0
;
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
for
(;
tid
<
ntchw
;
tid
+=
stride
)
{
int
in
=
tid
/
tchw
;
int
in
=
tid
/
tchw
;
int
it
=
(
tid
%
tchw
)
/
chw
;
int
it
=
(
tid
%
tchw
)
/
chw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
int
ic
=
(
tid
%
chw
)
/
hw
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
int
iw
=
tid
%
w
;
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
src_it
=
it
-
1
;
}
else
if
(
ic
<
c2
)
{
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
src_it
=
it
+
1
;
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
}
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
if
(
src_it
>=
0
&&
src_it
<
t
)
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
input_grad
[
src_idx
]
=
output_grad
[
tid
];
input_grad
[
src_idx
]
=
output_grad
[
tid
];
}
}
}
}
}
}
...
@@ -113,8 +115,8 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -113,8 +115,8 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
KeTemporalShiftFw
<
KeTemporalShiftFw
<
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
}
}
};
};
...
@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const
int
ntchw
=
nt
*
chw
;
const
int
ntchw
=
nt
*
chw
;
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
()(
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
input_grad
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
static_cast
<
T
>
(
0
));
...
@@ -148,8 +151,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -148,8 +151,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
KeTemporalShiftBw
<
KeTemporalShiftBw
<
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
}
}
};
};
...
...
paddle/fluid/operators/temporal_shift_op.h
浏览文件 @
82d4f903
...
@@ -18,13 +18,15 @@ namespace operators {
...
@@ -18,13 +18,15 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
static
HOSTDEVICE
inline
int
GetEntryIndex
(
int
in
,
int
it
,
int
ic
,
int
ih
,
int
iw
,
static
HOSTDEVICE
inline
int
GetEntryIndex
(
int
in
,
int
it
,
int
ic
,
int
ih
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
)
{
int
iw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
)
{
return
in
*
tchw
+
it
*
chw
+
ic
*
hw
+
ih
*
w
+
iw
;
return
in
*
tchw
+
it
*
chw
+
ic
*
hw
+
ih
*
w
+
iw
;
}
}
template
<
typename
T
>
template
<
typename
T
>
class
TemporalShiftKernel
:
public
framework
::
OpKernel
<
T
>
{
class
TemporalShiftKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
...
@@ -62,7 +64,7 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
...
@@ -62,7 +64,7 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
}
}
if
(
src_it
<
0
||
src_it
>=
t
)
{
if
(
src_it
<
0
||
src_it
>=
t
)
{
output_data
[
i
]
=
0
;
output_data
[
i
]
=
0
;
}
else
{
}
else
{
...
@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
...
@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const
int
tchw
=
t
*
chw
;
const
int
tchw
=
t
*
chw
;
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
({
nt
,
c
,
h
,
w
},
ctx
.
GetPlace
());
memset
(
input_grad_data
,
0
,
input_grad
->
numel
()
*
sizeof
(
T
));
memset
(
input_grad_data
,
0
,
input_grad
->
numel
()
*
sizeof
(
T
));
int
src_it
=
0
;
int
src_it
=
0
;
...
@@ -113,7 +116,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
...
@@ -113,7 +116,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
}
}
if
(
src_it
>=
0
&&
src_it
<
t
)
{
if
(
src_it
>=
0
&&
src_it
<
t
)
{
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
int
src_idx
=
GetEntryIndex
(
in
,
src_it
,
ic
,
ih
,
iw
,
tchw
,
chw
,
hw
,
w
);
input_grad_data
[
src_idx
]
=
output_grad_data
[
i
];
input_grad_data
[
src_idx
]
=
output_grad_data
[
i
];
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
82d4f903
...
@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
...
@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
type
=
"temporal_shift"
,
type
=
"temporal_shift"
,
inputs
=
{
"X"
:
x
},
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
out
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
attrs
=
{
"seg_num"
:
seg_num
,
"seg_num"
:
seg_num
,
"shift_ratio"
:
shift_ratio
})
"shift_ratio"
:
shift_ratio
})
return
out
return
out
...
...
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
浏览文件 @
82d4f903
...
@@ -24,15 +24,17 @@ from paddle.fluid import core
...
@@ -24,15 +24,17 @@ from paddle.fluid import core
def
temporal_shift
(
x
,
seg_num
,
shift_ratio
):
def
temporal_shift
(
x
,
seg_num
,
shift_ratio
):
shape
=
x
.
shape
shape
=
x
.
shape
reshape_x
=
x
.
reshape
((
-
1
,
seg_num
,
shape
[
1
],
shape
[
2
],
shape
[
3
]))
reshape_x
=
x
.
reshape
((
-
1
,
seg_num
,
shape
[
1
],
shape
[
2
],
shape
[
3
]))
pad_x
=
np
.
pad
(
reshape_x
,
((
0
,
0
),
(
1
,
1
),
(
0
,
0
),
(
0
,
0
),
(
0
,
0
)),
'constant'
)
pad_x
=
np
.
pad
(
reshape_x
,
((
0
,
0
),
(
1
,
1
),
(
0
,
0
),
(
0
,
0
),
(
0
,
0
)),
'constant'
)
c1
=
int
(
shape
[
1
]
*
shift_ratio
)
c1
=
int
(
shape
[
1
]
*
shift_ratio
)
c2
=
int
(
shape
[
1
]
*
2
*
shift_ratio
)
c2
=
int
(
shape
[
1
]
*
2
*
shift_ratio
)
slice1
=
pad_x
[:,
:
seg_num
,
:
c1
,
:,
:]
slice1
=
pad_x
[:,
:
seg_num
,
:
c1
,
:,
:]
slice2
=
pad_x
[:,
2
:
seg_num
+
2
,
c1
:
c2
,
:,
:]
slice2
=
pad_x
[:,
2
:
seg_num
+
2
,
c1
:
c2
,
:,
:]
slice3
=
pad_x
[:,
1
:
seg_num
+
1
,
c2
:,
:,
:]
slice3
=
pad_x
[:,
1
:
seg_num
+
1
,
c2
:,
:,
:]
concat_x
=
np
.
concatenate
([
slice1
,
slice2
,
slice3
],
axis
=
2
)
concat_x
=
np
.
concatenate
([
slice1
,
slice2
,
slice3
],
axis
=
2
)
return
concat_x
.
reshape
(
shape
)
return
concat_x
.
reshape
(
shape
)
class
TestTemporalShift
(
OpTest
):
class
TestTemporalShift
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
initTestCase
()
...
@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest):
...
@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest):
"shift_ratio"
:
self
.
shift_ratio
,
"shift_ratio"
:
self
.
shift_ratio
,
}
}
self
.
inputs
=
{
self
.
inputs
=
{
"X"
:
x
,
}
"X"
:
x
,
}
output
=
temporal_shift
(
x
,
self
.
seg_num
,
self
.
shift_ratio
)
output
=
temporal_shift
(
x
,
self
.
seg_num
,
self
.
shift_ratio
)
self
.
outputs
=
{
"Out"
:
output
}
self
.
outputs
=
{
"Out"
:
output
}
...
@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
...
@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self
.
seg_num
=
3
self
.
seg_num
=
3
self
.
shift_ratio
=
0.25
self
.
shift_ratio
=
0.25
class
TestTemporalShift2
(
TestTemporalShift
):
class
TestTemporalShift2
(
TestTemporalShift
):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
x_shape
=
(
4
,
9
,
7
,
7
)
self
.
x_shape
=
(
4
,
9
,
7
,
7
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录