未验证 提交 ae8b1c32 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13821 from panyx0718/fix

Make variable::GetMutable robust
...@@ -66,7 +66,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -66,7 +66,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope*>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
......
...@@ -27,8 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input, ...@@ -27,8 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
// be created. // be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name); Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs = auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
if (index >= feed_inputs.size()) { if (index >= feed_inputs.size()) {
feed_inputs.resize(index + 1); feed_inputs.resize(index + 1);
} }
......
...@@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) { ...@@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope *>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
......
...@@ -38,8 +38,12 @@ class Variable { ...@@ -38,8 +38,12 @@ class Variable {
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
if (!IsType<T>()) { if (!holder_) {
holder_.reset(new PlaceholderImpl<T>(new T())); holder_.reset(new PlaceholderImpl<T>(new T()));
} else {
PADDLE_ENFORCE(IsType<T>(),
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
} }
......
...@@ -33,9 +33,10 @@ TEST(Variable, GetMutable) { ...@@ -33,9 +33,10 @@ TEST(Variable, GetMutable) {
const Tensor& tt = v->Get<Tensor>(); const Tensor& tt = v->Get<Tensor>();
EXPECT_EQ(1234, tt.content_); EXPECT_EQ(1234, tt.content_);
std::string* s = v->GetMutable<std::string>(); try {
*s = "hello"; v->GetMutable<std::string>();
} catch (std::exception& e) {
const std::string& ss = v->Get<std::string>(); return;
EXPECT_EQ("hello", ss); }
EXPECT_TRUE(false);
} }
...@@ -17,7 +17,6 @@ from __future__ import print_function ...@@ -17,7 +17,6 @@ from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places from paddle.fluid.layers.device import get_places
from paddle.fluid.layers.control_flow import ParallelDo
import unittest import unittest
import os import os
import numpy as np import numpy as np
...@@ -84,18 +83,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True): ...@@ -84,18 +83,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
avg_cost, predict_word = __network__( avg_cost, predict_word = __network__(
[first_word, second_word, third_word, forth_word, next_word]) [first_word, second_word, third_word, forth_word, next_word])
else: else:
places = get_places() raise ValueError('is_parallel=True not implemented')
pd = ParallelDo(places)
with pd.do():
avg_cost, predict_word = __network__(
list(
map(pd.read_input, [
first_word, second_word, third_word, forth_word,
next_word
])))
pd.write_output(avg_cost)
avg_cost = fluid.layers.mean(pd())
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
...@@ -262,7 +250,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel): ...@@ -262,7 +250,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel):
for use_cuda in (False, True): for use_cuda in (False, True):
for is_sparse in (False, True): for is_sparse in (False, True):
for is_parallel in (False, True): for is_parallel in (False, ): # TODO(paddle-dev): Add parallel test.
inject_test_method(use_cuda, is_sparse, is_parallel) inject_test_method(use_cuda, is_sparse, is_parallel)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册