From d206582337b7bb110c849a9af2e83549fe704331 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Thu, 7 Mar 2019 10:36:51 +0800 Subject: [PATCH] add parallel graph dist test (#16076) * add parallel graph dist test=develop * update test=develop * update style test=develop --- .../details/parallel_ssa_graph_executor.cc | 7 ++++ .../tests/unittests/test_dist_mnist_pg.py | 40 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 5b8ae8b6770..2afac32437d 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" +#include +#include #include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { @@ -29,6 +31,11 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { auto &g = graphs.back(); g->Set(kGraphVars, new GraphVars(1UL)); g->Set(kGraphDepVars, new GraphDepVars); + auto &stale_ops = + graph->Get>(details::kStaleProgramOpDescs); + g->Erase(details::kStaleProgramOpDescs); + g->Set>(details::kStaleProgramOpDescs, + new std::vector(stale_ops)); } auto op_handles = ir::FilterByNodeWrapper(*graph); diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py new file mode 100644 index 00000000000..d063f8473e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + + +class TestDistMnistNCCL2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "dist_mnist.py", + delta=1, + need_envs={ + "FLAGS_enable_parallel_graph": "1", + "FLAGS_sync_nccl_allreduce": "1" + }) + + +if __name__ == "__main__": + unittest.main() -- GitLab