Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c42d662e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c42d662e
编写于
7月 13, 2020
作者:
Y
yaoxuefeng
提交者:
GitHub
7月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify roll test=develop (#25321)
上级
bdc2c2db
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
70 addition
and
47 deletion
+70
-47
paddle/fluid/operators/roll_op.cc
paddle/fluid/operators/roll_op.cc
+2
-2
paddle/fluid/operators/roll_op.h
paddle/fluid/operators/roll_op.h
+4
-4
python/paddle/fluid/tests/unittests/test_roll_op.py
python/paddle/fluid/tests/unittests/test_roll_op.py
+20
-6
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+44
-35
未找到文件。
paddle/fluid/operators/roll_op.cc
浏览文件 @
c42d662e
...
...
@@ -33,7 +33,7 @@ class RollOp : public framework::OperatorWithKernel {
platform
::
errors
::
InvalidArgument
(
"Output(Out) of RollOp should not be null."
));
auto
dims
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"
dim
s"
);
auto
dims
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"
axi
s"
);
auto
shifts
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"shifts"
);
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
shifts
.
size
(),
...
...
@@ -92,7 +92,7 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"of the tensor are shifted."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
int64_t
>>
(
"
dim
s"
,
"
axi
s"
,
"Axis along which to roll. It must have the same size "
"with shifts."
)
.
SetDefault
({});
...
...
paddle/fluid/operators/roll_op.h
浏览文件 @
c42d662e
...
...
@@ -82,7 +82,7 @@ class RollKernel : public framework::OpKernel<T> {
auto
&
input
=
input_var
->
Get
<
LoDTensor
>
();
auto
*
output
=
output_var
->
GetMutable
<
LoDTensor
>
();
std
::
vector
<
int64_t
>
shifts
=
context
.
Attr
<
std
::
vector
<
int64_t
>>
(
"shifts"
);
std
::
vector
<
int64_t
>
dims
=
context
.
Attr
<
std
::
vector
<
int64_t
>>
(
"
dim
s"
);
std
::
vector
<
int64_t
>
dims
=
context
.
Attr
<
std
::
vector
<
int64_t
>>
(
"
axi
s"
);
std
::
vector
<
T
>
out_vec
;
TensorToVector
(
input
,
context
.
device_context
(),
&
out_vec
);
...
...
@@ -94,8 +94,8 @@ class RollKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
dims
[
i
]
<
input_dim
.
size
()
&&
dims
[
i
]
>=
(
0
-
input_dim
.
size
()),
true
,
platform
::
errors
::
OutOfRange
(
"Attr(
dim
s[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(
dim
s[%d]) = %d."
,
"Attr(
axi
s[%d]) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(
axi
s[%d]) = %d."
,
i
,
input_dim
.
size
(),
input_dim
.
size
()
-
1
,
i
,
dims
[
i
]));
shift_along_dim
(
out_vec
.
data
(),
input_dim
,
dims
[
i
],
shifts
[
i
]);
}
...
...
@@ -114,7 +114,7 @@ class RollGradKernel : public framework::OpKernel<T> {
auto
&
input
=
input_var
->
Get
<
LoDTensor
>
();
auto
*
output
=
output_var
->
GetMutable
<
LoDTensor
>
();
std
::
vector
<
int64_t
>
shifts
=
context
.
Attr
<
std
::
vector
<
int64_t
>>
(
"shifts"
);
std
::
vector
<
int64_t
>
dims
=
context
.
Attr
<
std
::
vector
<
int64_t
>>
(
"
dim
s"
);
std
::
vector
<
int64_t
>
dims
=
context
.
Attr
<
std
::
vector
<
int64_t
>>
(
"
axi
s"
);
std
::
vector
<
T
>
out_vec
;
TensorToVector
(
input
,
context
.
device_context
(),
&
out_vec
);
...
...
python/paddle/fluid/tests/unittests/test_roll_op.py
浏览文件 @
c42d662e
...
...
@@ -28,17 +28,17 @@ class TestRollOp(OpTest):
self
.
op_type
=
"roll"
self
.
init_dtype_type
()
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)}
self
.
attrs
=
{
'shifts'
:
self
.
shifts
,
'
dims'
:
self
.
dim
s
}
self
.
attrs
=
{
'shifts'
:
self
.
shifts
,
'
axis'
:
self
.
axi
s
}
self
.
outputs
=
{
'Out'
:
np
.
roll
(
self
.
inputs
[
'X'
],
self
.
attrs
[
'shifts'
],
self
.
attrs
[
'
dim
s'
])
self
.
attrs
[
'
axi
s'
])
}
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float64
self
.
x_shape
=
(
100
,
4
,
5
)
self
.
shifts
=
[
101
,
-
1
]
self
.
dim
s
=
[
0
,
-
2
]
self
.
axi
s
=
[
0
,
-
2
]
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -52,7 +52,7 @@ class TestRollOpCase2(TestRollOp):
self
.
dtype
=
np
.
float32
self
.
x_shape
=
(
100
,
10
,
5
)
self
.
shifts
=
[
8
,
-
1
]
self
.
dim
s
=
[
-
1
,
-
2
]
self
.
axi
s
=
[
-
1
,
-
2
]
class
TestRollAPI
(
unittest
.
TestCase
):
...
...
@@ -78,7 +78,7 @@ class TestRollAPI(unittest.TestCase):
# case 2:
with
program_guard
(
Program
(),
Program
()):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
-
1
,
3
])
z
=
paddle
.
roll
(
x
,
shifts
=
1
,
dim
s
=
0
)
z
=
paddle
.
roll
(
x
,
shifts
=
1
,
axi
s
=
0
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
res
,
=
exe
.
run
(
feed
=
{
'x'
:
self
.
data_x
},
fetch_list
=
[
z
.
name
],
...
...
@@ -101,12 +101,26 @@ class TestRollAPI(unittest.TestCase):
# case 2:
with
fluid
.
dygraph
.
guard
():
x
=
fluid
.
dygraph
.
to_variable
(
self
.
data_x
)
z
=
paddle
.
roll
(
x
,
shifts
=
1
,
dim
s
=
0
)
z
=
paddle
.
roll
(
x
,
shifts
=
1
,
axi
s
=
0
)
np_z
=
z
.
numpy
()
expect_out
=
np
.
array
([[
7.0
,
8.0
,
9.0
],
[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]])
self
.
assertTrue
(
np
.
allclose
(
expect_out
,
np_z
))
def
test_roll_op_false
(
self
):
self
.
input_data
()
def
test_axis_out_range
():
with
program_guard
(
Program
(),
Program
()):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
-
1
,
3
])
z
=
paddle
.
roll
(
x
,
shifts
=
1
,
axis
=
10
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
res
,
=
exe
.
run
(
feed
=
{
'x'
:
self
.
data_x
},
fetch_list
=
[
z
.
name
],
return_numpy
=
False
)
self
.
assertRaises
(
ValueError
,
test_axis_out_range
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/tensor/manipulation.py
浏览文件 @
c42d662e
...
...
@@ -104,23 +104,24 @@ def flip(input, dims, name=None):
return
out
def
roll
(
input
,
shifts
,
dims
=
None
):
def
roll
(
x
,
shifts
,
axis
=
None
,
name
=
None
):
"""
:alias_main: paddle.roll
:alias: paddle.roll,paddle.tensor.roll,paddle.tensor.manipulation.roll
Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond
the last position are re-introduced at the first position. If a dimension is not specified,
Roll the `x` tensor along the given axis(axes). With specific 'shifts', Elements that
roll beyond the last position are re-introduced at the first according to 'shifts'.
If a axis is not specified,
the tensor will be flattened before rolling and then restored to the original shape.
Args:
input (Variable): The input tensor variable
.
x (Variable): The x tensor variable as input
.
shifts (int|list|tuple): The number of places by which the elements
of the `
input
` tensor are shifted.
dims (int|list|tuple|None): Dimentions
along which to roll.
of the `
x
` tensor are shifted.
axis (int|list|tuple|None): axis(axes)
along which to roll.
Returns:
Variable: A Tensor with same data type as `
input
`.
Variable: A Tensor with same data type as `
x
`.
Examples:
.. code-block:: python
...
...
@@ -131,48 +132,56 @@ def roll(input, shifts, dims=None):
data = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
with fluid.dygraph.guard():
x = fluid.dygraph
.to_variable(data)
paddle.enable_imperative()
x = paddle.imperative
.to_variable(data)
out_z1 = paddle.roll(x, shifts=1)
print(out_z1.numpy())
#[[9. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
out_z2 = paddle.roll(x, shifts=1, dim
s=0)
out_z2 = paddle.roll(x, shifts=1, axi
s=0)
print(out_z2.numpy())
#[[7. 8. 9.]
# [1. 2. 3.]
# [4. 5. 6.]]
"""
helper
=
LayerHelper
(
"roll"
,
**
locals
())
origin_shape
=
input
.
shape
origin_shape
=
x
.
shape
if
type
(
shifts
)
==
int
:
shifts
=
[
shifts
]
if
type
(
dims
)
==
int
:
dims
=
[
dims
]
if
dims
:
check_type
(
dims
,
'dims'
,
(
list
,
tuple
),
'roll'
)
if
type
(
axis
)
==
int
:
axis
=
[
axis
]
len_origin_shape
=
len
(
origin_shape
)
if
axis
:
for
i
in
range
(
len
(
axis
)):
if
axis
[
i
]
>=
len_origin_shape
or
axis
[
i
]
<
-
len_origin_shape
:
raise
ValueError
(
"axis is out of range, it should be in range [{}, {}), but received {}"
.
format
(
-
len_origin_shape
,
len_origin_shape
,
axis
))
if
axis
:
check_type
(
axis
,
'axis'
,
(
list
,
tuple
),
'roll'
)
check_type
(
shifts
,
'shifts'
,
(
list
,
tuple
),
'roll'
)
if
in_dygraph_mode
():
if
dim
s
is
None
:
input
=
core
.
ops
.
reshape
(
input
,
'shape'
,
[
-
1
,
1
])
dim
s
=
[
0
]
out
=
core
.
ops
.
roll
(
input
,
'dims'
,
dim
s
,
'shifts'
,
shifts
)
if
axi
s
is
None
:
x
=
core
.
ops
.
reshape
(
x
,
'shape'
,
[
-
1
,
1
])
axi
s
=
[
0
]
out
=
core
.
ops
.
roll
(
x
,
'axis'
,
axi
s
,
'shifts'
,
shifts
)
return
core
.
ops
.
reshape
(
out
,
'shape'
,
origin_shape
)
out
=
helper
.
create_variable_for_type_inference
(
input
.
dtype
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
if
dim
s
is
None
:
input
=
reshape
(
input
,
shape
=
[
-
1
,
1
])
dim
s
=
[
0
]
if
axi
s
is
None
:
x
=
reshape
(
x
,
shape
=
[
-
1
,
1
])
axi
s
=
[
0
]
helper
.
append_op
(
type
=
'roll'
,
inputs
=
{
'X'
:
input
},
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'
dims'
:
dim
s
,
attrs
=
{
'
axis'
:
axi
s
,
'shifts'
:
shifts
})
out
=
reshape
(
out
,
shape
=
origin_shape
,
inplace
=
True
)
return
out
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录