Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ca157793
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca157793
编写于
7月 03, 2018
作者:
C
chenweihang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rewrite, use reshape op in unsqueeze op, test passed
上级
996c157f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
185 addition
and
231 deletion
+185
-231
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/unsqueeze_op.cc
paddle/fluid/operators/unsqueeze_op.cc
+73
-73
paddle/fluid/operators/unsqueeze_op.cu
paddle/fluid/operators/unsqueeze_op.cu
+0
-30
paddle/fluid/operators/unsqueeze_op.h
paddle/fluid/operators/unsqueeze_op.h
+0
-72
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
+111
-56
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
ca157793
...
...
@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
cos_sim_op DEPS cos_sim_functor
)
op_library
(
parallel_do_op DEPS executor
)
op_library
(
unsqueeze_op DEPS reshape_op
)
if
(
WITH_GPU
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
...
...
paddle/fluid/operators/unsqueeze_op.cc
浏览文件 @
ca157793
...
...
@@ -12,41 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
OpKernelType
;
using
framework
::
Tensor
;
class
UnsqueezeOp
:
public
framework
::
OperatorWithKernel
{
class
UnsqueezeOpInferShape
:
public
framework
::
InferShapeBase
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of UnsqueezeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UnsqueezeOp should not be null."
);
const
auto
&
axes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"axes"
);
const
auto
&
axes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"axes"
);
PADDLE_ENFORCE
(
!
axes
.
empty
(),
"The unsqueeze axes information must be set by Attr(axes)."
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE
(
x_dims
.
size
()
<
6
,
PADDLE_ENFORCE
(
static_cast
<
int
>
(
x_dims
.
size
())
<=
6
,
"Invalid dimensions, dynamic dimensions should within "
"[
0, 5
] dimensions (Eigen limit)."
);
"[
1, 6
] dimensions (Eigen limit)."
);
// Validity Check: the range of unsqueeze aixs.
// TODO(chenweihang): Don't consider negative axis?.
for
(
unsigned
int
idx
=
0
;
idx
<
axes
.
size
();
++
idx
)
{
PADDLE_ENFORCE
(
axes
[
idx
]
<
6
,
for
(
int
axis
:
axes
)
{
PADDLE_ENFORCE
(
axis
<
6
,
"Invalid dimensions, input axis should within "
"[
0, 5
] dimensions (Eigen limit)."
);
"[
1, 6
] dimensions (Eigen limit)."
);
}
auto
out_dims
=
GetOutputShape
(
axes
,
x_dims
);
...
...
@@ -54,33 +48,7 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
}
static
framework
::
DDim
GetOutputShape
(
const
std
::
vector
<
int
>
unsqz_dims
,
const
framework
::
DDim
&
in_dims
)
{
/*
* STL version
* Test Error! don't know why?.
std::vector<int64_t> output_shape;
// Contruct base output shape
for(int idx = 0; idx < in_dims.size(); ++idx) {
output_shape.emplace_back(in_dims[idx]);
}
// Validity Check: output dimensions limit.
PADDLE_ENFORCE(unsqz_dims.size() + output_shape.size() < 6,
"The Attr(axes) size is too large. The output shape should "
"be less than 6 (Eigne limit).");
// Insert the unsqueeze axis in turn.
auto it = output_shape.begin();
for (int axis : unsqz_dims) {
int cur = axis < 0 ? (axis + output_shape.size() + 1)
: axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE(cur >= 0 && cur <= static_cast<int>(output_shape.size()),
"The unsqueeze dims must be within range of current
rank.");
output_shape.emplace(it + axis, 1);
}
*/
const
framework
::
DDim
&
in_dims
)
{
unsigned
int
unsqz_mask
=
0
;
unsigned
int
front
=
0
,
back
=
0
;
int
output_dims_size
=
in_dims
.
size
();
...
...
@@ -93,17 +61,17 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
cur
>=
0
&&
cur
<=
output_dims_size
,
"The unsqueeze dims must be within range of current rank."
);
// Save the front part.
front
=
unsqz_mask
&
((
1
<<
axis
)
-
1
);
front
=
unsqz_mask
&
((
1
<<
cur
)
-
1
);
// Move the back part.
back
=
unsqz_mask
&
~
((
1
<<
axis
)
-
1
);
back
=
unsqz_mask
&
~
((
1
<<
cur
)
-
1
);
back
<<=
1
;
// Merge two part.
back
|=
(
1
<<
axis
);
back
|=
(
1
<<
cur
);
unsqz_mask
=
front
|
back
;
// Add the output size.
output_dims_size
++
;
// Validity Check: rank range.
PADDLE_ENFORCE
(
output_dims_size
<
6
,
PADDLE_ENFORCE
(
output_dims_size
<
=
6
,
"The output tensor's rank should be less than 6."
);
}
...
...
@@ -121,6 +89,31 @@ class UnsqueezeOp : public framework::OperatorWithKernel {
}
};
class
UnsqueezeOp
:
public
framework
::
OperatorBase
{
public:
UnsqueezeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
axes
=
Attr
<
std
::
vector
<
int
>>
(
"axes"
);
auto
x_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
auto
out_dims
=
UnsqueezeOpInferShape
::
GetOutputShape
(
axes
,
x_dims
);
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
out_dims
);
attrs
[
"inplace"
]
=
Attr
<
bool
>
(
"inplace"
);
// Invoke Reshape op.
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
class
UnsqueezeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -150,42 +143,49 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class
UnsqueezeGrad
Op
:
public
framework
::
OperatorWithKernel
{
class
UnsqueezeGrad
InferShape
:
public
framework
::
InferShapeBase
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of UnsqueezeGradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Output(Out@GRAD) of UnsqueezeGradOp should not be null."
);
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
}
};
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
class
UnsqueezeGradOp
:
public
framework
::
OperatorBase
{
public:
UnsqueezeGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
dx_name
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
dout_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
x_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
x_dims
);
attrs
[
"inplace"
]
=
Attr
<
bool
>
(
"inplace"
);
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape"
,
{{
"X"
,
{
dout_name
}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
dx_name
}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
}
// namespace operators
}
// namespace paddle
// Tell linker to use reshape op.
USE_OP
(
reshape
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
unsqueeze
,
ops
::
UnsqueezeOp
,
ops
::
UnsqueezeOpMaker
,
ops
::
UnsqueezeOpInferShape
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
unsqueeze_grad
,
ops
::
UnsqueezeGradOp
);
REGISTER_OP_CPU_KERNEL
(
unsqueeze
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
unsqueeze_grad
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OPERATOR
(
unsqueeze_grad
,
ops
::
UnsqueezeGradOp
,
ops
::
UnsqueezeGradInferShape
);
paddle/fluid/operators/unsqueeze_op.cu
已删除
100644 → 0
浏览文件 @
996c157f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/unsqueeze_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
unsqueeze
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
UnsqueezeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
unsqueeze_grad
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
UnsqueezeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/unsqueeze_op.h
已删除
100644 → 0
浏览文件 @
996c157f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
UnsqueezeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
framework
::
DDim
out_dims
=
out
->
dims
();
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
out
->
Resize
(
out_dims
);
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
Resize
(
out_dims
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
UnsqueezeGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
ctx
.
device_context
().
Wait
();
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
浏览文件 @
ca157793
...
...
@@ -27,7 +27,7 @@ class TestUnsqueezeOp(OpTest):
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
al
ce"
:
False
}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
la
ce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
...
...
@@ -37,16 +37,35 @@ class TestUnsqueezeOp(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
# Correct: There is mins axis.
# Correct: Single input index.
class
TestUnsqueezeOp1
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
5
)
axes
=
(
-
1
,
)
new_shape
=
(
3
,
5
,
1
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inplace"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
# Correct: Mixed input axis.
class
TestUnsqueezeOp2
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
5
)
axes
=
(
0
,
-
2
)
new_shape
=
(
1
,
3
,
1
,
5
)
axes
=
(
0
,
-
1
)
new_shape
=
(
1
,
3
,
5
,
1
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
al
ce"
:
False
}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
la
ce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
...
...
@@ -65,7 +84,7 @@ class TestUnsqueezeOp3(OpTest):
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
al
ce"
:
False
}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
la
ce"
:
False
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
...
...
@@ -75,16 +94,16 @@ class TestUnsqueezeOp3(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
#
Error: Output dimension is error
.
class
TestUnsqueezeOp
4
(
OpTest
):
#
Correct: Inplace
.
class
TestUnsqueezeOp
Inplace1
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
3
)
new_shape
=
(
1
,
3
,
2
,
2
,
5
)
ori_shape
=
(
3
,
5
)
axes
=
(
0
,
2
)
new_shape
=
(
1
,
3
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
alce"
:
Fals
e
}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
lace"
:
Tru
e
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
...
...
@@ -94,16 +113,16 @@ class TestUnsqueezeOp4(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
#
Error: Input axes is invalid case 1
.
class
TestUnsqueezeOp
5
(
OpTest
):
#
Correct: Inplace. There is mins index
.
class
TestUnsqueezeOp
Inplace2
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
5
)
ori_shape
=
(
3
,
5
)
axes
=
(
0
,
-
2
)
new_shape
=
(
1
,
3
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
alce"
:
Fals
e
}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
lace"
:
Tru
e
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
...
...
@@ -113,16 +132,16 @@ class TestUnsqueezeOp5(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
#
Error: Input axes is invalid case 2
.
class
TestUnsqueezeOp
5
(
OpTest
):
#
Correct: Inplace. There is duplicated axis
.
class
TestUnsqueezeOp
Inplace3
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
2
,
10
)
new_shape
=
(
1
,
3
,
1
,
5
)
axes
=
(
0
,
3
,
3
)
new_shape
=
(
1
,
3
,
2
,
1
,
1
,
5
)
self
.
op_type
=
"unsqueeze"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
alce"
:
Fals
e
}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
lace"
:
Tru
e
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
...
...
@@ -132,16 +151,17 @@ class TestUnsqueezeOp5(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
# Correct: Inplace.
class
TestUnsqueezeOpInplace1
(
OpTest
):
'''
# Error: Output dimension is error.
class TestUnsqueezeOp4(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes
=
(
0
,
2
)
new_shape
=
(
1
,
3
,
1
,
5
)
axes = (0,
3
)
new_shape = (1, 3, 1,
1,
5)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self
.
attrs
=
{
"axes"
:
axes
,
"inplace"
:
Tru
e
}
self.attrs = {"axes": axes, "inplace":
Fals
e}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
...
...
@@ -150,17 +170,34 @@ class TestUnsqueezeOpInplace1(OpTest):
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axis is large than output range.
class TestUnsqueezeOp5(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 4)
new_shape = (1, 3, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
# Correct: Inplace. There is duplicated axis.
class
TestUnsqueezeOpInplace2
(
OpTest
):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axes is large than Eigen limit.
class TestUnsqueezeOp6(OpTest):
def setUp(self):
ori_shape
=
(
3
,
2
,
5
)
axes
=
(
0
,
3
,
3
)
new_shape
=
(
1
,
3
,
2
,
1
,
1
,
5
)
ori_shape = (3, 5)
axes = (0,
2, 10
)
new_shape = (1, 3,
1, 5, 1
)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self
.
attrs
=
{
"axes"
:
axes
,
"inp
alce"
:
Tru
e
}
self.attrs = {"axes": axes, "inp
lace": Fals
e}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
...
...
@@ -169,6 +206,24 @@ class TestUnsqueezeOpInplace2(OpTest):
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Error: Input axes size is large than Eigen limit.
class TestUnsqueezeOp7(OpTest):
def setUp(self):
ori_shape = (3, 5)
axes = (0, 2, 2, 2, 2, 2)
new_shape = (1, 3, 1, 1, 5, 1)
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
'''
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录