Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e4024962
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4024962
编写于
7月 02, 2018
作者:
C
chenweihang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
complete unsqueeze op and related unittest.
上级
a1e7f2d5
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
167 addition
and
48 deletion
+167
-48
paddle/fluid/operators/unsqueeze_op.cc
paddle/fluid/operators/unsqueeze_op.cc
+78
-35
paddle/fluid/operators/unsqueeze_op.cu
paddle/fluid/operators/unsqueeze_op.cu
+2
-2
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
+87
-11
未找到文件。
paddle/fluid/operators/unsqueeze_op.cc
浏览文件 @
e4024962
...
...
@@ -32,42 +32,85 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UnsqueezeOp should not be null."
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
&
axes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"axes"
);
// Check output tensor dims (<9).
PADDLE_ENFORCE_LE
(
x_dims
.
size
()
+
axes
.
size
(),
9
,
"Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions."
);
// Check the range of unsqueeze aixs.
for
(
int
a
:
axes
)
{
PADDLE_ENFORCE_LT
(
a
,
static_cast
<
int64_t
>
(
x_dims
.
size
()
+
axes
.
size
()),
"The axis must be less than output tensor's rank."
);
PADDLE_ENFORCE
(
!
axes
.
empty
(),
"The unsqueeze axes information must be set by Attr(axes)."
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE
(
x_dims
.
size
()
<
6
,
"Invalid dimensions, dynamic dimensions should within "
"[0, 5] dimensions (Eigen limit)."
);
// Validity Check: the range of unsqueeze aixs.
// TODO(chenweihang): Don't consider negative axis?.
for
(
unsigned
int
idx
=
0
;
idx
<
axes
.
size
();
++
idx
)
{
PADDLE_ENFORCE
(
axes
[
idx
]
<
6
,
"Invalid dimensions, input axis should within "
"[0, 5] dimensions (Eigen limit)."
);
}
auto
out_dims
=
GetOutputShape
(
axes
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
}
static
framework
::
DDim
GetOutputShape
(
const
std
::
vector
<
int
>
unsq
ueeze
_dims
,
static
framework
::
DDim
GetOutputShape
(
const
std
::
vector
<
int
>
unsq
z
_dims
,
const
framework
::
DDim
&
in_dims
)
{
int
out_dims_size
=
in_dims
.
size
()
+
unsqueeze_dims
.
size
();
bool
should_unsqueeze
[
9
]
=
{
false
};
// Determines the dimensions should be unsqueezed in output tensor after.
for
(
unsigned
int
idx
=
0
;
idx
<
unsqueeze_dims
.
size
();
++
idx
)
{
int
current
=
unsqueeze_dims
[
idx
]
<
0
?
unsqueeze_dims
[
idx
]
+
out_dims_size
:
unsqueeze_dims
[
idx
];
// Check current index.
PADDLE_ENFORCE_GE
(
current
,
0
,
"Invaild axis, negative axis is out of range."
);
should_unsqueeze
[
idx
]
=
true
;
/*
* STL version
* Test Error! don't know why?.
std::vector<int64_t> output_shape;
// Contruct base output shape
for(int idx = 0; idx < in_dims.size(); ++idx) {
output_shape.emplace_back(in_dims[idx]);
}
// Validity Check: output dimensions limit.
PADDLE_ENFORCE(unsqz_dims.size() + output_shape.size() < 6,
"The Attr(axes) size is too large. The output shape should "
"be less than 6 (Eigne limit).");
// Insert the unsqueeze axis in turn.
auto it = output_shape.begin();
for (int axis : unsqz_dims) {
int cur = axis < 0 ? (axis + output_shape.size() + 1)
: axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE(cur >= 0 && cur <= static_cast<int>(output_shape.size()),
"The unsqueeze dims must be within range of current
rank.");
output_shape.emplace(it + axis, 1);
}
*/
unsigned
int
unsqz_mask
=
0
;
unsigned
int
front
=
0
,
back
=
0
;
int
output_dims_size
=
in_dims
.
size
();
// Simulate insert by bit calc.
for
(
int
axis
:
unsqz_dims
)
{
int
cur
=
axis
<
0
?
axis
+
output_dims_size
+
1
:
axis
;
// Vaildity Check: the axis bound
PADDLE_ENFORCE
(
cur
>=
0
&&
cur
<=
output_dims_size
,
"The unsqueeze dims must be within range of current rank."
);
// Save the front part.
front
=
unsqz_mask
&
((
1
<<
axis
)
-
1
);
// Move the back part.
back
=
unsqz_mask
&
~
((
1
<<
axis
)
-
1
);
back
<<=
1
;
// Merge two part.
back
|=
(
1
<<
axis
);
unsqz_mask
=
front
|
back
;
// Add the output size.
output_dims_size
++
;
// Validity Check: rank range.
PADDLE_ENFORCE
(
output_dims_size
<
6
,
"The output tensor's rank should be less than 6."
);
}
// Make output
dimensions
std
::
vector
<
int64_t
>
output_shape
(
out_dims_size
,
0
);
for
(
int
in_idx
=
0
,
out_idx
=
0
;
out_idx
<
out_dims_size
;
++
out_idx
)
{
if
(
!
should_unsqueeze
[
out_idx
]
)
{
// Make output
shape
std
::
vector
<
int64_t
>
output_shape
(
out
put
_dims_size
,
0
);
for
(
int
in_idx
=
0
,
out_idx
=
0
;
out_idx
<
out
put
_dims_size
;
++
out_idx
)
{
if
(
(
unsqz_mask
&
(
1
<<
out_idx
))
==
0
)
{
output_shape
[
out_idx
]
=
in_dims
[
in_idx
++
];
}
else
{
output_shape
[
out_idx
]
=
1
;
...
...
paddle/fluid/operators/unsqueeze_op.cu
浏览文件 @
e4024962
...
...
@@ -18,12 +18,12 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
squeeze
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
un
squeeze
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
squeeze_grad
,
un
squeeze_grad
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
...
...
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
浏览文件 @
e4024962
...
...
@@ -19,7 +19,7 @@ from op_test import OpTest
# Correct: General.
class
Test
SqueezeOp1
(
OpTest
):
class
Test
UnsqueezeOp
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
5
)
axes
=
(
0
,
2
)
...
...
@@ -38,7 +38,7 @@ class TestSqueezeOp1(OpTest):
# Correct: There is mins axis.
class
Test
S
queezeOp2
(
OpTest
):
class
Test
Uns
queezeOp2
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
5
)
axes
=
(
0
,
-
2
)
...
...
@@ -56,6 +56,82 @@ class TestSqueezeOp2(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
# Correct: There is duplicated axis.
class
TestUnsqueezeOp3
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
3
,
3
)
new_shape
=
(
1
,
3
,
2
,
1
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inpalce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
# Error: Output dimension is error.
class
TestUnsqueezeOp4
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
3
)
new_shape
=
(
1
,
3
,
2
,
2
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inpalce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
# Error: Input axes is invalid case 1.
class
TestUnsqueezeOp5
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
5
)
new_shape
=
(
1
,
3
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inpalce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
# Error: Input axes is invalid case 2.
class
TestUnsqueezeOp5
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
2
,
10
)
new_shape
=
(
1
,
3
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inpalce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
# Correct: Inplace.
class
TestUnsqueezeOpInplace1
(
OpTest
):
def
setUp
(
self
):
...
...
@@ -75,12 +151,12 @@ class TestUnsqueezeOpInplace1(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
# Correct: Inplace. There is
mins
axis.
# Correct: Inplace. There is
duplicated
axis.
class
TestUnsqueezeOpInplace2
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
5
)
axes
=
(
0
,
-
2
)
new_shape
=
(
1
,
3
,
1
,
5
)
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
3
,
3
)
new_shape
=
(
1
,
3
,
2
,
1
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录