Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0cd9b8c0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0cd9b8c0
编写于
9月 20, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify the input\output name to X\Out
上级
a9a7ba3c
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
44 addition
and
46 deletion
+44
-46
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+20
-22
paddle/operators/transpose_op.h
paddle/operators/transpose_op.h
+21
-21
python/paddle/v2/framework/tests/test_transpose_op.py
python/paddle/v2/framework/tests/test_transpose_op.py
+3
-3
未找到文件。
paddle/operators/transpose_op.cc
浏览文件 @
0cd9b8c0
...
@@ -25,19 +25,18 @@ class TransposeOp : public framework::OperatorWithKernel {
...
@@ -25,19 +25,18 @@ class TransposeOp : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Input"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
"Input(Input) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Out"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Output"
),
"Output(Out) should not be null"
);
"Output(Output) should not be null"
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
input_dim
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
std
::
vector
<
int
>
axis
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
size_t
input_rank
=
input_dim
.
size
();
size_t
x_rank
=
x_dims
.
size
();
size_t
axis_size
=
axis
.
size
();
size_t
axis_size
=
axis
.
size
();
PADDLE_ENFORCE_EQ
(
input
_rank
,
axis_size
,
PADDLE_ENFORCE_EQ
(
x
_rank
,
axis_size
,
"the input tensor's rank(%d) "
"the input tensor's rank(%d) "
"should be equal to the axis's size(%d)"
,
"should be equal to the axis's size(%d)"
,
input
_rank
,
axis_size
);
x
_rank
,
axis_size
);
std
::
vector
<
int
>
count
(
axis_size
,
0
);
std
::
vector
<
int
>
count
(
axis_size
,
0
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
...
@@ -48,11 +47,11 @@ class TransposeOp : public framework::OperatorWithKernel {
...
@@ -48,11 +47,11 @@ class TransposeOp : public framework::OperatorWithKernel {
"where the dims is the axis's size"
);
"where the dims is the axis's size"
);
}
}
framework
::
DDim
out
put_dim
(
input_dim
);
framework
::
DDim
out
_dims
(
x_dims
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
out
put_dim
[
i
]
=
input_dim
[
axis
[
i
]];
out
_dims
[
i
]
=
x_dims
[
axis
[
i
]];
}
}
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out
put"
)
->
Resize
(
output_dim
);
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out
"
)
->
Resize
(
out_dims
);
}
}
};
};
...
@@ -62,9 +61,9 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -62,9 +61,9 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
AddInput
(
"
Input
"
,
"
X
"
,
"(Tensor)The input tensor, tensors with rank at most 6 are supported"
);
"(Tensor)The input tensor, tensors with rank at most 6 are supported"
);
AddOutput
(
"Out
put
"
,
"(Tensor)The output tensor"
);
AddOutput
(
"Out"
,
"(Tensor)The output tensor"
);
AddAttr
<
std
::
vector
<
int
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"axis"
,
"axis"
,
"(vector<int>)a list of values, and the size of the list should be "
"(vector<int>)a list of values, and the size of the list should be "
...
@@ -96,15 +95,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
...
@@ -96,15 +95,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Input"
),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
"Input(Input) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Output"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Output@GRAD) should not be null"
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
input_dim
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
auto
*
x_grad
=
auto
*
input_grad
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Input"
));
if
(
x_grad
)
x_grad
->
Resize
(
x_dims
);
if
(
input_grad
)
input_grad
->
Resize
(
input_dim
);
}
}
};
};
...
...
paddle/operators/transpose_op.h
浏览文件 @
0cd9b8c0
...
@@ -41,30 +41,30 @@ template <typename Place, typename T>
...
@@ -41,30 +41,30 @@ template <typename Place, typename T>
class
TransposeKernel
:
public
framework
::
OpKernel
{
class
TransposeKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"Input
"
);
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X
"
);
auto
*
out
put
=
context
.
Output
<
framework
::
Tensor
>
(
"Outp
ut"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"O
ut"
);
out
put
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
std
::
vector
<
int
>
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
int
ndims
=
axis
.
size
();
int
ndims
=
axis
.
size
();
switch
(
ndims
)
{
switch
(
ndims
)
{
case
1
:
case
1
:
EigenTranspose
<
Place
,
T
,
1
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
1
>
(
context
,
*
x
,
*
o
ut
,
axis
);
break
;
break
;
case
2
:
case
2
:
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
x
,
*
o
ut
,
axis
);
break
;
break
;
case
3
:
case
3
:
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
x
,
*
o
ut
,
axis
);
break
;
break
;
case
4
:
case
4
:
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
x
,
*
o
ut
,
axis
);
break
;
break
;
case
5
:
case
5
:
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
x
,
*
o
ut
,
axis
);
break
;
break
;
case
6
:
case
6
:
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
x
,
*
o
ut
,
axis
);
break
;
break
;
default:
default:
PADDLE_THROW
(
"Tensors with rank at most 6 are supported"
);
PADDLE_THROW
(
"Tensors with rank at most 6 are supported"
);
...
@@ -76,12 +76,12 @@ template <typename Place, typename T>
...
@@ -76,12 +76,12 @@ template <typename Place, typename T>
class
TransposeGradKernel
:
public
framework
::
OpKernel
{
class
TransposeGradKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
out
put
_grad
=
auto
*
out_grad
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out
put
"
));
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
input
_grad
=
auto
*
x
_grad
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"
Input
"
));
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"
X
"
));
if
(
input
_grad
)
{
if
(
x
_grad
)
{
input
_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
x
_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
std
::
vector
<
int
>
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
reversed_axis
(
axis
);
std
::
vector
<
int
>
reversed_axis
(
axis
);
...
@@ -94,27 +94,27 @@ class TransposeGradKernel : public framework::OpKernel {
...
@@ -94,27 +94,27 @@ class TransposeGradKernel : public framework::OpKernel {
switch
(
ndims
)
{
switch
(
ndims
)
{
case
1
:
case
1
:
EigenTranspose
<
Place
,
T
,
1
>
(
context
,
*
out
put_grad
,
*
input
_grad
,
EigenTranspose
<
Place
,
T
,
1
>
(
context
,
*
out
_grad
,
*
x
_grad
,
reversed_axis
);
reversed_axis
);
break
;
break
;
case
2
:
case
2
:
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
out
put_grad
,
*
input
_grad
,
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
out
_grad
,
*
x
_grad
,
reversed_axis
);
reversed_axis
);
break
;
break
;
case
3
:
case
3
:
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
out
put_grad
,
*
input
_grad
,
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
out
_grad
,
*
x
_grad
,
reversed_axis
);
reversed_axis
);
break
;
break
;
case
4
:
case
4
:
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
out
put_grad
,
*
input
_grad
,
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
out
_grad
,
*
x
_grad
,
reversed_axis
);
reversed_axis
);
break
;
break
;
case
5
:
case
5
:
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
out
put_grad
,
*
input
_grad
,
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
out
_grad
,
*
x
_grad
,
reversed_axis
);
reversed_axis
);
break
;
break
;
case
6
:
case
6
:
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
out
put_grad
,
*
input
_grad
,
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
out
_grad
,
*
x
_grad
,
reversed_axis
);
reversed_axis
);
break
;
break
;
default:
default:
...
...
python/paddle/v2/framework/tests/test_transpose_op.py
浏览文件 @
0cd9b8c0
...
@@ -7,15 +7,15 @@ class TestTransposeOp(OpTest):
...
@@ -7,15 +7,15 @@ class TestTransposeOp(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
initTestCase
()
self
.
op_type
=
"transpose"
self
.
op_type
=
"transpose"
self
.
inputs
=
{
'
Input
'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
'
X
'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
'axis'
:
list
(
self
.
axis
)}
self
.
attrs
=
{
'axis'
:
list
(
self
.
axis
)}
self
.
outputs
=
{
'Out
put'
:
self
.
inputs
[
'Input
'
].
transpose
(
self
.
axis
)}
self
.
outputs
=
{
'Out
'
:
self
.
inputs
[
'X
'
].
transpose
(
self
.
axis
)}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'
Input'
],
'Outp
ut'
)
self
.
check_grad
([
'
X'
],
'O
ut'
)
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
shape
=
(
3
,
4
)
self
.
shape
=
(
3
,
4
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录