Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
df707d04
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
df707d04
编写于
7月 17, 2017
作者:
Y
Yu Yang
提交者:
GitHub
7月 17, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2893 from reyoung/feature/op_creation_methods
Python Generate OpCreation Methods by OpProto
上级
a0caf234
0e77b31a
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
542 addition
and
20 deletion
+542
-20
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+24
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+13
-15
paddle/framework/operator.h
paddle/framework/operator.h
+7
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+17
-0
python/paddle/v2/framework/create_op_creation_methods.py
python/paddle/v2/framework/create_op_creation_methods.py
+235
-0
python/paddle/v2/framework/tests/test_op_creation_methods.py
python/paddle/v2/framework/tests/test_op_creation_methods.py
+241
-2
python/paddle/v2/optimizer.py
python/paddle/v2/optimizer.py
+5
-1
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
df707d04
#pragma once
#include <algorithm>
#include <atomic>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
...
...
@@ -214,11 +215,14 @@ class OpRegistry {
}
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
//! Create a OpPtr by type.
std
::
string
op_type
=
op_desc
.
type
();
OperatorPtr
op
(
creators
().
at
(
op_type
)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const
OpProto
&
op_proto
=
protos
().
at
(
op_type
);
// set op's inputs_ from desc.
op
->
type_
=
op_desc
.
type
();
// set op's inputs_ from desc.
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
back_inserter
(
op
->
inputs_
));
...
...
@@ -226,13 +230,20 @@ class OpRegistry {
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
back_inserter
(
op
->
outputs_
));
// set op's attr;
//! Fill attrs, and validate attrs.
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
op
->
attrs_
[
attr
.
name
()]
=
AttrTypeHelper
::
GetAttrValue
(
attr
);
}
op_checkers
().
at
(
op_type
).
Check
(
op
->
attrs_
);
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName
(
op
.
get
());
// set argument offsets stored in op.
CreateInOutOffsetMap
(
op
,
op_proto
);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op
->
Init
();
return
op
;
}
...
...
@@ -248,6 +259,17 @@ class OpRegistry {
};
private:
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
if
(
outname
==
OperatorBase
::
TMP_VAR_NAME
())
{
outname
+=
op
->
type_
;
outname
+=
"@"
;
outname
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
...
...
paddle/framework/operator.cc
浏览文件 @
df707d04
...
...
@@ -77,23 +77,21 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
ss
<<
"
=================
\n
"
;
ss
<<
"type = "
<<
type_
<<
"
\n
"
;
ss
<<
"inputs = ["
;
for
(
auto
&
ipt
:
inputs_
)
{
ss
<<
ipt
<<
", "
;
ss
<<
"
Op("
<<
type_
<<
"), inputs:(
"
;
for
(
size_t
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
ss
<<
inputs_
[
i
]
;
if
(
i
!=
inputs_
.
size
()
-
1
)
{
ss
<<
", "
;
}
ss
<<
"]
\n
"
;
ss
<<
"outputs = ["
;
for
(
auto
&
opt
:
outputs_
)
{
ss
<<
opt
<<
", "
;
}
ss
<<
"]
\n
"
;
ss
<<
"attr_keys = ["
;
for
(
auto
&
attr
:
attrs_
)
{
ss
<<
attr
.
first
<<
", "
;
ss
<<
"), outputs:("
;
for
(
size_t
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
ss
<<
outputs_
[
i
];
if
(
i
!=
outputs_
.
size
()
-
1
)
{
ss
<<
", "
;
}
ss
<<
"]
\n
"
;
}
ss
<<
")."
;
return
ss
.
str
();
}
...
...
paddle/framework/operator.h
浏览文件 @
df707d04
...
...
@@ -41,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
*/
class
OperatorBase
{
public:
/// If a variable is a empty variable, that name will be used.
static
std
::
string
EMPTY_VAR_NAME
()
{
return
"@EMPTY@"
;
}
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
...
...
paddle/pybind/pybind.cc
浏览文件 @
df707d04
...
...
@@ -63,6 +63,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return
ret_values
;
});
m
.
def_submodule
(
"var_names"
,
"The module will return special predefined variable name in Paddle"
)
.
def
(
"empty"
,
pd
::
OperatorBase
::
EMPTY_VAR_NAME
)
.
def
(
"temp"
,
pd
::
OperatorBase
::
TMP_VAR_NAME
);
py
::
class_
<
pd
::
OperatorBase
,
pd
::
OperatorPtr
>
(
m
,
"Operator"
)
.
def
(
"__str__"
,
&
pd
::
OperatorBase
::
DebugString
)
.
def_static
(
"create"
,
[](
const
std
::
string
&
protobin
)
{
pd
::
OpDesc
desc
;
PADDLE_ENFORCE
(
desc
.
ParsePartialFromString
(
protobin
),
"Cannot parse user input to OpDesc"
);
PADDLE_ENFORCE
(
desc
.
IsInitialized
(),
"User OpDesc is not initialized, reason %s"
,
desc
.
InitializationErrorString
());
return
pd
::
OpRegistry
::
CreateOp
(
desc
);
});
return
m
.
ptr
();
}
python/paddle/v2/framework/create_op_creation_methods.py
浏览文件 @
df707d04
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attr_type_pb2
as
attr_type_pb2
import
cStringIO
def
get_all_op_protos
():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
protostrs
=
core
.
get_all_op_protos
()
ret_values
=
[]
for
pbstr
in
protostrs
:
op_proto
=
op_proto_pb2
.
OpProto
.
FromString
(
str
(
pbstr
))
ret_values
.
append
(
op_proto
)
return
ret_values
class
OpDescCreationMethod
(
object
):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
def
__init__
(
self
,
op_proto
):
if
not
isinstance
(
op_proto
,
op_proto_pb2
.
OpProto
):
raise
TypeError
(
"Argument should be OpProto"
)
self
.
__op_proto__
=
op_proto
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
:rtype: op_desc_pb2.OpDesc
"""
if
len
(
args
)
!=
0
:
raise
ValueError
(
"Only keyword arguments is supported by Paddle"
)
op_desc
=
op_desc_pb2
.
OpDesc
()
# Inputs
ipts
,
ipt_format
,
_
=
OpDescCreationMethod
.
extract_input_or_output
(
"input"
,
kwargs
,
self
.
__op_proto__
.
inputs
)
op_desc
.
inputs
.
extend
(
ipts
)
if
ipt_format
is
not
None
:
op_desc
.
attrs
.
extend
([
ipt_format
])
# Outputs
outs
,
out_format
,
tmp_index
=
OpDescCreationMethod
.
extract_input_or_output
(
"output"
,
kwargs
,
self
.
__op_proto__
.
outputs
)
op_desc
.
outputs
.
extend
(
outs
)
if
out_format
is
not
None
:
op_desc
.
attrs
.
extend
([
out_format
])
if
len
(
tmp_index
)
!=
0
:
tmp_index_attr
=
op_desc
.
attrs
.
add
()
tmp_index_attr
.
type
=
attr_type_pb2
.
INTS
tmp_index_attr
.
name
=
"temporary_index"
tmp_index_attr
.
ints
.
extend
(
tmp_index
)
# Types
op_desc
.
type
=
self
.
__op_proto__
.
type
# Attrs
for
attr
in
self
.
__op_proto__
.
attrs
:
if
attr
.
generated
:
continue
user_defined_attr
=
kwargs
.
get
(
attr
.
name
,
None
)
if
user_defined_attr
is
not
None
:
new_attr
=
op_desc
.
attrs
.
add
()
new_attr
.
name
=
attr
.
name
new_attr
.
type
=
attr
.
type
if
attr
.
type
==
attr_type_pb2
.
INT
:
new_attr
.
i
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
FLOAT
:
new_attr
.
f
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
STRING
:
new_attr
.
s
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
INTS
:
new_attr
.
ints
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attr_type_pb2
.
FLOATS
:
new_attr
.
floats
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attr_type_pb2
.
STRINGS
:
new_attr
.
strings
.
extend
(
user_defined_attr
)
else
:
raise
NotImplementedError
(
"Not support attribute type "
+
attr
.
type
)
return
op_desc
@
staticmethod
def
extract_input_or_output
(
in_out
,
kwargs
,
meta
):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.
:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple
=
OpDescCreationMethod
.
any_is_true
((
m
.
multiple
for
m
in
meta
))
tmp_index
=
[]
retv
=
[]
if
multiple
:
var_format
=
op_desc_pb2
.
AttrDesc
()
var_format
.
type
=
attr_type_pb2
.
INTS
var_format
.
name
=
"%s_format"
%
in_out
var_format
.
ints
.
append
(
0
)
for
var
in
meta
:
var_name
=
var
.
name
if
var
.
temporary
:
var_name
=
[
core
.
var_names
.
temp
()]
tmp_index
.
append
(
len
(
retv
))
else
:
var_name
=
kwargs
.
get
(
var_name
,
[])
if
not
isinstance
(
var_name
,
list
):
var_name
=
[
var_name
]
retv
.
extend
(
var_name
)
var_format
.
ints
.
append
(
len
(
var_name
)
+
var_format
.
ints
[
-
1
])
return
retv
,
var_format
,
tmp_index
else
:
for
var
in
meta
:
if
var
.
temporary
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
temp
()))
tmp_index
.
append
(
len
(
retv
))
else
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
empty
()))
return
retv
,
None
,
tmp_index
@
staticmethod
def
any_is_true
(
generator
):
"""
Reduce a bool array to one. If any of them is True, then return True.
"""
for
flag
in
generator
:
if
flag
:
return
True
return
False
def
get_docstring_from_op_proto
(
op_proto
):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if
not
isinstance
(
op_proto
,
op_proto_pb2
.
OpProto
):
raise
TypeError
(
"Input must be OpProto"
)
f
=
cStringIO
.
StringIO
()
f
.
write
(
op_proto
.
comment
)
f
.
write
(
"
\n
"
)
def
__append_param__
(
name
,
comment
,
type
):
# Maybe replace the following line with template engine is better.
f
.
write
(
":param "
)
f
.
write
(
name
)
f
.
write
(
": "
)
f
.
write
(
comment
)
f
.
write
(
"
\n
"
)
f
.
write
(
":type "
)
f
.
write
(
name
)
f
.
write
(
": "
)
f
.
write
(
type
)
f
.
write
(
"
\n
"
)
for
ipt
in
op_proto
.
inputs
:
__append_param__
(
ipt
.
name
,
ipt
.
comment
,
"list | basestr"
if
ipt
.
multiple
else
"basestr"
)
temp_var_prefix
=
\
"This is a temporary variable. It does not have to set by user. "
for
opt
in
op_proto
.
outputs
:
__append_param__
(
opt
.
name
,
opt
.
comment
if
not
opt
.
temporary
else
temp_var_prefix
+
opt
.
comment
,
"list | basestr"
if
opt
.
multiple
else
"basestr"
)
for
attr
in
op_proto
.
attrs
:
attr_type
=
None
if
attr
.
type
==
attr_type_pb2
.
INT
:
attr_type
=
"int"
elif
attr
.
type
==
attr_type_pb2
.
FLOAT
:
attr_type
=
"float"
elif
attr
.
type
==
attr_type_pb2
.
STRING
:
attr_type
=
"basestr"
elif
attr
.
type
==
attr_type_pb2
.
INTS
:
attr_type
=
"list of int"
elif
attr
.
type
==
attr_type_pb2
.
FLOATS
:
attr_type
=
"list of float"
elif
attr
.
type
==
attr_type_pb2
.
STRINGS
:
attr_type
=
"list of basestr"
if
attr_type
is
None
:
raise
RuntimeError
(
"Not supported attribute type "
+
attr
.
type
)
__append_param__
(
attr
.
name
,
attr
.
comment
,
attr_type
)
return
f
.
getvalue
()
def
create_op_creation_method
(
op_proto
):
"""
Generate op creation method for an OpProto
"""
method
=
OpDescCreationMethod
(
op_proto
)
def
__impl__
(
*
args
,
**
kwargs
):
opdesc
=
method
(
*
args
,
**
kwargs
)
return
core
.
Operator
.
create
(
opdesc
.
SerializeToString
())
__impl__
.
__doc__
=
get_docstring_from_op_proto
(
op_proto
)
return
__impl__
class
OpCreationsHolder
(
object
):
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them.
"""
pass
op_creations
=
OpCreationsHolder
()
def
__bootstrap__
():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for
op_proto
in
get_all_op_protos
():
func
=
create_op_creation_method
(
op_proto
)
func
.
__name__
=
str
(
op_proto
.
type
)
setattr
(
op_creations
,
func
.
__name__
,
func
)
__bootstrap__
()
python/paddle/v2/framework/tests/test_op_creation_methods.py
浏览文件 @
df707d04
import
unittest
import
paddle.v2.framework.create_op_creation_methods
as
creation
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attr_type_pb2
as
attr_type_pb2
class
Test
OpCreationsMethod
s
(
unittest
.
TestCase
):
def
test_all
_protos
(
self
):
class
Test
GetAllProto
s
(
unittest
.
TestCase
):
def
test_all
(
self
):
all_protos
=
creation
.
get_all_op_protos
()
self
.
assertNotEqual
(
0
,
len
(
all_protos
))
...
...
@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase):
self
.
assertTrue
(
each
.
IsInitialized
())
class
TestOpDescCreationMethod
(
unittest
.
TestCase
):
def
test_plain_input_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
comment
=
"not matter"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"Y"
ipt
.
comment
=
"not matter"
opt
=
op
.
outputs
.
add
()
opt
.
name
=
"Z"
opt
.
comment
=
"not matter"
op
.
comment
=
"not matter"
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
output
=
method
(
X
=
"a"
,
Y
=
"b"
,
Z
=
"c"
)
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
"a"
,
"b"
])
expected
.
outputs
.
append
(
"c"
)
self
.
assertEqual
(
expected
,
output
)
def
test_multiple_input_plain_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"fc"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
comment
=
""
ipt
.
multiple
=
True
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"W"
ipt
.
comment
=
""
ipt
.
multiple
=
True
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"b"
ipt
.
comment
=
""
out
=
op
.
outputs
.
add
()
out
.
name
=
"Y"
out
.
comment
=
""
op
.
comment
=
""
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
generated1
=
method
(
X
=
"x"
,
W
=
"w"
,
b
=
"b"
,
Y
=
"y"
)
expected1
=
op_desc_pb2
.
OpDesc
()
expected1
.
inputs
.
extend
([
'x'
,
'w'
,
'b'
])
expected1
.
outputs
.
extend
([
'y'
])
expected1
.
type
=
'fc'
attr
=
expected1
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
])
self
.
assertEqual
(
expected1
,
generated1
)
generated2
=
method
(
X
=
[
'x1'
,
'x2'
,
'x3'
],
b
=
'b'
,
W
=
[
'w1'
,
'w2'
,
'w3'
],
Y
=
'y'
)
expected2
=
op_desc_pb2
.
OpDesc
()
expected2
.
inputs
.
extend
([
'x1'
,
'x2'
,
'x3'
,
'w1'
,
'w2'
,
'w3'
,
'b'
])
expected2
.
outputs
.
extend
([
'y'
])
expected2
.
type
=
'fc'
attr
=
expected2
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
3
,
6
,
7
])
self
.
assertEqual
(
expected2
,
generated2
)
def
test_attrs
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
'X'
ipt
.
comment
=
""
def
__add_attr__
(
name
,
type
):
attr
=
op
.
attrs
.
add
()
attr
.
name
=
name
attr
.
comment
=
""
attr
.
type
=
type
__add_attr__
(
"int_attr"
,
attr_type_pb2
.
INT
)
__add_attr__
(
"float_attr"
,
attr_type_pb2
.
FLOAT
)
__add_attr__
(
"string_attr"
,
attr_type_pb2
.
STRING
)
__add_attr__
(
"ints_attr"
,
attr_type_pb2
.
INTS
)
__add_attr__
(
"floats_attr"
,
attr_type_pb2
.
FLOATS
)
__add_attr__
(
"strings_attr"
,
attr_type_pb2
.
STRINGS
)
op
.
comment
=
""
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
generated
=
method
(
X
=
"a"
,
int_attr
=
10
,
float_attr
=
3.2
,
string_attr
=
"test_str"
,
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
strings_attr
=
[
"a"
,
"b"
,
"c"
])
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
'a'
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"int_attr"
attr
.
type
=
attr_type_pb2
.
INT
attr
.
i
=
10
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"float_attr"
attr
.
type
=
attr_type_pb2
.
FLOAT
attr
.
f
=
3.2
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"string_attr"
attr
.
type
=
attr_type_pb2
.
STRING
attr
.
s
=
"test_str"
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"ints_attr"
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
,
4
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'floats_attr'
attr
.
type
=
attr_type_pb2
.
FLOATS
attr
.
floats
.
extend
([
0.2
,
3.2
,
4.5
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'strings_attr'
attr
.
type
=
attr_type_pb2
.
STRINGS
attr
.
strings
.
extend
([
'a'
,
'b'
,
'c'
])
self
.
assertEqual
(
expected
,
generated
)
def
test_input_temporary_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
out
=
op
.
outputs
.
add
()
out
.
name
=
"OUT"
out
.
comment
=
""
out
=
op
.
outputs
.
add
()
out
.
name
=
"TMP"
out
.
comment
=
""
out
.
temporary
=
True
out
=
op
.
outputs
.
add
()
out
.
name
=
"OUT2"
out
.
comment
=
""
op
.
comment
=
""
method
=
creation
.
OpDescCreationMethod
(
op
)
generated
=
method
(
OUT
=
"a"
,
OUT2
=
"b"
)
desc
=
op_desc_pb2
.
OpDesc
()
desc
.
outputs
.
extend
([
"a"
,
core
.
var_names
.
temp
(),
"b"
])
desc
.
type
=
"test"
attr
=
desc
.
attrs
.
add
()
attr
.
name
=
"temporary_index"
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
append
(
2
)
self
.
assertEqual
(
generated
,
desc
)
class
TestOpCreationDocStr
(
unittest
.
TestCase
):
def
test_all
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
op
.
comment
=
"""Test Op.
This op is used for unit test, not a real op.
"""
a
=
op
.
inputs
.
add
()
a
.
name
=
"a"
a
.
comment
=
"Input a for test op"
a
.
multiple
=
True
b
=
op
.
inputs
.
add
()
b
.
name
=
"b"
b
.
comment
=
"Input b for test op"
self
.
assertTrue
(
op
.
IsInitialized
())
o1
=
op
.
outputs
.
add
()
o1
.
name
=
"output"
o1
.
comment
=
"The output of test op"
o2
=
op
.
outputs
.
add
()
o2
.
name
=
"temp output"
o2
.
comment
=
"The temporary output of test op"
o2
.
temporary
=
True
test_str
=
op
.
attrs
.
add
()
test_str
.
name
=
"str_attr"
test_str
.
type
=
attr_type_pb2
.
STRING
test_str
.
comment
=
"A string attribute for test op"
actual
=
creation
.
get_docstring_from_op_proto
(
op
)
expected_docstring
=
'''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self
.
assertEqual
(
expected_docstring
,
actual
)
class
TestOpCreations
(
unittest
.
TestCase
):
def
test_all
(
self
):
add_op
=
creation
.
op_creations
.
add_two
(
X
=
"a"
,
Y
=
"b"
,
Out
=
"z"
)
self
.
assertIsNotNone
(
add_op
)
# Invoke C++ DebugString()
self
.
assertEqual
(
'Op(add_two), inputs:(a, b), outputs:(z).'
,
str
(
add_op
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/optimizer.py
浏览文件 @
df707d04
import
py_paddle.swig_paddle
as
swig_api
import
paddle.trainer_config_helpers.config_parser_utils
as
config_parser_utils
import
paddle.trainer_config_helpers.optimizers
as
v1_optimizers
"""
...
...
@@ -17,6 +16,7 @@ __all__ = [
class
Optimizer
(
object
):
def
__init__
(
self
,
**
kwargs
):
import
py_paddle.swig_paddle
as
swig_api
if
'batch_size'
in
kwargs
:
del
kwargs
[
'batch_size'
]
# not important for python library.
...
...
@@ -35,18 +35,22 @@ class Optimizer(object):
For each optimizer(SGD, Adam), GradientMachine should enable different
buffers.
"""
import
py_paddle.swig_paddle
as
swig_api
tmp
=
swig_api
.
ParameterOptimizer
.
create
(
self
.
__opt_conf__
)
assert
isinstance
(
tmp
,
swig_api
.
ParameterOptimizer
)
return
tmp
.
getParameterTypes
()
def
__create_local_updater__
(
self
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createLocalUpdater
(
self
.
__opt_conf__
)
def
__create_remote_updater__
(
self
,
pass_num
,
use_sparse_updater
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createRemoteUpdater
(
self
.
__opt_conf__
,
pass_num
,
use_sparse_updater
)
def
__create_new_remote_updater__
(
self
,
pserver_spec
,
use_etcd
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createNewRemoteUpdater
(
self
.
__opt_conf__
,
pserver_spec
,
use_etcd
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录