Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9de45e11
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9de45e11
编写于
9月 19, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed bug when dims.size == 1, modify the variable naming, add judgement when input_grad is null
上级
35967e86
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
73 addition
and
61 deletion
+73
-61
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+20
-24
paddle/operators/transpose_op.h
paddle/operators/transpose_op.h
+47
-37
python/paddle/v2/framework/tests/test_transpose_op.py
python/paddle/v2/framework/tests/test_transpose_op.py
+6
-0
未找到文件。
paddle/operators/transpose_op.cc
浏览文件 @
9de45e11
...
...
@@ -27,26 +27,29 @@ class TransposeOp : public framework::OperatorWithKernel {
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Input"
),
"Input(Input) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Output"
),
"Output(Output) should not be null"
);
auto
input_dim
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
auto
axis
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
size_t
input_
dim_size
=
input_dim
.
size
();
std
::
vector
<
int
>
axis
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
size_t
input_
rank
=
input_dim
.
size
();
size_t
axis_size
=
axis
.
size
();
PADDLE_ENFORCE_EQ
(
input_
dim_size
,
axis_size
,
"the input tensor's
dimension
(%d) "
PADDLE_ENFORCE_EQ
(
input_
rank
,
axis_size
,
"the input tensor's
rank
(%d) "
"should be equal to the axis's size(%d)"
,
input_dim_size
,
axis_size
);
std
::
vector
<
int
>
axis_sorted
(
axis
);
std
::
sort
(
axis_sorted
.
begin
(),
axis_sorted
.
end
());
for
(
size_t
i
=
0
;
i
<
axis_sorted
.
size
();
i
++
)
{
PADDLE_ENFORCE_EQ
(
axis_sorted
[
i
],
static_cast
<
int
>
(
i
),
"the sorted axis should be [0, 1, ... dims - 1], "
input_rank
,
axis_size
);
std
::
vector
<
int
>
count
(
axis_size
,
0
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
PADDLE_ENFORCE
(
axis
[
i
]
<
static_cast
<
int
>
(
axis_size
)
&&
++
count
[
axis
[
i
]]
==
1
,
"Each element of Attribute axis should be a unique value "
"range from 0 to (dims - 1), "
"where the dims is the axis's size"
);
}
framework
::
DDim
output_dim
(
input_dim
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
()
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
axis
_size
;
i
++
)
{
output_dim
[
i
]
=
input_dim
[
axis
[
i
]];
}
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Output"
)
->
Resize
(
output_dim
);
...
...
@@ -60,12 +63,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Input"
,
"(Tensor)The input tensor, tensors with rank at most
7
are supported"
);
"(Tensor)The input tensor, tensors with rank at most
6
are supported"
);
AddOutput
(
"Output"
,
"(Tensor)The output tensor"
);
AddAttr
<
std
::
vector
<
int
>>
(
"axis"
,
"(vector<int>)a list of values, and the size of the list should be "
"the same with the input tensor
dimensions
, the tensor will "
"the same with the input tensor
rank
, the tensor will "
"permute the axes according the the values given"
);
AddComment
(
R"DOC(
The Tensor will be permuted according to the axis values given.
...
...
@@ -97,18 +100,11 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
"Input(Input) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Output"
)),
"Input(Output@GRAD) should not be null"
);
auto
input_dim
s
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
auto
input_dim
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
auto
*
input_grad
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
output_grad_dims
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
))
->
dims
();
auto
output_dims
=
ctx
.
Input
<
Tensor
>
(
"Output"
)
->
dims
();
PADDLE_ENFORCE
(
output_grad_dims
==
output_dims
,
"Output@GRAD dims must equal to Input(Input) dims"
);
input_grad
->
Resize
(
input_dims
);
if
(
input_grad
)
input_grad
->
Resize
(
input_dim
);
}
};
...
...
paddle/operators/transpose_op.h
浏览文件 @
9de45e11
...
...
@@ -20,19 +20,19 @@
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
,
int
Dims
>
template
<
typename
Place
,
typename
T
,
int
Rank
>
void
EigenTranspose
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
in
,
framework
::
Tensor
&
out
,
std
::
vector
<
int
>
axis
)
{
Eigen
::
array
<
int
,
Dims
>
permute
;
for
(
int
i
=
0
;
i
<
Dims
;
i
++
)
{
Eigen
::
array
<
int
,
Rank
>
permute
;
for
(
int
i
=
0
;
i
<
Rank
;
i
++
)
{
permute
[
i
]
=
axis
[
i
];
}
auto
in_dim
=
in
.
dims
();
auto
out_dim
=
out
.
dims
();
auto
eigen_in
=
framework
::
EigenTensor
<
T
,
Dims
>::
From
(
in
);
auto
eigen_out
=
framework
::
EigenTensor
<
T
,
Dims
>::
From
(
out
);
auto
eigen_in
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
in
);
auto
eigen_out
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
out
);
auto
&
dev
=
context
.
GetEigenDevice
<
Place
>
();
eigen_out
.
device
(
dev
)
=
eigen_in
.
shuffle
(
permute
);
}
...
...
@@ -45,10 +45,11 @@ class TransposeKernel : public framework::OpKernel {
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
int
ndims
=
axis
.
size
();
switch
(
ndims
)
{
case
1
:
EigenTranspose
<
Place
,
T
,
1
>
(
context
,
*
input
,
*
output
,
axis
);
break
;
case
2
:
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
input
,
*
output
,
axis
);
...
...
@@ -79,39 +80,48 @@ class TransposeGradKernel : public framework::OpKernel {
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
*
input_grad
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
axis_temp
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
(
axis_temp
);
std
::
vector
<
int
>
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
reversed_axis
(
axis
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
axis
[
axis_temp
[
i
]]
=
i
;
reversed_axis
[
axis
[
i
]]
=
i
;
}
int
ndims
=
axis
.
size
();
switch
(
ndims
)
{
case
1
:
EigenTranspose
<
Place
,
T
,
1
>
(
context
,
*
output_grad
,
*
input_grad
,
reversed_axis
);
break
;
case
2
:
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
output_grad
,
*
input_grad
,
reversed_axis
);
break
;
case
3
:
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
output_grad
,
*
input_grad
,
reversed_axis
);
break
;
case
4
:
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
output_grad
,
*
input_grad
,
reversed_axis
);
break
;
case
5
:
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
output_grad
,
*
input_grad
,
reversed_axis
);
break
;
case
6
:
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
output_grad
,
*
input_grad
,
reversed_axis
);
break
;
default:
PADDLE_THROW
(
"Tensors with rank at most 6 are supported"
);
}
}
}
};
}
// namespace operators
...
...
python/paddle/v2/framework/tests/test_transpose_op.py
浏览文件 @
9de45e11
...
...
@@ -22,6 +22,12 @@ class TestTransposeOp(OpTest):
self
.
axis
=
(
1
,
0
)
class
TestCase0
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
3
,
)
self
.
axis
=
(
0
,
)
class
TestCase1
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
3
,
4
,
5
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录