Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e240ba29
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e240ba29
编写于
12月 12, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement backward
test=develop
上级
8760d23c
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
154 addition
and
28 deletion
+154
-28
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+2
-0
paddle/fluid/framework/op_desc.h
paddle/fluid/framework/op_desc.h
+2
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+5
-0
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+5
-0
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+112
-15
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+27
-12
未找到文件。
paddle/fluid/framework/op_desc.cc
浏览文件 @
e240ba29
...
@@ -34,6 +34,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
...
@@ -34,6 +34,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
public:
public:
CompileTimeInferShapeContext
(
const
OpDesc
&
op
,
const
BlockDesc
&
block
);
CompileTimeInferShapeContext
(
const
OpDesc
&
op
,
const
BlockDesc
&
block
);
InferShapeOpPtr
GetOp
()
const
override
{
return
&
op_
;
}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
;
bool
HasInput
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
;
...
...
paddle/fluid/framework/op_desc.h
浏览文件 @
e240ba29
...
@@ -121,6 +121,8 @@ class OpDesc {
...
@@ -121,6 +121,8 @@ class OpDesc {
BlockDesc
*
Block
()
{
return
this
->
block_
;
}
BlockDesc
*
Block
()
{
return
this
->
block_
;
}
const
BlockDesc
*
Block
()
const
{
return
this
->
block_
;
}
private:
private:
template
<
typename
MapType
>
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
e240ba29
...
@@ -481,6 +481,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -481,6 +481,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
:
op_
(
op
),
scope_
(
scope
)
{}
InferShapeOpPtr
GetOp
()
const
override
{
return
&
op_
;
}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
// has only one input
// has only one input
const
auto
&
ins
=
op_
.
Inputs
();
const
auto
&
ins
=
op_
.
Inputs
();
...
@@ -879,6 +881,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -879,6 +881,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
}
if
(
t
!=
nullptr
)
{
if
(
t
!=
nullptr
)
{
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s(%s) does not exist in Operator %s"
,
input
.
first
,
ipt_name
,
DebugString
());
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
tmp
==
data_type
||
data_type
==
-
1
,
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
e240ba29
...
@@ -25,7 +25,10 @@ limitations under the License. */
...
@@ -25,7 +25,10 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OperatorBase
;
using
InferShapeVarPtr
=
boost
::
variant
<
VarDesc
*
,
Variable
*>
;
using
InferShapeVarPtr
=
boost
::
variant
<
VarDesc
*
,
Variable
*>
;
using
InferShapeOpPtr
=
boost
::
variant
<
const
OpDesc
*
,
const
OperatorBase
*>
;
class
InferShapeContext
{
class
InferShapeContext
{
public:
public:
...
@@ -38,6 +41,8 @@ class InferShapeContext {
...
@@ -38,6 +41,8 @@ class InferShapeContext {
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
name
)
const
;
virtual
InferShapeOpPtr
GetOp
()
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
e240ba29
...
@@ -24,34 +24,34 @@ namespace operators {
...
@@ -24,34 +24,34 @@ namespace operators {
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
static
std
::
mutex
g_py_callables_mtx
;
static
std
::
vector
<
py
::
object
>
g_py_callables
;
static
std
::
vector
<
py
::
object
>
g_py_callables
;
size_t
AppendPythonCallableObjectAndReturnId
(
py
::
object
py_obj
)
{
size_t
AppendPythonCallableObjectAndReturnId
(
py
::
object
py_obj
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
g_py_callables_mtx
);
g_py_callables
.
emplace_back
(
py_obj
);
g_py_callables
.
emplace_back
(
py_obj
);
return
g_py_callables
.
size
()
-
1
;
return
g_py_callables
.
size
()
-
1
;
}
}
static
py
::
object
*
GetPythonCallableObject
(
size_t
i
)
{
static
py
::
object
*
GetPythonCallableObject
(
size_t
i
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
g_py_callables_mtx
);
PADDLE_ENFORCE_LT
(
i
,
g_py_callables
.
size
());
PADDLE_ENFORCE_LT
(
i
,
g_py_callables
.
size
());
return
&
g_py_callables
[
i
];
return
&
g_py_callables
[
i
];
}
}
void
Do
CallPythonFunc
(
py
::
object
*
callable
,
const
std
::
string
&
func_token
,
void
CallPythonFunc
(
py
::
object
*
callable
,
const
std
::
string
&
func_token
,
const
std
::
vector
<
framework
::
LoDTensor
>
&
ins
,
const
std
::
vector
<
framework
::
LoDTensor
>
&
ins
,
std
::
vector
<
framework
::
LoDTensor
*>
*
out
)
{
std
::
vector
<
framework
::
LoDTensor
*>
*
out
)
{
py
::
gil_scoped_acquire
guard
{};
py
::
gil_scoped_acquire
guard
{};
py
::
tuple
in_args
(
ins
.
size
());
py
::
tuple
in_args
(
ins
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
in_args
[
i
]
=
py
::
cast
(
ins
[
i
]
);
in_args
[
i
]
=
ins
[
i
].
IsInitialized
()
?
py
::
cast
(
ins
[
i
])
:
py
::
cast
(
nullptr
);
}
}
auto
ret
=
(
*
callable
)(
func_token
,
*
in_args
);
auto
ret
=
(
*
callable
)(
func_token
,
*
in_args
);
auto
ret_tuple
=
py
::
cast
<
py
::
tuple
>
(
ret
);
auto
ret_tuple
=
py
::
cast
<
py
::
tuple
>
(
ret
);
PADDLE_ENFORCE_EQ
(
py
::
len
(
ret_tuple
),
out
->
size
(),
"Output number not match"
);
PADDLE_ENFORCE_EQ
(
py
::
len
(
ret_tuple
),
out
->
size
(),
"Output number not match"
);
for
(
size_t
i
=
0
;
i
<
out
->
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out
->
size
();
++
i
)
{
if
((
*
out
)[
i
]
==
nullptr
)
{
continue
;
}
try
{
try
{
auto
*
out_tensor
=
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
i
]);
auto
*
out_tensor
=
py
::
cast
<
framework
::
LoDTensor
*>
(
ret_tuple
[
i
]);
PADDLE_ENFORCE_NOT_NULL
(
out_tensor
,
PADDLE_ENFORCE_NOT_NULL
(
out_tensor
,
...
@@ -67,8 +67,43 @@ void DoCallPythonFunc(py::object *callable, const std::string &func_token,
...
@@ -67,8 +67,43 @@ void DoCallPythonFunc(py::object *callable, const std::string &func_token,
class
PyFuncOpShapeInference
:
public
framework
::
InferShapeBase
{
class
PyFuncOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
IsRuntime
(),
"Infer shape cannot be called in runtime."
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Input(X) must exist"
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Input(X) must exist"
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
"Output(Out) must exist"
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
"Out"
),
"Output(Out) must exist"
);
auto
*
op
=
boost
::
get
<
const
framework
::
OpDesc
*>
(
ctx
->
GetOp
());
auto
*
block
=
op
->
Block
();
// No need to infer shape in forward part
if
(
block
->
ForwardBlockID
()
<
0
)
{
return
;
}
PADDLE_ENFORCE
(
!
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"token"
).
empty
(),
"Function token cannot be empty"
);
const
std
::
string
kGradVarSuffix
=
framework
::
kGradVarSuffix
;
auto
out_vars
=
ctx
->
GetOutputVarPtrs
(
"Out"
);
for
(
auto
&
out_var
:
out_vars
)
{
auto
*
out_var_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
out_var
);
auto
out_name
=
out_var_desc
->
Name
();
if
(
out_name
==
framework
::
kEmptyVarName
||
out_name
.
size
()
<
kGradVarSuffix
.
size
())
{
continue
;
}
size_t
len
=
out_name
.
size
()
-
kGradVarSuffix
.
size
();
if
(
out_name
.
substr
(
len
)
==
kGradVarSuffix
)
{
auto
fwd_var_name
=
out_name
.
substr
(
0
,
len
);
auto
*
in_var_desc
=
block
->
FindVarRecursive
(
fwd_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
in_var_desc
,
"Forward variable %s not found"
,
fwd_var_name
);
out_var_desc
->
SetShape
(
in_var_desc
->
GetShape
());
out_var_desc
->
SetDataType
(
in_var_desc
->
GetDataType
());
out_var_desc
->
SetLoDLevel
(
in_var_desc
->
GetLoDLevel
());
out_var_desc
->
SetType
(
in_var_desc
->
GetType
());
}
}
}
}
};
};
...
@@ -77,12 +112,68 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -77,12 +112,68 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"Inputs of py_func op."
).
AsDuplicable
();
AddInput
(
"X"
,
"Inputs of py_func op."
).
AsDuplicable
();
AddOutput
(
"Out"
,
"Outputs of py_func op"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"Outputs of py_func op"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"token"
,
"function token"
);
AddAttr
<
int
>
(
"handle_idx"
,
"Index of the registered py_func handle"
)
AddAttr
<
int
>
(
"handle_idx"
,
"handle index"
).
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"token"
,
"Token of function token to be called"
)
.
SetDefault
(
""
);
AddAttr
<
std
::
string
>
(
"backward_token"
,
"Token of backward function to be called"
)
.
SetDefault
(
""
);
AddComment
(
R"DOC("PyFunc Op")DOC"
);
AddComment
(
R"DOC("PyFunc Op")DOC"
);
}
}
};
};
class
PyFuncOpGradDescMaker
:
public
framework
::
GradOpDescMakerBase
{
public:
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
operator
()()
const
override
{
auto
&
fwd_attrs
=
Attrs
();
if
(
fwd_attrs
.
at
(
"backward_token"
).
empty
())
{
return
{};
}
std
::
unique_ptr
<
framework
::
OpDesc
>
grad_op
(
new
framework
::
OpDesc
());
grad_op
->
SetType
(
"py_func"
);
framework
::
AttributeMap
bwd_attrs
;
bwd_attrs
[
"token"
]
=
fwd_attrs
.
at
(
"backward_token"
);
bwd_attrs
[
"backward_token"
]
=
std
::
string
(
""
);
grad_op
->
SetAttrMap
(
bwd_attrs
);
auto
bwd_in
=
Input
(
"X"
);
auto
fwd_out
=
Output
(
"Out"
);
auto
fwd_out_grad
=
OutputGrad
(
"Out"
);
bwd_in
.
insert
(
bwd_in
.
end
(),
fwd_out
.
begin
(),
fwd_out
.
end
());
bwd_in
.
insert
(
bwd_in
.
end
(),
fwd_out_grad
.
begin
(),
fwd_out_grad
.
end
());
auto
bwd_out
=
InputGrad
(
"X"
,
false
);
if
(
VLOG_IS_ON
(
10
))
{
std
::
string
in_str
=
"PyFunc Grad Input: "
;
for
(
auto
&
in
:
bwd_in
)
{
in_str
+=
in
;
in_str
+=
" "
;
}
VLOG
(
10
)
<<
in_str
;
std
::
string
out_str
=
"PyFunc Grad Output: "
;
for
(
auto
&
out
:
bwd_out
)
{
out_str
+=
out
;
out
+=
" "
;
}
VLOG
(
10
)
<<
out_str
;
}
grad_op
->
SetInput
(
"X"
,
bwd_in
);
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
,
false
));
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
ret
(
1
);
ret
[
0
]
=
std
::
move
(
grad_op
);
return
ret
;
}
};
class
PyFuncOp
:
public
framework
::
OperatorBase
{
class
PyFuncOp
:
public
framework
::
OperatorBase
{
public:
public:
using
framework
::
OperatorBase
::
OperatorBase
;
using
framework
::
OperatorBase
::
OperatorBase
;
...
@@ -95,8 +186,14 @@ class PyFuncOp : public framework::OperatorBase {
...
@@ -95,8 +186,14 @@ class PyFuncOp : public framework::OperatorBase {
std
::
vector
<
framework
::
LoDTensor
>
inputs
(
in_arg_names
.
size
());
std
::
vector
<
framework
::
LoDTensor
>
inputs
(
in_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
in_arg_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
in_arg_names
.
size
();
++
i
)
{
auto
&
in_tensor
=
auto
in_var
=
scope
.
FindVar
(
in_arg_names
[
i
]);
scope
.
FindVar
(
in_arg_names
[
i
])
->
Get
<
framework
::
LoDTensor
>
();
if
(
in_var
==
nullptr
)
{
continue
;
}
auto
&
in_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
if
(
!
in_tensor
.
IsInitialized
())
{
continue
;
}
if
(
platform
::
is_gpu_place
(
in_tensor
.
place
()))
{
if
(
platform
::
is_gpu_place
(
in_tensor
.
place
()))
{
framework
::
TensorCopySync
(
in_tensor
,
platform
::
CPUPlace
(),
&
inputs
[
i
]);
framework
::
TensorCopySync
(
in_tensor
,
platform
::
CPUPlace
(),
&
inputs
[
i
]);
}
else
{
}
else
{
...
@@ -107,8 +204,9 @@ class PyFuncOp : public framework::OperatorBase {
...
@@ -107,8 +204,9 @@ class PyFuncOp : public framework::OperatorBase {
std
::
vector
<
framework
::
LoDTensor
*>
outputs
(
out_arg_names
.
size
());
std
::
vector
<
framework
::
LoDTensor
*>
outputs
(
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
out_arg_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out_arg_names
.
size
();
++
i
)
{
auto
*
out_var
=
scope
.
FindVar
(
out_arg_names
[
i
]);
auto
*
out_tensor
=
auto
*
out_tensor
=
scope
.
FindVar
(
out_arg_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
()
;
out_var
?
out_var
->
GetMutable
<
framework
::
LoDTensor
>
()
:
nullptr
;
outputs
[
i
]
=
out_tensor
;
outputs
[
i
]
=
out_tensor
;
}
}
...
@@ -117,7 +215,7 @@ class PyFuncOp : public framework::OperatorBase {
...
@@ -117,7 +215,7 @@ class PyFuncOp : public framework::OperatorBase {
auto
*
py_callable
=
GetPythonCallableObject
(
handle_idx
);
auto
*
py_callable
=
GetPythonCallableObject
(
handle_idx
);
VLOG
(
10
)
<<
"Call py_func_op with token "
<<
token
<<
", and handle_idx "
VLOG
(
10
)
<<
"Call py_func_op with token "
<<
token
<<
", and handle_idx "
<<
handle_idx
;
<<
handle_idx
;
Do
CallPythonFunc
(
py_callable
,
token
,
inputs
,
&
outputs
);
CallPythonFunc
(
py_callable
,
token
,
inputs
,
&
outputs
);
}
}
};
};
...
@@ -127,5 +225,4 @@ class PyFuncOp : public framework::OperatorBase {
...
@@ -127,5 +225,4 @@ class PyFuncOp : public framework::OperatorBase {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
py_func
,
ops
::
PyFuncOp
,
ops
::
PyFuncOpMaker
,
REGISTER_OPERATOR
(
py_func
,
ops
::
PyFuncOp
,
ops
::
PyFuncOpMaker
,
ops
::
PyFuncOpShapeInference
,
ops
::
PyFuncOpShapeInference
,
ops
::
PyFuncOpGradDescMaker
);
paddle
::
framework
::
EmptyGradOpMaker
);
paddle/fluid/pybind/protobuf.cc
浏览文件 @
e240ba29
...
@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) {
...
@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) {
.
def
(
"infer_var_type"
,
&
pd
::
OpDesc
::
InferVarType
)
.
def
(
"infer_var_type"
,
&
pd
::
OpDesc
::
InferVarType
)
.
def
(
"set_is_target"
,
&
pd
::
OpDesc
::
SetIsTarget
)
.
def
(
"set_is_target"
,
&
pd
::
OpDesc
::
SetIsTarget
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
pd
::
OpDesc
>
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
pd
::
OpDesc
>
)
.
def
(
"block"
,
&
pd
::
OpDesc
::
Block
,
.
def
(
"block"
,
[](
pd
::
OpDesc
&
self
)
{
return
self
.
Block
();
}
,
pybind11
::
return_value_policy
::
reference
);
pybind11
::
return_value_policy
::
reference
);
}
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
e240ba29
...
@@ -9096,12 +9096,9 @@ def py_func(func, x, out, backward_func=None):
...
@@ -9096,12 +9096,9 @@ def py_func(func, x, out, backward_func=None):
_main_program_to_register
=
dict
()
_main_program_to_register
=
dict
()
@
classmethod
@
classmethod
def
get_instance
(
cls
,
prog
=
None
):
def
get_instance
(
cls
,
prog
):
if
prog
is
None
:
prog
=
fluid
.
default_main_program
()
if
not
isinstance
(
prog
,
Program
):
if
not
isinstance
(
prog
,
Program
):
raise
ValueError
(
"prog must be None or
type of Program"
)
raise
TypeError
(
"prog must be
type of Program"
)
ret
=
cls
.
_main_program_to_register
.
get
(
prog
,
None
)
ret
=
cls
.
_main_program_to_register
.
get
(
prog
,
None
)
if
ret
is
None
:
if
ret
is
None
:
...
@@ -9155,6 +9152,10 @@ def py_func(func, x, out, backward_func=None):
...
@@ -9155,6 +9152,10 @@ def py_func(func, x, out, backward_func=None):
ret
=
[]
ret
=
[]
for
i
in
six
.
moves
.
range
(
len
(
ret0
)):
for
i
in
six
.
moves
.
range
(
len
(
ret0
)):
if
ret0
[
i
]
is
None
:
ret
.
append
(
None
)
continue
if
isinstance
(
ret0
[
i
],
core
.
LoDTensor
):
if
isinstance
(
ret0
[
i
],
core
.
LoDTensor
):
ret
.
append
(
ret0
[
i
])
ret
.
append
(
ret0
[
i
])
continue
continue
...
@@ -9175,20 +9176,34 @@ def py_func(func, x, out, backward_func=None):
...
@@ -9175,20 +9176,34 @@ def py_func(func, x, out, backward_func=None):
x
=
[
x
]
x
=
[
x
]
if
isinstance
(
out
,
Variable
):
if
isinstance
(
out
,
Variable
):
out
=
[
out
]
out_list
=
[
out
]
else
:
out_list
=
out
if
func
is
None
or
not
hasattr
(
func
,
'__call__'
):
raise
TypeError
(
'Input func must be a function'
)
for
each_out
in
out
:
if
backward_func
is
not
None
and
not
hasattr
(
backward_func
,
'__call__'
):
raise
TypeError
(
'Input backward_func must be a function'
)
for
each_out
in
out_list
:
if
len
(
each_out
.
shape
)
==
0
:
if
len
(
each_out
.
shape
)
==
0
:
raise
ValueError
(
raise
ValueError
(
'users should infer shapes of outputs of py_func op manually'
)
'Output shapes of py_func op should be provided by users manually'
)
py_func_reg
=
PyFuncRegister
.
get_instance
(
helper
.
main_program
)
py_func_reg
=
PyFuncRegister
.
get_instance
(
helper
.
main_program
)
token
=
py_func_reg
.
unique_token
(
func
)
forward_token
=
py_func_reg
.
unique_token
(
func
)
backward_token
=
py_func_reg
.
unique_token
(
backward_func
)
if
backward_func
is
not
None
else
''
helper
.
append_op
(
helper
.
append_op
(
type
=
'py_func'
,
type
=
'py_func'
,
inputs
=
{
'X'
:
x
},
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
outputs
=
{
'Out'
:
out_list
},
attrs
=
{
'handle_idx'
:
py_func_reg
.
handle_idx
,
attrs
=
{
'token'
:
token
})
'handle_idx'
:
py_func_reg
.
handle_idx
,
'token'
:
forward_token
,
'backward_token'
:
backward_token
})
return
out
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录