Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
490eb906
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
490eb906
编写于
12月 20, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish infer shape of py_func op
test=develop
上级
dc8847af
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
41 addition
and
45 deletion
+41
-45
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+0
-2
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+0
-2
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+0
-3
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+41
-38
未找到文件。
paddle/fluid/framework/op_desc.cc
浏览文件 @
490eb906
...
...
@@ -34,8 +34,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext
(
const
OpDesc
&
op
,
const
BlockDesc
&
block
);
InferShapeOpPtr
GetOp
()
const
override
{
return
&
op_
;
}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
;
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
490eb906
...
...
@@ -481,8 +481,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
InferShapeOpPtr
GetOp
()
const
override
{
return
&
op_
;
}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
// has only one input
const
auto
&
ins
=
op_
.
Inputs
();
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
490eb906
...
...
@@ -28,7 +28,6 @@ namespace framework {
class
OperatorBase
;
using
InferShapeVarPtr
=
boost
::
variant
<
VarDesc
*
,
Variable
*>
;
using
InferShapeOpPtr
=
boost
::
variant
<
const
OpDesc
*
,
const
OperatorBase
*>
;
class
InferShapeContext
{
public:
...
...
@@ -41,8 +40,6 @@ class InferShapeContext {
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
;
virtual
InferShapeOpPtr
GetOp
()
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
490eb906
...
...
@@ -91,66 +91,68 @@ static void CallPythonFunc(py::object *callable,
}
}
class
PyFuncOp
ShapeInference
:
public
framework
::
InferShapeBas
e
{
class
PyFuncOp
VarTypInference
:
public
framework
::
VarTypeInferenc
e
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
IsRuntime
(),
"Infer shape cannot be called in runtime."
);
void
operator
()(
const
framework
::
OpDesc
&
op
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
&
outs
=
op
.
Outputs
();
bool
has_out
=
(
outs
.
count
(
"Out"
)
>
0
&&
!
outs
.
at
(
"Out"
).
empty
());
auto
&
ins
=
op
.
Inputs
();
bool
has_in
=
(
ins
.
count
(
"X"
)
>
0
&&
!
ins
.
at
(
"X"
).
empty
());
/**
* X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output
*/
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
)
||
ctx
->
HasOutputs
(
"Out"
),
"Input(X) or Output(Out) must exist"
);
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
int
>
(
kForwardPythonCallableId
),
0
,
PADDLE_ENFORCE
(
has_in
||
has_out
,
"Input(X) or Output(Out) must exist"
);
PADDLE_ENFORCE_GE
(
boost
::
get
<
int
>
(
op
.
GetAttr
(
kForwardPythonCallableId
)
),
0
,
"Function id cannot be less than 0"
);
if
(
!
has_out
)
return
;
/**
* Traverse all outputs, check if name of any output ends with @GRAD.
* If found, set its shape, dtype, lod_level, type to be the same as
* the corresponding forward variable
*
* Why not get input dims from InferShapeContext?
* Because some variables in forward inputs/outputs may not be needed
* in backward. Those variables are not inside InferShapeContext.
*
* InferShape would be only called in compile time. During runtime,
* the shapes of outputs should be guaranteed by user-defined Python
* functions.
*/
auto
*
op
=
boost
::
get
<
const
framework
::
OpDesc
*>
(
ctx
->
GetOp
());
auto
*
block
=
op
->
Block
();
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
auto
out_vars
=
ctx
->
GetOutputVarPtrs
(
"Out"
);
for
(
auto
&
out_var
:
out_vars
)
{
auto
*
out_var_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
out_var
);
if
(
out_var_desc
==
nullptr
)
{
continue
;
}
auto
out_name
=
out_var_desc
->
Name
();
if
(
out_name
==
framework
::
kEmptyVarName
||
out_name
.
size
()
<
kGradVarSuffix
.
size
())
{
auto
&
out_var_names
=
outs
.
at
(
"Out"
);
for
(
auto
&
out_var_name
:
out_var_names
)
{
if
(
out_var_name
==
framework
::
kEmptyVarName
||
out_var_name
.
size
()
<
kGradVarSuffix
.
size
())
{
continue
;
}
size_t
len
=
out_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_name
.
substr
(
0
,
len
);
auto
*
in_var_desc
=
block
->
FindVarRecursive
(
fwd_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
in_var_desc
,
"Forward variable %s not found"
,
size_t
len
=
out_var_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_var_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_var_name
.
substr
(
0
,
len
);
auto
*
out_var_desc
=
block
->
FindVarRecursive
(
out_var_name
);
auto
*
fwd_var_desc
=
block
->
FindVarRecursive
(
fwd_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
out_var_desc
,
"Backward variable %s not found"
,
out_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
fwd_var_desc
,
"Forward variable %s not found"
,
fwd_var_name
);
VLOG
(
10
)
<<
"Infer
shape of Output("
<<
out
_name
<<
") as Input("
<<
in_var_desc
->
Name
()
<<
")"
;
out_var_desc
->
SetShape
(
in
_var_desc
->
GetShape
());
out_var_desc
->
SetDataType
(
in
_var_desc
->
GetDataType
());
out_var_desc
->
SetLoDLevel
(
in
_var_desc
->
GetLoDLevel
());
out_var_desc
->
SetType
(
in
_var_desc
->
GetType
());
VLOG
(
10
)
<<
"Infer
var_desc of Output("
<<
out_var
_name
<<
") as Input("
<<
fwd_var_name
<<
")"
;
out_var_desc
->
SetShape
(
fwd
_var_desc
->
GetShape
());
out_var_desc
->
SetDataType
(
fwd
_var_desc
->
GetDataType
());
out_var_desc
->
SetLoDLevel
(
fwd
_var_desc
->
GetLoDLevel
());
out_var_desc
->
SetType
(
fwd
_var_desc
->
GetType
());
}
}
}
};
class
PyFuncOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
IsRuntime
(),
"Infer shape cannot be called in runtime."
);
}
};
class
PyFuncOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
py_func
,
ops
::
PyFuncOp
,
ops
::
PyFuncOpMaker
,
ops
::
PyFuncOpShapeInference
,
ops
::
PyFuncOpGradDescMaker
);
ops
::
PyFuncOpVarTypInference
,
ops
::
PyFuncOpShapeInference
,
ops
::
PyFuncOpGradDescMaker
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录