Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5c1920b7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5c1920b7
编写于
3月 08, 2019
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add Attr shift_ratio. test=develop
上级
71101c9c
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
62 addition
and
23 deletion
+62
-23
paddle/fluid/operators/temporal_shift_op.cc
paddle/fluid/operators/temporal_shift_op.cc
+13
-2
paddle/fluid/operators/temporal_shift_op.cu
paddle/fluid/operators/temporal_shift_op.cu
+18
-8
paddle/fluid/operators/temporal_shift_op.h
paddle/fluid/operators/temporal_shift_op.h
+12
-4
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+7
-3
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+1
-1
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
...on/paddle/fluid/tests/unittests/test_temporal_shift_op.py
+11
-5
未找到文件。
paddle/fluid/operators/temporal_shift_op.cc
浏览文件 @
5c1920b7
...
@@ -33,8 +33,12 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
...
@@ -33,8 +33,12 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."
);
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."
);
int
seg_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"seg_num"
);
int
seg_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
->
Attrs
().
Get
<
float
>
(
"shift_ratio"
);
PADDLE_ENFORCE_GT
(
seg_num
,
0
,
PADDLE_ENFORCE_GT
(
seg_num
,
0
,
"Attr(seg_num) should be greater then 0."
);
"Attr(seg_num) should be greater than 0."
);
PADDLE_ENFORCE
(
shift_ratio
>
0
||
shift_ratio
<
.5
,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5."
);
if
(
ctx
->
IsRuntime
())
{
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
dim_x
[
0
]
%
seg_num
,
0
,
PADDLE_ENFORCE_EQ
(
dim_x
[
0
]
%
seg_num
,
0
,
...
@@ -69,6 +73,12 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -69,6 +73,12 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
"seg_num"
,
AddAttr
<
int
>
(
"seg_num"
,
"The temporal segment number, this should be a positive "
"The temporal segment number, this should be a positive "
"interger."
);
"interger."
);
AddAttr
<
float
>
(
"shift_ratio"
,
"The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25."
)
.
SetDefault
(
0.25
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
This operator calculates the temporal shifting features for Input(X).
This operator calculates the temporal shifting features for Input(X).
...
@@ -85,7 +95,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -85,7 +95,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
padding width as 1 on each side, padding result will be in shape
padding width as 1 on each side, padding result will be in shape
of [N, T+2, C, H, W].
of [N, T+2, C, H, W].
Step 3: Slice padding result as follows:
Step 3: Assume :attr:`shift_ratio` is :math:`0.25`, slice padding
result as follows:
slice1 = x[:, :T, :C/4, :, :]
slice1 = x[:, :T, :C/4, :, :]
slice2 = x[:, 2:T+2, C/4:C/2, :, :]
slice2 = x[:, 2:T+2, C/4:C/2, :, :]
...
...
paddle/fluid/operators/temporal_shift_op.cu
浏览文件 @
5c1920b7
...
@@ -20,7 +20,8 @@ using framework::Tensor;
...
@@ -20,7 +20,8 @@ using framework::Tensor;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KeTemporalShiftFw
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
__global__
void
KeTemporalShiftFw
(
const
T
*
input
,
T
*
output
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
)
{
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
float
shift_ratio
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
int
src_it
=
0
;
...
@@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
...
@@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
int
ih
=
(
tid
%
hw
)
/
w
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
int
iw
=
tid
%
w
;
if
(
ic
<
c
/
4
)
{
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
src_it
=
it
-
1
;
}
else
if
(
ic
<
c
/
2
)
{
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
src_it
=
it
+
1
;
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
...
@@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
...
@@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KeTemporalShiftBw
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
__global__
void
KeTemporalShiftBw
(
const
T
*
output_grad
,
T
*
input_grad
,
const
int
ntchw
,
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
)
{
const
int
tchw
,
const
int
chw
,
const
int
hw
,
const
int
w
,
const
int
t
,
const
int
c
,
const
float
shift_ratio
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
src_it
=
0
;
int
src_it
=
0
;
...
@@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int
...
@@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int
int
ih
=
(
tid
%
hw
)
/
w
;
int
ih
=
(
tid
%
hw
)
/
w
;
int
iw
=
tid
%
w
;
int
iw
=
tid
%
w
;
if
(
ic
<
c
/
4
)
{
const
int
c1
=
static_cast
<
T
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
T
>
(
c
*
2
*
shift_ratio
);
if
(
ic
<
c1
)
{
src_it
=
it
-
1
;
src_it
=
it
-
1
;
}
else
if
(
ic
<
c
/
2
)
{
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
src_it
=
it
+
1
;
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
...
@@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
c
=
input
->
dims
()[
1
];
...
@@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftFw
<
KeTemporalShiftFw
<
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
);
input_data
,
output_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
}
}
};
};
...
@@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
c
=
output_grad
->
dims
()[
1
];
const
int
c
=
output_grad
->
dims
()[
1
];
...
@@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftBw
<
KeTemporalShiftBw
<
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
);
output_grad_data
,
input_grad_data
,
ntchw
,
tchw
,
chw
,
hw
,
w
,
t
,
c
,
shift_ratio
);
}
}
};
};
...
...
paddle/fluid/operators/temporal_shift_op.h
浏览文件 @
5c1920b7
...
@@ -30,12 +30,16 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
...
@@ -30,12 +30,16 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
int
nt
=
input
->
dims
()[
0
];
const
int
nt
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
const
int
w
=
input
->
dims
()[
3
];
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
const
int
hw
=
h
*
w
;
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
tchw
=
t
*
chw
;
...
@@ -51,9 +55,9 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
...
@@ -51,9 +55,9 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
int
ih
=
(
i
%
hw
)
/
w
;
int
ih
=
(
i
%
hw
)
/
w
;
int
iw
=
i
%
w
;
int
iw
=
i
%
w
;
if
(
ic
<
c
/
4
)
{
if
(
ic
<
c
1
)
{
src_it
=
it
-
1
;
src_it
=
it
-
1
;
}
else
if
(
ic
<
c
/
2
)
{
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
src_it
=
it
+
1
;
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
...
@@ -76,12 +80,16 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
...
@@ -76,12 +80,16 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
int
t
=
ctx
.
Attr
<
int
>
(
"seg_num"
);
float
shift_ratio
=
ctx
.
Attr
<
float
>
(
"shift_ratio"
);
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
nt
=
output_grad
->
dims
()[
0
];
const
int
c
=
output_grad
->
dims
()[
1
];
const
int
c
=
output_grad
->
dims
()[
1
];
const
int
h
=
output_grad
->
dims
()[
2
];
const
int
h
=
output_grad
->
dims
()[
2
];
const
int
w
=
output_grad
->
dims
()[
3
];
const
int
w
=
output_grad
->
dims
()[
3
];
const
int
c1
=
static_cast
<
int
>
(
c
*
shift_ratio
);
const
int
c2
=
static_cast
<
int
>
(
c
*
2
*
shift_ratio
);
const
int
hw
=
h
*
w
;
const
int
hw
=
h
*
w
;
const
int
chw
=
c
*
hw
;
const
int
chw
=
c
*
hw
;
const
int
tchw
=
t
*
chw
;
const
int
tchw
=
t
*
chw
;
...
@@ -98,9 +106,9 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
...
@@ -98,9 +106,9 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
int
ih
=
(
i
%
hw
)
/
w
;
int
ih
=
(
i
%
hw
)
/
w
;
int
iw
=
i
%
w
;
int
iw
=
i
%
w
;
if
(
ic
<
c
/
4
)
{
if
(
ic
<
c
1
)
{
src_it
=
it
-
1
;
src_it
=
it
-
1
;
}
else
if
(
ic
<
c
/
2
)
{
}
else
if
(
ic
<
c2
)
{
src_it
=
it
+
1
;
src_it
=
it
+
1
;
}
else
{
}
else
{
src_it
=
it
;
src_it
=
it
;
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
5c1920b7
...
@@ -10266,7 +10266,7 @@ def shuffle_channel(x, group, name=None):
...
@@ -10266,7 +10266,7 @@ def shuffle_channel(x, group, name=None):
@
templatedoc
()
@
templatedoc
()
def
temporal_shift
(
x
,
seg_num
,
name
=
None
):
def
temporal_shift
(
x
,
seg_num
,
shift_ratio
=
0.25
,
name
=
None
):
"""
"""
**Temporal Shift Operator**
**Temporal Shift Operator**
...
@@ -10275,6 +10275,7 @@ def temporal_shift(x, seg_num, name=None):
...
@@ -10275,6 +10275,7 @@ def temporal_shift(x, seg_num, name=None):
Args:
Args:
x(Variable): ${x_comment}
x(Variable): ${x_comment}
seg_num(int): ${seg_num_comment}
seg_num(int): ${seg_num_comment}
shift_ratio(float): ${shift_ratio_comment}
Returns:
Returns:
out(Variable): The temporal shifting result is a tensor variable with the
out(Variable): The temporal shifting result is a tensor variable with the
...
@@ -10287,7 +10288,7 @@ def temporal_shift(x, seg_num, name=None):
...
@@ -10287,7 +10288,7 @@ def temporal_shift(x, seg_num, name=None):
.. code-block:: python
.. code-block:: python
input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32')
input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32')
out = fluid.layers.temporal_shift(x=input, seg_num=2)
out = fluid.layers.temporal_shift(x=input, seg_num=2
, shift_ratio=0.2
)
"""
"""
helper
=
LayerHelper
(
"temporal_shift"
,
**
locals
())
helper
=
LayerHelper
(
"temporal_shift"
,
**
locals
())
...
@@ -10300,7 +10301,10 @@ def temporal_shift(x, seg_num, name=None):
...
@@ -10300,7 +10301,10 @@ def temporal_shift(x, seg_num, name=None):
type
=
"temporal_shift"
,
type
=
"temporal_shift"
,
inputs
=
{
"X"
:
x
},
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
out
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"seg_num"
:
seg_num
})
attrs
=
{
"seg_num"
:
seg_num
,
"shift_ratio"
:
shift_ratio
})
return
out
return
out
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
5c1920b7
...
@@ -1052,7 +1052,7 @@ class TestBook(unittest.TestCase):
...
@@ -1052,7 +1052,7 @@ class TestBook(unittest.TestCase):
program
=
Program
()
program
=
Program
()
with
program_guard
(
program
):
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
"X"
,
shape
=
[
16
,
4
,
4
],
dtype
=
"float32"
)
x
=
layers
.
data
(
name
=
"X"
,
shape
=
[
16
,
4
,
4
],
dtype
=
"float32"
)
out
=
layers
.
temporal_shift
(
x
,
seg_num
=
4
)
out
=
layers
.
temporal_shift
(
x
,
seg_num
=
4
,
shift_ratio
=
0.2
)
self
.
assertIsNotNone
(
out
)
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
print
(
str
(
program
))
...
...
python/paddle/fluid/tests/unittests/test_temporal_shift_op.py
浏览文件 @
5c1920b7
...
@@ -21,13 +21,15 @@ from op_test import OpTest
...
@@ -21,13 +21,15 @@ from op_test import OpTest
from
paddle.fluid
import
core
from
paddle.fluid
import
core
def
temporal_shift
(
x
,
seg_num
):
def
temporal_shift
(
x
,
seg_num
,
shift_ratio
):
shape
=
x
.
shape
shape
=
x
.
shape
reshape_x
=
x
.
reshape
((
-
1
,
seg_num
,
shape
[
1
],
shape
[
2
],
shape
[
3
]))
reshape_x
=
x
.
reshape
((
-
1
,
seg_num
,
shape
[
1
],
shape
[
2
],
shape
[
3
]))
pad_x
=
np
.
pad
(
reshape_x
,
((
0
,
0
),
(
1
,
1
),
(
0
,
0
),
(
0
,
0
),
(
0
,
0
)),
'constant'
)
pad_x
=
np
.
pad
(
reshape_x
,
((
0
,
0
),
(
1
,
1
),
(
0
,
0
),
(
0
,
0
),
(
0
,
0
)),
'constant'
)
slice1
=
pad_x
[:,
:
seg_num
,
:
shape
[
1
]
//
4
,
:,
:]
c1
=
int
(
shape
[
1
]
*
shift_ratio
)
slice2
=
pad_x
[:,
2
:
seg_num
+
2
,
shape
[
1
]
//
4
:
shape
[
1
]
//
2
,
:,
:]
c2
=
int
(
shape
[
1
]
*
2
*
shift_ratio
)
slice3
=
pad_x
[:,
1
:
seg_num
+
1
,
shape
[
1
]
//
2
:,
:,
:]
slice1
=
pad_x
[:,
:
seg_num
,
:
c1
,
:,
:]
slice2
=
pad_x
[:,
2
:
seg_num
+
2
,
c1
:
c2
,
:,
:]
slice3
=
pad_x
[:,
1
:
seg_num
+
1
,
c2
:,
:,
:]
concat_x
=
np
.
concatenate
([
slice1
,
slice2
,
slice3
],
axis
=
2
)
concat_x
=
np
.
concatenate
([
slice1
,
slice2
,
slice3
],
axis
=
2
)
return
concat_x
.
reshape
(
shape
)
return
concat_x
.
reshape
(
shape
)
...
@@ -39,13 +41,14 @@ class TestTemporalShift(OpTest):
...
@@ -39,13 +41,14 @@ class TestTemporalShift(OpTest):
self
.
attrs
=
{
self
.
attrs
=
{
"seg_num"
:
self
.
seg_num
,
"seg_num"
:
self
.
seg_num
,
"shift_ratio"
:
self
.
shift_ratio
,
}
}
self
.
inputs
=
{
self
.
inputs
=
{
"X"
:
x
,
"X"
:
x
,
}
}
output
=
temporal_shift
(
x
,
self
.
seg_num
)
output
=
temporal_shift
(
x
,
self
.
seg_num
,
self
.
shift_ratio
)
self
.
outputs
=
{
"Out"
:
output
}
self
.
outputs
=
{
"Out"
:
output
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
@@ -57,17 +60,20 @@ class TestTemporalShift(OpTest):
...
@@ -57,17 +60,20 @@ class TestTemporalShift(OpTest):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
x_shape
=
(
6
,
4
,
4
,
4
)
self
.
x_shape
=
(
6
,
4
,
4
,
4
)
self
.
seg_num
=
3
self
.
seg_num
=
3
self
.
shift_ratio
=
0.25
class
TestTemporalShift2
(
TestTemporalShift
):
class
TestTemporalShift2
(
TestTemporalShift
):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
x_shape
=
(
4
,
9
,
7
,
7
)
self
.
x_shape
=
(
4
,
9
,
7
,
7
)
self
.
seg_num
=
2
self
.
seg_num
=
2
self
.
shift_ratio
=
0.2
class
TestTemporalShift2
(
TestTemporalShift
):
class
TestTemporalShift2
(
TestTemporalShift
):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
x_shape
=
(
3
,
10
,
5
,
5
)
self
.
x_shape
=
(
3
,
10
,
5
,
5
)
self
.
seg_num
=
1
self
.
seg_num
=
1
self
.
shift_ratio
=
0.3
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录