From e383ea20dcc350a7290a6c7689f0c0ca53d42eb1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 15 May 2018 21:00:42 +0800 Subject: [PATCH] fix fetch op handle --- paddle/fluid/framework/details/fetch_op_handle.cc | 15 ++++++++------- .../tests/unittests/test_parallel_executor.py | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index b1c9dd0d152..224e8e1f6ef 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() { WaitInputVarGenerated(platform::CPUPlace()); tensors_.resize(inputs_.size()); - auto *var_handle = static_cast(inputs_[0]); - auto &var_name = var_handle->name_; platform::CPUPlace cpu; auto &scopes = *local_scopes_; - for (size_t i = 0; i < scopes.size(); ++i) { - auto &scope = scopes[i]; - auto *var = - scope->FindVar(kLocalExecScopeName)->Get()->FindVar(var_name); + for (size_t i = 0; i < inputs_.size(); ++i) { + auto *var_handle = static_cast(inputs_[i]); + auto &scope = scopes.at(var_handle->scope_idx_); + auto *var = scope->FindVar(kLocalExecScopeName) + ->Get() + ->FindVar(var_handle->name_); PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", - var_name); + var_handle->name_); + auto &t = var->Get(); if (platform::is_gpu_place(t.place())) { #ifdef PADDLE_WITH_CUDA diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index a3be1a8db68..926c6bc28a1 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -764,7 +764,7 @@ class TestCRFModel(unittest.TestCase): def test_update_sparse_parameter_with_new_strategy(self): self.check_network_convergence( - is_sparse=False, balance_parameter_opt_between_cards=True) + is_sparse=True, balance_parameter_opt_between_cards=True) def test_update_dense_parameter_with_new_strategy(self): self.check_network_convergence( @@ -836,7 +836,7 @@ class TestFetchOp(unittest.TestCase): assert not math.isnan(np.sum(ret[i])) and \ not math.isinf(np.sum(ret[i])) - def test_update_sparse_parameter(self): + def test_fetch_op(self): tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16) tst_reader_iter = tst_reader() -- GitLab