Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
454b0a96
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
454b0a96
编写于
3月 21, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove the extra call of ValidateShape in ReshapeKernel
上级
437f7a32
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
77 addition
and
77 deletion
+77
-77
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+76
-0
paddle/fluid/operators/reshape_op.h
paddle/fluid/operators/reshape_op.h
+1
-77
未找到文件。
paddle/fluid/operators/reshape_op.cc
浏览文件 @
454b0a96
...
...
@@ -17,6 +17,82 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
class
ReshapeOp
:
public
framework
::
OperatorWithKernel
{
public:
ReshapeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE
(
!
shape
.
empty
(),
"The shape information must be set by Attr(shape)."
);
std
::
vector
<
int64_t
>
output_shape
;
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ValidateShape
(
shape
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
// NOTE: Reshape op cannot reshape an input sequence batch into an
// output sequence batch that has a different number of time steps. Here
// output always shares the LoD information with input. But if
// Attr(shape) contains 0 or -1, the actual output shape can only be
// determined during runtime. The check for wheather it is a valid
// output sequence batch is performed in runtime.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
private:
framework
::
DDim
ValidateShape
(
const
std
::
vector
<
int
>
shape
,
const
framework
::
DDim
&
in_dims
)
const
{
const
int64_t
in_size
=
framework
::
product
(
in_dims
);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const
int64_t
unk_dim_val
=
-
1
;
const
int64_t
copy_dim_val
=
0
;
std
::
vector
<
int64_t
>
output_shape
(
shape
.
size
(),
0
);
int64_t
capacity
=
1
;
int
unk_dim_idx
=
-
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
unk_dim_val
)
{
PADDLE_ENFORCE
(
unk_dim_idx
==
-
1
,
"Only one input dimension of Attr(shape) can be unknown."
);
unk_dim_idx
=
i
;
}
else
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
i
)
<
in_dims
.
size
(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape."
);
}
else
{
PADDLE_ENFORCE
(
shape
[
i
]
>
0
,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension."
);
}
capacity
*=
(
shape
[
i
]
?
shape
[
i
]
:
in_dims
[
i
]);
output_shape
[
i
]
=
(
shape
[
i
]
?
static_cast
<
int64_t
>
(
shape
[
i
])
:
in_dims
[
i
]);
}
if
(
unk_dim_idx
!=
-
1
)
{
output_shape
[
unk_dim_idx
]
=
-
in_size
/
capacity
;
PADDLE_ENFORCE_EQ
(
output_shape
[
unk_dim_idx
]
*
capacity
,
-
in_size
,
"Invalid shape is given."
);
}
else
{
PADDLE_ENFORCE_EQ
(
capacity
,
in_size
,
"Invalid shape is given."
);
}
return
framework
::
make_ddim
(
output_shape
);
}
};
class
ReshapeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
ReshapeOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
paddle/fluid/operators/reshape_op.h
浏览文件 @
454b0a96
...
...
@@ -20,81 +20,6 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
class
ReshapeOp
:
public
framework
::
OperatorWithKernel
{
public:
ReshapeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE
(
!
shape
.
empty
(),
"The shape information must be set by Attr(shape)."
);
std
::
vector
<
int64_t
>
output_shape
;
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ValidateShape
(
shape
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
// NOTE: Reshape op cannot reshape an input sequence batch into an
// output sequence batch that has a different number of time steps. Here
// output always shares the LoD information with input. But if
// Attr(shape) contains 0 or -1, the actual output shape can only be
// determined during runtime. The check for wheather it is a valid
// output sequence batch is performed in runtime.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
static
framework
::
DDim
ValidateShape
(
const
std
::
vector
<
int
>
shape
,
const
framework
::
DDim
&
in_dims
)
{
const
int64_t
in_size
=
framework
::
product
(
in_dims
);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const
int64_t
unk_dim_val
=
-
1
;
const
int64_t
copy_dim_val
=
0
;
std
::
vector
<
int64_t
>
output_shape
(
shape
.
size
(),
0
);
int64_t
capacity
=
1
;
int
unk_dim_idx
=
-
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
unk_dim_val
)
{
PADDLE_ENFORCE
(
unk_dim_idx
==
-
1
,
"Only one input dimension of Attr(shape) can be unknown."
);
unk_dim_idx
=
i
;
}
else
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
i
)
<
in_dims
.
size
(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape."
);
}
else
{
PADDLE_ENFORCE
(
shape
[
i
]
>
0
,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension."
);
}
capacity
*=
(
shape
[
i
]
?
shape
[
i
]
:
in_dims
[
i
]);
output_shape
[
i
]
=
(
shape
[
i
]
?
static_cast
<
int64_t
>
(
shape
[
i
])
:
in_dims
[
i
]);
}
if
(
unk_dim_idx
!=
-
1
)
{
output_shape
[
unk_dim_idx
]
=
-
in_size
/
capacity
;
PADDLE_ENFORCE_EQ
(
output_shape
[
unk_dim_idx
]
*
capacity
,
-
in_size
,
"Invalid shape is given."
);
}
else
{
PADDLE_ENFORCE_EQ
(
capacity
,
in_size
,
"Invalid shape is given."
);
}
return
framework
::
make_ddim
(
output_shape
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ReshapeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -102,8 +27,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out_dims
=
ReshapeOp
::
ValidateShape
(
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"shape"
),
in
->
dims
());
auto
out_dims
=
out
->
dims
();
if
(
!
in
->
lod
().
empty
())
{
PADDLE_ENFORCE_EQ
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录