Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
754f0c68
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
754f0c68
编写于
7月 25, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix unittest
上级
b80590d7
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
14 addition
and
19 deletion
+14
-19
paddle/framework/scope.h
paddle/framework/scope.h
+8
-8
paddle/framework/scope_test.cc
paddle/framework/scope_test.cc
+3
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-9
python/paddle/v2/framework/tests/test_network.py
python/paddle/v2/framework/tests/test_network.py
+2
-2
未找到文件。
paddle/framework/scope.h
浏览文件 @
754f0c68
...
...
@@ -57,8 +57,8 @@ class Scope {
return
var
;
}
else
{
auto
ptr
=
new
Variable
();
vars
_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
ptr
);
var_
names
_
[
ptr
]
=
name
;
name_to_var
_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
ptr
);
var_
to_name
_
[
ptr
]
=
name
;
return
GetVariable
(
name
);
}
}
...
...
@@ -70,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found.
*/
Variable
*
GetVariable
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars
_
.
find
(
name
);
if
(
it
!=
vars
_
.
end
())
{
auto
it
=
name_to_var
_
.
find
(
name
);
if
(
it
!=
name_to_var
_
.
end
())
{
return
it
->
second
.
get
();
}
else
if
(
parent_
!=
nullptr
)
{
return
parent_
->
GetVariable
(
name
);
...
...
@@ -86,21 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope
*/
bool
HasVariable
(
const
std
::
string
&
name
)
const
{
return
(
vars_
.
find
(
name
)
!=
vars
_
.
end
()
||
return
(
name_to_var_
.
find
(
name
)
!=
name_to_var
_
.
end
()
||
(
parent_
&&
parent_
->
HasVariable
(
name
)));
}
std
::
string
GetVariableName
(
Variable
*
const
var
)
const
{
try
{
return
var_
names
_
.
at
(
var
);
return
var_
to_name
_
.
at
(
var
);
}
catch
(...)
{
return
""
;
}
}
private:
std
::
unordered_map
<
Variable
*
,
std
::
string
>
var_
names
_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars
_
;
std
::
unordered_map
<
Variable
*
,
std
::
string
>
var_
to_name
_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
name_to_var
_
;
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
};
...
...
paddle/framework/scope_test.cc
浏览文件 @
754f0c68
...
...
@@ -42,6 +42,9 @@ TEST(Scope, Create) {
EXPECT_EQ
(
var4
,
var2
);
EXPECT_EQ
(
"a"
,
scope
->
GetVariableName
(
var4
));
Scope
scope2
;
auto
var
=
scope2
.
CreateVariable
(
"tmp"
);
EXPECT_EQ
(
""
,
scope
->
GetVariableName
(
var
));
}
TEST
(
Scope
,
Parent
)
{
...
...
paddle/pybind/pybind.cc
浏览文件 @
754f0c68
...
...
@@ -15,14 +15,6 @@ limitations under the License. */
#include <Python.h>
#include <fstream>
#include <vector>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
...
...
@@ -160,7 +152,7 @@ All parameter, weight, gradient are variables in Paddle.
net
.
def_static
(
"create"
,
[]()
->
std
::
shared_ptr
<
pd
::
PlainNet
>
{
auto
retv
=
std
::
make_shared
<
pd
::
PlainNet
>
();
retv
->
type_
=
"
naive
_net"
;
retv
->
type_
=
"
plain
_net"
;
return
retv
;
})
.
def
(
"add_op"
,
&
pd
::
PlainNet
::
AddOp
)
...
...
python/paddle/v2/framework/tests/test_network.py
浏览文件 @
754f0c68
...
...
@@ -11,7 +11,7 @@ class TestNet(unittest.TestCase):
net
.
complete_add_op
()
self
.
assertTrue
(
isinstance
(
fc_out
,
core
.
Variable
))
self
.
assertEqual
(
'''Op(
naive
_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
'''Op(
plain
_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
...
...
@@ -23,7 +23,7 @@ class TestNet(unittest.TestCase):
self
.
assertTrue
(
isinstance
(
tmp
,
core
.
Variable
))
net2
.
complete_add_op
()
self
.
assertEqual
(
'''Op(
naive
_net), inputs:(X, Y), outputs:(add_two@OUT@2).
'''Op(
plain
_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
'''
,
str
(
net2
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录