Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
df707d04
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
df707d04
编写于
7月 17, 2017
作者:
Y
Yu Yang
提交者:
GitHub
7月 17, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2893 from reyoung/feature/op_creation_methods
Python Generate OpCreation Methods by OpProto
上级
a0caf234
0e77b31a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
542 addition
and
20 deletion
+542
-20
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+24
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+13
-15
paddle/framework/operator.h
paddle/framework/operator.h
+7
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+17
-0
python/paddle/v2/framework/create_op_creation_methods.py
python/paddle/v2/framework/create_op_creation_methods.py
+235
-0
python/paddle/v2/framework/tests/test_op_creation_methods.py
python/paddle/v2/framework/tests/test_op_creation_methods.py
+241
-2
python/paddle/v2/optimizer.py
python/paddle/v2/optimizer.py
+5
-1
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
df707d04
#pragma once
#include <algorithm>
#include <atomic>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
...
...
@@ -214,11 +215,14 @@ class OpRegistry {
}
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
//! Create a OpPtr by type.
std
::
string
op_type
=
op_desc
.
type
();
OperatorPtr
op
(
creators
().
at
(
op_type
)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const
OpProto
&
op_proto
=
protos
().
at
(
op_type
);
// set op's inputs_ from desc.
op
->
type_
=
op_desc
.
type
();
// set op's inputs_ from desc.
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
back_inserter
(
op
->
inputs_
));
...
...
@@ -226,13 +230,20 @@ class OpRegistry {
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
back_inserter
(
op
->
outputs_
));
// set op's attr;
//! Fill attrs, and validate attrs.
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
op
->
attrs_
[
attr
.
name
()]
=
AttrTypeHelper
::
GetAttrValue
(
attr
);
}
op_checkers
().
at
(
op_type
).
Check
(
op
->
attrs_
);
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName
(
op
.
get
());
// set argument offsets stored in op.
CreateInOutOffsetMap
(
op
,
op_proto
);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op
->
Init
();
return
op
;
}
...
...
@@ -248,6 +259,17 @@ class OpRegistry {
};
private:
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
if
(
outname
==
OperatorBase
::
TMP_VAR_NAME
())
{
outname
+=
op
->
type_
;
outname
+=
"@"
;
outname
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
...
...
paddle/framework/operator.cc
浏览文件 @
df707d04
...
...
@@ -77,23 +77,21 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
ss
<<
"=================
\n
"
;
ss
<<
"type = "
<<
type_
<<
"
\n
"
;
ss
<<
"inputs = ["
;
for
(
auto
&
ipt
:
inputs_
)
{
ss
<<
ipt
<<
", "
;
ss
<<
"Op("
<<
type_
<<
"), inputs:("
;
for
(
size_t
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
ss
<<
inputs_
[
i
];
if
(
i
!=
inputs_
.
size
()
-
1
)
{
ss
<<
", "
;
}
}
ss
<<
"]
\n
"
;
ss
<<
"outputs = ["
;
for
(
auto
&
opt
:
outputs_
)
{
ss
<<
opt
<<
", "
;
ss
<<
"), outputs:("
;
for
(
size_t
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
ss
<<
outputs_
[
i
];
if
(
i
!=
outputs_
.
size
()
-
1
)
{
ss
<<
", "
;
}
}
ss
<<
"]
\n
"
;
ss
<<
"attr_keys = ["
;
for
(
auto
&
attr
:
attrs_
)
{
ss
<<
attr
.
first
<<
", "
;
}
ss
<<
"]
\n
"
;
ss
<<
")."
;
return
ss
.
str
();
}
...
...
paddle/framework/operator.h
浏览文件 @
df707d04
...
...
@@ -41,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
*/
class
OperatorBase
{
public:
/// If a variable is a empty variable, that name will be used.
static
std
::
string
EMPTY_VAR_NAME
()
{
return
"@EMPTY@"
;
}
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
...
...
paddle/pybind/pybind.cc
浏览文件 @
df707d04
...
...
@@ -63,6 +63,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return
ret_values
;
});
m
.
def_submodule
(
"var_names"
,
"The module will return special predefined variable name in Paddle"
)
.
def
(
"empty"
,
pd
::
OperatorBase
::
EMPTY_VAR_NAME
)
.
def
(
"temp"
,
pd
::
OperatorBase
::
TMP_VAR_NAME
);
py
::
class_
<
pd
::
OperatorBase
,
pd
::
OperatorPtr
>
(
m
,
"Operator"
)
.
def
(
"__str__"
,
&
pd
::
OperatorBase
::
DebugString
)
.
def_static
(
"create"
,
[](
const
std
::
string
&
protobin
)
{
pd
::
OpDesc
desc
;
PADDLE_ENFORCE
(
desc
.
ParsePartialFromString
(
protobin
),
"Cannot parse user input to OpDesc"
);
PADDLE_ENFORCE
(
desc
.
IsInitialized
(),
"User OpDesc is not initialized, reason %s"
,
desc
.
InitializationErrorString
());
return
pd
::
OpRegistry
::
CreateOp
(
desc
);
});
return
m
.
ptr
();
}
python/paddle/v2/framework/create_op_creation_methods.py
浏览文件 @
df707d04
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attr_type_pb2
as
attr_type_pb2
import
cStringIO
def
get_all_op_protos
():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
protostrs
=
core
.
get_all_op_protos
()
ret_values
=
[]
for
pbstr
in
protostrs
:
op_proto
=
op_proto_pb2
.
OpProto
.
FromString
(
str
(
pbstr
))
ret_values
.
append
(
op_proto
)
return
ret_values
class
OpDescCreationMethod
(
object
):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
def
__init__
(
self
,
op_proto
):
if
not
isinstance
(
op_proto
,
op_proto_pb2
.
OpProto
):
raise
TypeError
(
"Argument should be OpProto"
)
self
.
__op_proto__
=
op_proto
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
:rtype: op_desc_pb2.OpDesc
"""
if
len
(
args
)
!=
0
:
raise
ValueError
(
"Only keyword arguments is supported by Paddle"
)
op_desc
=
op_desc_pb2
.
OpDesc
()
# Inputs
ipts
,
ipt_format
,
_
=
OpDescCreationMethod
.
extract_input_or_output
(
"input"
,
kwargs
,
self
.
__op_proto__
.
inputs
)
op_desc
.
inputs
.
extend
(
ipts
)
if
ipt_format
is
not
None
:
op_desc
.
attrs
.
extend
([
ipt_format
])
# Outputs
outs
,
out_format
,
tmp_index
=
OpDescCreationMethod
.
extract_input_or_output
(
"output"
,
kwargs
,
self
.
__op_proto__
.
outputs
)
op_desc
.
outputs
.
extend
(
outs
)
if
out_format
is
not
None
:
op_desc
.
attrs
.
extend
([
out_format
])
if
len
(
tmp_index
)
!=
0
:
tmp_index_attr
=
op_desc
.
attrs
.
add
()
tmp_index_attr
.
type
=
attr_type_pb2
.
INTS
tmp_index_attr
.
name
=
"temporary_index"
tmp_index_attr
.
ints
.
extend
(
tmp_index
)
# Types
op_desc
.
type
=
self
.
__op_proto__
.
type
# Attrs
for
attr
in
self
.
__op_proto__
.
attrs
:
if
attr
.
generated
:
continue
user_defined_attr
=
kwargs
.
get
(
attr
.
name
,
None
)
if
user_defined_attr
is
not
None
:
new_attr
=
op_desc
.
attrs
.
add
()
new_attr
.
name
=
attr
.
name
new_attr
.
type
=
attr
.
type
if
attr
.
type
==
attr_type_pb2
.
INT
:
new_attr
.
i
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
FLOAT
:
new_attr
.
f
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
STRING
:
new_attr
.
s
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
INTS
:
new_attr
.
ints
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attr_type_pb2
.
FLOATS
:
new_attr
.
floats
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attr_type_pb2
.
STRINGS
:
new_attr
.
strings
.
extend
(
user_defined_attr
)
else
:
raise
NotImplementedError
(
"Not support attribute type "
+
attr
.
type
)
return
op_desc
@
staticmethod
def
extract_input_or_output
(
in_out
,
kwargs
,
meta
):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.
:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple
=
OpDescCreationMethod
.
any_is_true
((
m
.
multiple
for
m
in
meta
))
tmp_index
=
[]
retv
=
[]
if
multiple
:
var_format
=
op_desc_pb2
.
AttrDesc
()
var_format
.
type
=
attr_type_pb2
.
INTS
var_format
.
name
=
"%s_format"
%
in_out
var_format
.
ints
.
append
(
0
)
for
var
in
meta
:
var_name
=
var
.
name
if
var
.
temporary
:
var_name
=
[
core
.
var_names
.
temp
()]
tmp_index
.
append
(
len
(
retv
))
else
:
var_name
=
kwargs
.
get
(
var_name
,
[])
if
not
isinstance
(
var_name
,
list
):
var_name
=
[
var_name
]
retv
.
extend
(
var_name
)
var_format
.
ints
.
append
(
len
(
var_name
)
+
var_format
.
ints
[
-
1
])
return
retv
,
var_format
,
tmp_index
else
:
for
var
in
meta
:
if
var
.
temporary
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
temp
()))
tmp_index
.
append
(
len
(
retv
))
else
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
empty
()))
return
retv
,
None
,
tmp_index
@
staticmethod
def
any_is_true
(
generator
):
"""
Reduce a bool array to one. If any of them is True, then return True.
"""
for
flag
in
generator
:
if
flag
:
return
True
return
False
def
get_docstring_from_op_proto
(
op_proto
):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if
not
isinstance
(
op_proto
,
op_proto_pb2
.
OpProto
):
raise
TypeError
(
"Input must be OpProto"
)
f
=
cStringIO
.
StringIO
()
f
.
write
(
op_proto
.
comment
)
f
.
write
(
"
\n
"
)
def
__append_param__
(
name
,
comment
,
type
):
# Maybe replace the following line with template engine is better.
f
.
write
(
":param "
)
f
.
write
(
name
)
f
.
write
(
": "
)
f
.
write
(
comment
)
f
.
write
(
"
\n
"
)
f
.
write
(
":type "
)
f
.
write
(
name
)
f
.
write
(
": "
)
f
.
write
(
type
)
f
.
write
(
"
\n
"
)
for
ipt
in
op_proto
.
inputs
:
__append_param__
(
ipt
.
name
,
ipt
.
comment
,
"list | basestr"
if
ipt
.
multiple
else
"basestr"
)
temp_var_prefix
=
\
"This is a temporary variable. It does not have to set by user. "
for
opt
in
op_proto
.
outputs
:
__append_param__
(
opt
.
name
,
opt
.
comment
if
not
opt
.
temporary
else
temp_var_prefix
+
opt
.
comment
,
"list | basestr"
if
opt
.
multiple
else
"basestr"
)
for
attr
in
op_proto
.
attrs
:
attr_type
=
None
if
attr
.
type
==
attr_type_pb2
.
INT
:
attr_type
=
"int"
elif
attr
.
type
==
attr_type_pb2
.
FLOAT
:
attr_type
=
"float"
elif
attr
.
type
==
attr_type_pb2
.
STRING
:
attr_type
=
"basestr"
elif
attr
.
type
==
attr_type_pb2
.
INTS
:
attr_type
=
"list of int"
elif
attr
.
type
==
attr_type_pb2
.
FLOATS
:
attr_type
=
"list of float"
elif
attr
.
type
==
attr_type_pb2
.
STRINGS
:
attr_type
=
"list of basestr"
if
attr_type
is
None
:
raise
RuntimeError
(
"Not supported attribute type "
+
attr
.
type
)
__append_param__
(
attr
.
name
,
attr
.
comment
,
attr_type
)
return
f
.
getvalue
()
def
create_op_creation_method
(
op_proto
):
"""
Generate op creation method for an OpProto
"""
method
=
OpDescCreationMethod
(
op_proto
)
def
__impl__
(
*
args
,
**
kwargs
):
opdesc
=
method
(
*
args
,
**
kwargs
)
return
core
.
Operator
.
create
(
opdesc
.
SerializeToString
())
__impl__
.
__doc__
=
get_docstring_from_op_proto
(
op_proto
)
return
__impl__
class
OpCreationsHolder
(
object
):
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them.
"""
pass
op_creations
=
OpCreationsHolder
()
def
__bootstrap__
():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for
op_proto
in
get_all_op_protos
():
func
=
create_op_creation_method
(
op_proto
)
func
.
__name__
=
str
(
op_proto
.
type
)
setattr
(
op_creations
,
func
.
__name__
,
func
)
__bootstrap__
()
python/paddle/v2/framework/tests/test_op_creation_methods.py
浏览文件 @
df707d04
import
unittest
import
paddle.v2.framework.create_op_creation_methods
as
creation
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attr_type_pb2
as
attr_type_pb2
class
Test
OpCreationsMethod
s
(
unittest
.
TestCase
):
def
test_all
_protos
(
self
):
class
Test
GetAllProto
s
(
unittest
.
TestCase
):
def
test_all
(
self
):
all_protos
=
creation
.
get_all_op_protos
()
self
.
assertNotEqual
(
0
,
len
(
all_protos
))
...
...
@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase):
self
.
assertTrue
(
each
.
IsInitialized
())
class
TestOpDescCreationMethod
(
unittest
.
TestCase
):
def
test_plain_input_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
comment
=
"not matter"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"Y"
ipt
.
comment
=
"not matter"
opt
=
op
.
outputs
.
add
()
opt
.
name
=
"Z"
opt
.
comment
=
"not matter"
op
.
comment
=
"not matter"
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
output
=
method
(
X
=
"a"
,
Y
=
"b"
,
Z
=
"c"
)
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
"a"
,
"b"
])
expected
.
outputs
.
append
(
"c"
)
self
.
assertEqual
(
expected
,
output
)
def
test_multiple_input_plain_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"fc"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
comment
=
""
ipt
.
multiple
=
True
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"W"
ipt
.
comment
=
""
ipt
.
multiple
=
True
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"b"
ipt
.
comment
=
""
out
=
op
.
outputs
.
add
()
out
.
name
=
"Y"
out
.
comment
=
""
op
.
comment
=
""
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
generated1
=
method
(
X
=
"x"
,
W
=
"w"
,
b
=
"b"
,
Y
=
"y"
)
expected1
=
op_desc_pb2
.
OpDesc
()
expected1
.
inputs
.
extend
([
'x'
,
'w'
,
'b'
])
expected1
.
outputs
.
extend
([
'y'
])
expected1
.
type
=
'fc'
attr
=
expected1
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
])
self
.
assertEqual
(
expected1
,
generated1
)
generated2
=
method
(
X
=
[
'x1'
,
'x2'
,
'x3'
],
b
=
'b'
,
W
=
[
'w1'
,
'w2'
,
'w3'
],
Y
=
'y'
)
expected2
=
op_desc_pb2
.
OpDesc
()
expected2
.
inputs
.
extend
([
'x1'
,
'x2'
,
'x3'
,
'w1'
,
'w2'
,
'w3'
,
'b'
])
expected2
.
outputs
.
extend
([
'y'
])
expected2
.
type
=
'fc'
attr
=
expected2
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
3
,
6
,
7
])
self
.
assertEqual
(
expected2
,
generated2
)
def
test_attrs
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
'X'
ipt
.
comment
=
""
def
__add_attr__
(
name
,
type
):
attr
=
op
.
attrs
.
add
()
attr
.
name
=
name
attr
.
comment
=
""
attr
.
type
=
type
__add_attr__
(
"int_attr"
,
attr_type_pb2
.
INT
)
__add_attr__
(
"float_attr"
,
attr_type_pb2
.
FLOAT
)
__add_attr__
(
"string_attr"
,
attr_type_pb2
.
STRING
)
__add_attr__
(
"ints_attr"
,
attr_type_pb2
.
INTS
)
__add_attr__
(
"floats_attr"
,
attr_type_pb2
.
FLOATS
)
__add_attr__
(
"strings_attr"
,
attr_type_pb2
.
STRINGS
)
op
.
comment
=
""
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
generated
=
method
(
X
=
"a"
,
int_attr
=
10
,
float_attr
=
3.2
,
string_attr
=
"test_str"
,
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
strings_attr
=
[
"a"
,
"b"
,
"c"
])
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
'a'
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"int_attr"
attr
.
type
=
attr_type_pb2
.
INT
attr
.
i
=
10
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"float_attr"
attr
.
type
=
attr_type_pb2
.
FLOAT
attr
.
f
=
3.2
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"string_attr"
attr
.
type
=
attr_type_pb2
.
STRING
attr
.
s
=
"test_str"
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"ints_attr"
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
,
4
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'floats_attr'
attr
.
type
=
attr_type_pb2
.
FLOATS
attr
.
floats
.
extend
([
0.2
,
3.2
,
4.5
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'strings_attr'
attr
.
type
=
attr_type_pb2
.
STRINGS
attr
.
strings
.
extend
([
'a'
,
'b'
,
'c'
])
self
.
assertEqual
(
expected
,
generated
)
def
test_input_temporary_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
out
=
op
.
outputs
.
add
()
out
.
name
=
"OUT"
out
.
comment
=
""
out
=
op
.
outputs
.
add
()
out
.
name
=
"TMP"
out
.
comment
=
""
out
.
temporary
=
True
out
=
op
.
outputs
.
add
()
out
.
name
=
"OUT2"
out
.
comment
=
""
op
.
comment
=
""
method
=
creation
.
OpDescCreationMethod
(
op
)
generated
=
method
(
OUT
=
"a"
,
OUT2
=
"b"
)
desc
=
op_desc_pb2
.
OpDesc
()
desc
.
outputs
.
extend
([
"a"
,
core
.
var_names
.
temp
(),
"b"
])
desc
.
type
=
"test"
attr
=
desc
.
attrs
.
add
()
attr
.
name
=
"temporary_index"
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
append
(
2
)
self
.
assertEqual
(
generated
,
desc
)
class
TestOpCreationDocStr
(
unittest
.
TestCase
):
def
test_all
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
op
.
comment
=
"""Test Op.
This op is used for unit test, not a real op.
"""
a
=
op
.
inputs
.
add
()
a
.
name
=
"a"
a
.
comment
=
"Input a for test op"
a
.
multiple
=
True
b
=
op
.
inputs
.
add
()
b
.
name
=
"b"
b
.
comment
=
"Input b for test op"
self
.
assertTrue
(
op
.
IsInitialized
())
o1
=
op
.
outputs
.
add
()
o1
.
name
=
"output"
o1
.
comment
=
"The output of test op"
o2
=
op
.
outputs
.
add
()
o2
.
name
=
"temp output"
o2
.
comment
=
"The temporary output of test op"
o2
.
temporary
=
True
test_str
=
op
.
attrs
.
add
()
test_str
.
name
=
"str_attr"
test_str
.
type
=
attr_type_pb2
.
STRING
test_str
.
comment
=
"A string attribute for test op"
actual
=
creation
.
get_docstring_from_op_proto
(
op
)
expected_docstring
=
'''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self
.
assertEqual
(
expected_docstring
,
actual
)
class
TestOpCreations
(
unittest
.
TestCase
):
def
test_all
(
self
):
add_op
=
creation
.
op_creations
.
add_two
(
X
=
"a"
,
Y
=
"b"
,
Out
=
"z"
)
self
.
assertIsNotNone
(
add_op
)
# Invoke C++ DebugString()
self
.
assertEqual
(
'Op(add_two), inputs:(a, b), outputs:(z).'
,
str
(
add_op
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/optimizer.py
浏览文件 @
df707d04
import
py_paddle.swig_paddle
as
swig_api
import
paddle.trainer_config_helpers.config_parser_utils
as
config_parser_utils
import
paddle.trainer_config_helpers.optimizers
as
v1_optimizers
"""
...
...
@@ -17,6 +16,7 @@ __all__ = [
class
Optimizer
(
object
):
def
__init__
(
self
,
**
kwargs
):
import
py_paddle.swig_paddle
as
swig_api
if
'batch_size'
in
kwargs
:
del
kwargs
[
'batch_size'
]
# not important for python library.
...
...
@@ -35,18 +35,22 @@ class Optimizer(object):
For each optimizer(SGD, Adam), GradientMachine should enable different
buffers.
"""
import
py_paddle.swig_paddle
as
swig_api
tmp
=
swig_api
.
ParameterOptimizer
.
create
(
self
.
__opt_conf__
)
assert
isinstance
(
tmp
,
swig_api
.
ParameterOptimizer
)
return
tmp
.
getParameterTypes
()
def
__create_local_updater__
(
self
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createLocalUpdater
(
self
.
__opt_conf__
)
def
__create_remote_updater__
(
self
,
pass_num
,
use_sparse_updater
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createRemoteUpdater
(
self
.
__opt_conf__
,
pass_num
,
use_sparse_updater
)
def
__create_new_remote_updater__
(
self
,
pserver_spec
,
use_etcd
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createNewRemoteUpdater
(
self
.
__opt_conf__
,
pserver_spec
,
use_etcd
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录