Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1d4dfc09
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d4dfc09
编写于
3月 12, 2018
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs.
上级
d3d16f76
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
69 addition
and
17 deletion
+69
-17
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+29
-10
paddle/fluid/operators/reshape_op.h
paddle/fluid/operators/reshape_op.h
+9
-5
python/paddle/fluid/tests/unittests/test_reshape_op.py
python/paddle/fluid/tests/unittests/test_reshape_op.py
+31
-2
未找到文件。
paddle/fluid/operators/reshape_op.cc
浏览文件 @
1d4dfc09
...
...
@@ -32,7 +32,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE_EQ
(
shape
.
empty
(),
ctx
->
HasInput
(
"Shape"
),
"The shape information can only be set by Attr(shape) or "
"by Input(Shape). Attr(shape) and Input(Shape) cannot be "
...
...
@@ -41,29 +40,31 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
if
(
ctx
->
HasInput
(
"Shape"
))
{
// The shape information in given by Input(Shape).
auto
shape_dims
=
ctx
->
GetInputDim
(
"Shape"
);
PADDLE_ENFORCE
(
shape_dims
.
size
()
==
2UL
&&
shape_dims
[
0
]
==
1UL
,
"The Input(Label) should be a 2-D tensor with the 1st "
"dimensions fixed to 1 (a row vector)."
);
// The actual output shape will be set at runtime, here temporially
the
// The actual output shape will be set at runtime, here temporially
set
// the shape of output the same as the shape of input.
ctx
->
SetOutputDim
(
"Out"
,
x_dims
);
}
else
{
// The shape information in given by Attr(shape).
std
::
vector
<
int64_t
>
output_shape
;
ValidateShape
(
shape
,
framework
::
product
(
x_dims
),
output_shape
);
auto
out_dims
=
framework
::
make_ddim
(
output_shape
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
}
if
(
shape
[
0
]
==
x_dims
[
0
])
{
// Only pass LoD when the first dimension of output and input are the
//
same.
// Only pass LoD when the first dimension of output and Input(X)
// are the
same.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
}
private:
void
ValidateShape
(
const
std
::
vector
<
int
>
&
shape
,
const
int64_t
in_size
,
...
...
@@ -94,6 +95,14 @@ class ReshapeOp : public framework::OperatorWithKernel {
[](
int
a
)
{
return
static_cast
<
int64_t
>
(
a
);
});
if
(
neg_dims_idx
.
size
())
output_shape
[
neg_dims_idx
[
0
]]
=
inferred_dim
;
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
ReshapeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -101,11 +110,13 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of reshape operator."
);
AddInput
(
"Shape"
,
"a 1-D tensor that provides the shape information."
)
AddInput
(
"Shape"
,
"Tensor<int64_t>, a 1-D tensor that provides the shape information."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"The output tensor of reshape operator."
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"(
vector<int>) Target shape of reshape operator."
)
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"(std::
vector<int>) Target shape of reshape operator."
)
.
SetDefault
(
std
::
vector
<
int
>
());
AddComment
(
R"DOC(
Reshape Operator.
...
...
@@ -139,6 +150,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
}
// namespace operators
...
...
paddle/fluid/operators/reshape_op.h
浏览文件 @
1d4dfc09
...
...
@@ -33,9 +33,6 @@ class ReshapeKernel : public framework::OpKernel<T> {
std
::
vector
<
int64_t
>
output_shape
;
ValidateShape
(
*
shape
,
framework
::
product
(
in
->
dims
()),
output_shape
);
for
(
auto
d
:
output_shape
)
std
::
cout
<<
d
<<
" "
;
std
::
cout
<<
std
::
endl
;
out_dims
=
framework
::
make_ddim
(
output_shape
);
}
else
{
out_dims
=
out
->
dims
();
...
...
@@ -85,11 +82,18 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
};
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/test_reshape_op.py
浏览文件 @
1d4dfc09
...
...
@@ -33,7 +33,8 @@ from op_test import OpTest
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
#
#
# class TestReshapeOpDimInfer1(OpTest):
# def setUp(self):
# self.op_type = "reshape"
...
...
@@ -56,7 +57,8 @@ class TestReshapeOp2(OpTest):
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
),
"Shape"
:
np
.
array
(
new_shape
)
"Shape"
:
np
.
array
(
new_shape
,
dtype
=
"int64"
)
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
[
0
])}
...
...
@@ -67,5 +69,32 @@ class TestReshapeOp2(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
# class TestReshapeOpInplace(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [10 * 20], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
#
#
# class TestReshapeOpDimInferInplace(OpTest):
# def setUp(self):
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [4, -1, 5], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录