Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
abf1005b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
abf1005b
编写于
10月 13, 2018
作者:
X
Xin Pan
提交者:
GitHub
10月 13, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13866 from PaddlePaddle/revert-13821-fix
Revert "Make variable::GetMutable robust"
上级
ae8b1c32
d852be7c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
24 addition
and
16 deletion
+24
-16
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+1
-1
paddle/fluid/framework/feed_fetch_method.cc
paddle/fluid/framework/feed_fetch_method.cc
+2
-1
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+1
-1
paddle/fluid/framework/variable.h
paddle/fluid/framework/variable.h
+1
-5
paddle/fluid/framework/variable_test.cc
paddle/fluid/framework/variable_test.cc
+5
-6
python/paddle/fluid/tests/book/test_word2vec.py
python/paddle/fluid/tests/book/test_word2vec.py
+14
-2
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
abf1005b
...
...
@@ -66,7 +66,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
}
else
if
(
var_type
==
proto
::
VarType
::
FETCH_LIST
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
STEP_SCOPES
)
{
var
->
GetMutable
<
std
::
vector
<
framework
::
Scope
*
>>
();
var
->
GetMutable
<
std
::
vector
<
framework
::
Scope
>>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_RANK_TABLE
)
{
var
->
GetMutable
<
LoDRankTable
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
...
...
paddle/fluid/framework/feed_fetch_method.cc
浏览文件 @
abf1005b
...
...
@@ -27,7 +27,8 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
// be created.
VLOG
(
3
)
<<
"SetFeedVariable name="
<<
var_name
<<
" index="
<<
index
;
Variable
*
g_feed_value
=
scope
->
Var
(
var_name
);
auto
&
feed_inputs
=
*
(
g_feed_value
->
GetMutable
<
FeedFetchList
>
());
auto
&
feed_inputs
=
*
(
g_feed_value
->
GetMutable
<
std
::
vector
<
paddle
::
framework
::
LoDTensor
>>
());
if
(
index
>=
feed_inputs
.
size
())
{
feed_inputs
.
resize
(
index
+
1
);
}
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
abf1005b
...
...
@@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
}
else
if
(
var_type
==
proto
::
VarType
::
FETCH_LIST
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
STEP_SCOPES
)
{
var
->
GetMutable
<
std
::
vector
<
framework
::
Scope
*
>>
();
var
->
GetMutable
<
std
::
vector
<
framework
::
Scope
>>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_RANK_TABLE
)
{
var
->
GetMutable
<
LoDRankTable
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
...
...
paddle/fluid/framework/variable.h
浏览文件 @
abf1005b
...
...
@@ -38,12 +38,8 @@ class Variable {
template
<
typename
T
>
T
*
GetMutable
()
{
if
(
!
holder_
)
{
if
(
!
IsType
<
T
>
()
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
new
T
()));
}
else
{
PADDLE_ENFORCE
(
IsType
<
T
>
(),
"Variable must be type %s, the holding type is %s"
,
typeid
(
T
).
name
(),
holder_
->
Type
().
name
());
}
return
static_cast
<
T
*>
(
holder_
->
Ptr
());
}
...
...
paddle/fluid/framework/variable_test.cc
浏览文件 @
abf1005b
...
...
@@ -33,10 +33,9 @@ TEST(Variable, GetMutable) {
const
Tensor
&
tt
=
v
->
Get
<
Tensor
>
();
EXPECT_EQ
(
1234
,
tt
.
content_
);
try
{
v
->
GetMutable
<
std
::
string
>
();
}
catch
(
std
::
exception
&
e
)
{
return
;
}
EXPECT_TRUE
(
false
);
std
::
string
*
s
=
v
->
GetMutable
<
std
::
string
>
();
*
s
=
"hello"
;
const
std
::
string
&
ss
=
v
->
Get
<
std
::
string
>
();
EXPECT_EQ
(
"hello"
,
ss
);
}
python/paddle/fluid/tests/book/test_word2vec.py
浏览文件 @
abf1005b
...
...
@@ -17,6 +17,7 @@ from __future__ import print_function
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.layers.device
import
get_places
from
paddle.fluid.layers.control_flow
import
ParallelDo
import
unittest
import
os
import
numpy
as
np
...
...
@@ -83,7 +84,18 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
avg_cost
,
predict_word
=
__network__
(
[
first_word
,
second_word
,
third_word
,
forth_word
,
next_word
])
else
:
raise
ValueError
(
'is_parallel=True not implemented'
)
places
=
get_places
()
pd
=
ParallelDo
(
places
)
with
pd
.
do
():
avg_cost
,
predict_word
=
__network__
(
list
(
map
(
pd
.
read_input
,
[
first_word
,
second_word
,
third_word
,
forth_word
,
next_word
])))
pd
.
write_output
(
avg_cost
)
avg_cost
=
fluid
.
layers
.
mean
(
pd
())
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
avg_cost
)
...
...
@@ -250,7 +262,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel):
for
use_cuda
in
(
False
,
True
):
for
is_sparse
in
(
False
,
True
):
for
is_parallel
in
(
False
,
):
# TODO(paddle-dev): Add parallel test.
for
is_parallel
in
(
False
,
True
):
inject_test_method
(
use_cuda
,
is_sparse
,
is_parallel
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录