From 2e5d44f102896b9ea357f9eca82d9955385ed094 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 7 May 2018 16:00:16 +0800 Subject: [PATCH] fix fetch op --- .../framework/details/fetch_op_handle.cc | 7 +- .../tests/unittests/test_parallel_executor.py | 80 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 1e8ca20b5..88c7caadb 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -49,7 +49,7 @@ void FetchOpHandle::RunImpl() { platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); for (auto *input : inputs_) { auto *var = static_cast(input); - var->generated_op_->Wait(cpu_ctx); + if (var->generated_op_) var->generated_op_->Wait(cpu_ctx); } tensors_.resize(inputs_.size()); auto *var_handle = static_cast(inputs_[0]); @@ -61,9 +61,14 @@ void FetchOpHandle::RunImpl() { auto &scope = scopes[i]; auto *var = scope->FindVar(kLocalExecScopeName)->Get()->FindVar(var_name); + if (var == nullptr) { + scope->FindVar(var_name); + } + PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", var_name); auto &t = var->Get(); + if (platform::is_gpu_place(t.place())) { #ifdef PADDLE_WITH_CUDA TensorCopySync(t, cpu, &tensors_[i]); diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 9056f5e66..5fbe35e20 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -721,3 +721,83 @@ class TestCRFModel(unittest.TestCase): def test_update_dense_parameter(self): self.check_network_convergence(is_sparse=False) + + +# test fetch op + +import paddle.dataset.flowers as flowers + + +def lenet(data, class_dim): + conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None) + bn1 = fluid.layers.batch_norm(conv1, act='relu') + pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2) + conv2 = fluid.layers.conv2d(pool1, 50, 5, 1, act=None) + bn2 = fluid.layers.batch_norm(conv2, act='relu') + pool2 = fluid.layers.pool2d(bn2, 2, 'max', 2) + + fc1 = fluid.layers.fc(pool2, size=500, act='relu') + fc2 = fluid.layers.fc(fc1, size=class_dim, act='softmax') + + return fc2 + + +class TestFetchOp(unittest.TestCase): + def parallel_exe(self, train_inputs, seed): + main = fluid.Program() + startup = fluid.Program() + startup.random_seed = seed + with fluid.program_guard(main, startup): + data = fluid.layers.data( + name='image', shape=[3, 224, 224], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = lenet(data, class_dim=102) + loss = fluid.layers.cross_entropy(input=out, label=label) + loss = fluid.layers.mean(loss) + + opt = fluid.optimizer.Momentum( + learning_rate=0.1, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + + opt.minimize(loss) + + # TODO(zcd): I found that onece the memory optimizer is open, + # parallel_exe doesn't fetch some variable, such as conv2d_0.b_0@GRAD, conv2d_1.b_0@GRAD. + # fluid.memory_optimize(main) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup) + + feeder = fluid.DataFeeder(place=place, feed_list=[data, label]) + pe = fluid.ParallelExecutor( + use_cuda=True, loss_name=loss.name, main_program=main) + + fetch_list = [] + for data in train_inputs: + all_vars = main.global_block().vars + for k, v in all_vars.iteritems(): + if v.persistable and 'velocity' not in k: + fetch_list.append(k) + + ret = pe.run(fetch_list, feed=feeder.feed(data)) + result = {} + for i in range(len(fetch_list)): + result[fetch_list[i]] = np.sum(ret[i]) + + def test_update_sparse_parameter(self): + tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16) + tst_reader_iter = tst_reader() + + seed = 100 + iters = 4 + train_inputs = [] + for i in range(iters): + train_inputs.append(tst_reader_iter.next()) + + self.parallel_exe(train_inputs, seed) + + +if __name__ == '__main__': + unittest.main() -- GitLab