Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0ceeacbe
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0ceeacbe
编写于
7月 24, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make Scope can lookup variable name by variable
* Refine unittest also
上级
0ab678e9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
36 addition
and
12 deletion
+36
-12
paddle/framework/scope.h
paddle/framework/scope.h
+12
-1
paddle/framework/scope_test.cc
paddle/framework/scope_test.cc
+2
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+9
-1
python/paddle/v2/framework/network.py
python/paddle/v2/framework/network.py
+4
-10
python/paddle/v2/framework/tests/test_network.py
python/paddle/v2/framework/tests/test_network.py
+9
-0
未找到文件。
paddle/framework/scope.h
浏览文件 @
0ceeacbe
...
...
@@ -56,7 +56,9 @@ class Scope {
if
(
var
)
{
return
var
;
}
else
{
vars_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
new
Variable
());
auto
ptr
=
new
Variable
();
vars_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
ptr
);
var_names_
[
ptr
]
=
name
;
return
GetVariable
(
name
);
}
}
...
...
@@ -88,7 +90,16 @@ class Scope {
(
parent_
&&
parent_
->
HasVariable
(
name
)));
}
std
::
string
GetVariableName
(
Variable
*
const
var
)
const
{
try
{
return
var_names_
.
at
(
var
);
}
catch
(...)
{
return
""
;
}
}
private:
std
::
unordered_map
<
Variable
*
,
std
::
string
>
var_names_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars_
;
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
};
...
...
paddle/framework/scope_test.cc
浏览文件 @
0ceeacbe
...
...
@@ -40,6 +40,8 @@ TEST(Scope, Create) {
/// already exist.
Variable
*
var4
=
scope
->
CreateVariable
(
"a"
);
EXPECT_EQ
(
var4
,
var2
);
EXPECT_EQ
(
"a"
,
scope
->
GetVariableName
(
var4
));
}
TEST
(
Scope
,
Parent
)
{
...
...
paddle/pybind/pybind.cc
浏览文件 @
0ceeacbe
...
...
@@ -56,6 +56,11 @@ void ExposeOperator(ClassType& m) {
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
);
}
static
size_t
UniqueIntegerGenerator
()
{
static
std
::
atomic
<
size_t
>
generator
;
return
generator
.
fetch_add
(
1
);
}
PYBIND11_PLUGIN
(
core
)
{
py
::
module
m
(
"core"
,
"C++ core of PaddlePaddle"
);
...
...
@@ -106,7 +111,8 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
)
.
def
(
"create_var"
,
&
pd
::
Scope
::
CreateVariable
,
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
)
.
def
(
"get_var_name"
,
&
pd
::
Scope
::
GetVariableName
);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
...
...
@@ -166,5 +172,7 @@ All parameter, weight, gradient are variables in Paddle.
.
def
(
"complete_add_op"
,
[](
PlainNetPtr
&
self
)
{
self
->
CompleteAddOp
();
});
ExposeOperator
(
net
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
return
m
.
ptr
();
}
python/paddle/v2/framework/network.py
浏览文件 @
0ceeacbe
...
...
@@ -29,35 +29,31 @@ class NetworkFunctor(object):
if
ipt
in
kwargs
:
var
=
kwargs
[
ipt
]
if
isinstance
(
var
,
basestring
):
var_name
=
var
var
=
create_var
(
var
)
self
.
net
.
var_name_map
[
var
]
=
var_name
if
not
isinstance
(
var
,
core
.
Variable
):
raise
TypeError
(
"Input of op creation must be string or variable"
)
kwargs
[
ipt
]
=
self
.
net
.
var_name_map
[
var
]
kwargs
[
ipt
]
=
get_cur_scope
().
get_var_name
(
var
)
notemp_outputs
=
self
.
func
.
all_not_temp_output_args
for
name
in
notemp_outputs
:
if
name
not
in
kwargs
:
kwargs
[
name
]
=
self
.
func
.
__name__
+
"@OUT@%d"
%
self
.
net
.
generate_idx
self
.
net
.
generate_idx
+=
1
name
]
=
self
.
func
.
__name__
+
"@OUT@%d"
%
core
.
unique_integer
(
)
outputs
=
self
.
func
.
all_output_args
for
opt
in
outputs
:
if
opt
in
kwargs
:
var
=
kwargs
[
opt
]
if
isinstance
(
var
,
basestring
):
var_name
=
var
var
=
create_var
(
var
)
self
.
net
.
var_name_map
[
var
]
=
var_name
if
not
isinstance
(
var
,
core
.
Variable
):
raise
TypeError
(
"Output of op creation must be string or variable"
)
kwargs
[
opt
]
=
self
.
net
.
var_name_map
[
var
]
kwargs
[
opt
]
=
get_cur_scope
().
get_var_name
(
var
)
op
=
self
.
func
(
**
kwargs
)
...
...
@@ -93,8 +89,6 @@ class Network(object):
self
.
net
=
core
.
Net
.
create
()
funcs
=
(
func_name
for
func_name
in
dir
(
op_creations
)
if
not
func_name
.
startswith
(
"__"
))
self
.
generate_idx
=
0
self
.
var_name_map
=
dict
()
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
...
...
python/paddle/v2/framework/tests/test_network.py
浏览文件 @
0ceeacbe
...
...
@@ -18,6 +18,15 @@ class TestNet(unittest.TestCase):
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
'''
,
str
(
net
))
net2
=
Network
()
tmp
=
net2
.
add_two
(
X
=
"X"
,
Y
=
"Y"
)
self
.
assertTrue
(
isinstance
(
tmp
,
core
.
Variable
))
net2
.
complete_add_op
()
self
.
assertEqual
(
'''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
'''
,
str
(
net2
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录