Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9f816352
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9f816352
编写于
8月 07, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Follow comments
上级
5d074c91
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
16 deletion
+26
-16
python/paddle/v2/framework/op.py
python/paddle/v2/framework/op.py
+25
-15
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-1
未找到文件。
python/paddle/v2/framework/op.py
浏览文件 @
9f816352
...
...
@@ -145,6 +145,16 @@ class OpDescCreationMethod(object):
return
False
class
OpInfo
(
object
):
def
__init__
(
self
,
name
,
method
,
inputs
,
outputs
,
attrs
,
no_temp_outputs
):
self
.
name
=
name
self
.
method
=
method
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
attrs
=
attrs
self
.
no_temp_outputs
=
no_temp_outputs
def
create_op_creation_method
(
op_proto
):
"""
Generate op creation method for an OpProto
...
...
@@ -155,15 +165,15 @@ def create_op_creation_method(op_proto):
opdesc
=
method
(
*
args
,
**
kwargs
)
return
core
.
Operator
.
create
(
opdesc
.
SerializeToString
())
return
{
'method'
:
__impl__
,
'name'
:
op_proto
.
type
,
'all_inputs'
:
[
var
.
name
for
var
in
op_proto
.
inputs
],
'all_outputs'
:
[
var
.
name
for
var
in
op_proto
.
outputs
],
'all_attrs'
:
[
attr
.
name
for
attr
in
op_proto
.
attrs
],
'all_no_temp_outputs'
:
[
var
.
name
for
var
in
op_proto
.
outputs
if
not
var
.
temporary
]
}
return
OpInfo
(
method
=
__impl__
,
name
=
op_proto
.
type
,
inputs
=
[
var
.
name
for
var
in
op_proto
.
inputs
],
outputs
=
[
var
.
name
for
var
in
op_proto
.
outputs
],
attrs
=
[
attr
.
name
for
attr
in
op_proto
.
attrs
],
no_temp_outputs
=
[
var
.
name
for
var
in
op_proto
.
outputs
if
not
var
.
temporary
])
class
OperatorFactory
(
object
):
...
...
@@ -185,27 +195,27 @@ class OperatorFactory(object):
"argument except type"
)
t
=
args
[
0
]
return
self
.
get_op_
creation_info
(
t
)[
'method'
]
(
**
kwargs
)
return
self
.
get_op_
info
(
t
).
method
(
**
kwargs
)
def
types
(
self
):
return
self
.
op_methods
.
keys
()
def
get_op_
creation_
info
(
self
,
t
):
def
get_op_info
(
self
,
t
):
if
t
not
in
self
.
op_methods
:
raise
ValueError
(
"operator %s is not registered"
,
t
)
return
self
.
op_methods
.
get
(
t
)
def
get_op_input_names
(
self
,
type
):
return
self
.
get_op_
creation_info
(
type
)[
'all_inputs'
]
return
self
.
get_op_
info
(
type
).
inputs
def
get_op_output_names
(
self
,
type
):
return
self
.
get_op_
creation_info
(
type
)[
'all_outputs'
]
return
self
.
get_op_
info
(
type
).
outputs
def
get_op_attr_names
(
self
,
type
):
return
self
.
get_op_
creation_info
(
type
)[
'all_attrs'
]
return
self
.
get_op_
info
(
type
).
attrs
def
get_op_no_temp_output_names
(
self
,
type
):
return
self
.
get_op_
creation_info
(
type
)[
'all_no_temp_outputs'
]
return
self
.
get_op_
info
(
type
).
no_temp_outputs
Operator
=
OperatorFactory
()
# Default global factory
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
9f816352
...
...
@@ -20,4 +20,4 @@ py_test(gradient_checker SRCS gradient_checker.py)
py_test
(
test_rowwise_add_op SRCS test_rowwise_add_op.py
)
py_test
(
test_default_scope_funcs SRCS test_default_scope_funcs.py
)
py_test
(
test_operator SRCS test_operator.py
py_test
(
test_operator SRCS test_operator.py
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录