提交 e383ea20 编写于 作者: C chengduoZH

fix fetch op handle

上级 5f6fd26f
...@@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() { ...@@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() {
WaitInputVarGenerated(platform::CPUPlace()); WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var_handle->name_;
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto &scopes = *local_scopes_; auto &scopes = *local_scopes_;
for (size_t i = 0; i < scopes.size(); ++i) { for (size_t i = 0; i < inputs_.size(); ++i) {
auto &scope = scopes[i]; auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto *var = auto &scope = scopes.at(var_handle->scope_idx_);
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name); auto *var = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_name); var_handle->name_);
auto &t = var->Get<framework::LoDTensor>(); auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) { if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -764,7 +764,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -764,7 +764,7 @@ class TestCRFModel(unittest.TestCase):
def test_update_sparse_parameter_with_new_strategy(self): def test_update_sparse_parameter_with_new_strategy(self):
self.check_network_convergence( self.check_network_convergence(
is_sparse=False, balance_parameter_opt_between_cards=True) is_sparse=True, balance_parameter_opt_between_cards=True)
def test_update_dense_parameter_with_new_strategy(self): def test_update_dense_parameter_with_new_strategy(self):
self.check_network_convergence( self.check_network_convergence(
...@@ -836,7 +836,7 @@ class TestFetchOp(unittest.TestCase): ...@@ -836,7 +836,7 @@ class TestFetchOp(unittest.TestCase):
assert not math.isnan(np.sum(ret[i])) and \ assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i])) not math.isinf(np.sum(ret[i]))
def test_update_sparse_parameter(self): def test_fetch_op(self):
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16) tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
tst_reader_iter = tst_reader() tst_reader_iter = tst_reader()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册