From d06849305a67d6645699384ae87ec1870e5756e3 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 15 Aug 2018 21:17:14 +0800 Subject: [PATCH] parameter dispather. (#12666) --- paddle/fluid/framework/threadpool.cc | 7 ++ .../distributed/variable_response.cc | 7 +- paddle/fluid/operators/listen_and_serv_op.cc | 5 +- python/paddle/fluid/__init__.py | 2 +- python/paddle/fluid/initializer.py | 1 - .../fluid/tests/unittests/CMakeLists.txt | 4 +- .../fluid/tests/unittests/test_dist_train.py | 17 +++ .../tests/unittests/test_dist_transpiler.py | 50 ++++++--- .../fluid/transpiler/distribute_transpiler.py | 100 +++++++++++++++--- 9 files changed, 162 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index f26f212d4d..18cdca3a65 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -20,6 +20,9 @@ DEFINE_int32(io_threadpool_size, 100, "number of threads used for doing IO, default 100"); +DEFINE_int32(dist_threadpool_size, 0, + "number of threads used for distributed executed."); + namespace paddle { namespace framework { @@ -35,6 +38,10 @@ void ThreadPool::Init() { if (threadpool_.get() == nullptr) { // TODO(Yancey1989): specify the max threads number int num_threads = std::thread::hardware_concurrency(); + if (FLAGS_dist_threadpool_size > 0) { + num_threads = FLAGS_dist_threadpool_size; + VLOG(1) << "set dist_threadpool_size to " << num_threads; + } PADDLE_ENFORCE_GT(num_threads, 0); threadpool_.reset(new ThreadPool(num_threads)); } diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 466bce18af..8e38b3713f 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -190,12 +190,15 @@ bool VariableResponse::ProcSerializedField( #endif } + VLOG(7) << "ProcSerializedField:" << meta_.varname() + << ", type:" << meta_.type() << std::endl; framework::DDim dims = GetDims(meta_.dims()); if (meta_.type() == sendrecv::LOD_TENSOR) { PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { return false; } + return true; } @@ -206,7 +209,9 @@ bool VariableResponse::ProcSerializedField( return true; } - return true; + PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type()); + + return false; } }; // namespace distributed diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index b194807696..f196e18fe1 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -123,8 +123,11 @@ void ListenAndServOp::RunSyncLoop( optimize_prepared.begin(), std::shared_ptr(nullptr)); + // Trainers will get all parameters from pserver in the + // startup program, so we will wait RequestGet first + rpc_service_->SetCond(distributed::kRequestGet); + rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->ResetBarrierCounter(); - while (true) { rpc_service_->Profiler().OneStep(); // Get from multiple trainers, we don't care about the order in which diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 1ae05dec8d..9aac3c7fc1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -122,7 +122,7 @@ def __bootstrap__(): 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - 'cpu_deterministic' + "dist_threadpool_size", 'cpu_deterministic' ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 3f740dd7c5..6dedbae7a6 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -15,7 +15,6 @@ from . import framework import numpy as np import contextlib -from .framework import convert_np_dtype_to_dtype_ from .core import VarDesc __all__ = [ diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a6a911721d..e7dd85ef5c 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -59,8 +59,8 @@ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=$ if(WITH_DISTRIBUTE) py_test_modules(test_dist_train MODULES test_dist_train SERIAL) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) - set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180) - set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180) + set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200) + set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index aab8969a96..55aa923f5a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -26,6 +26,12 @@ from paddle.fluid.layers.io import ListenAndServ from paddle.fluid.layers.io import Recv from paddle.fluid.layers.io import Send +from paddle.fluid import core + +RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( +) +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC + class TestSendOp(unittest.TestCase): def test_send(self): @@ -89,18 +95,29 @@ class TestSendOp(unittest.TestCase): def init_client(self, place, port): main = fluid.Program() with fluid.program_guard(main): + main.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": ["127.0.0.1:{0}".format(port)], + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + x = layers.data( shape=[32, 32], dtype='float32', name='X', append_batch_size=False) fluid.initializer.Constant(value=2.3)(x, main.global_block()) + get_var = main.global_block().create_var( name="scale_0.tmp_0", # server side var dtype="float32", persistable=False, shape=[32, 32]) fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) + Send("127.0.0.1:%d" % port, [x]) o = Recv("127.0.0.1:%d" % port, [get_var]) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 55f8b3eff8..124abf4ccd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -18,6 +18,7 @@ import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops import traceback +import collections class TranspilerTest(unittest.TestCase): @@ -53,9 +54,18 @@ class TranspilerTest(unittest.TestCase): self.origin_prog = main.clone() return main - def get_trainer(self, config=None, sync_mode=True): - t = self._transpiler_instance(config, sync_mode) - return t.get_trainer_program() + def get_trainer(self, config=None): + src = fluid.default_startup_program().clone() + + t = self._transpiler_instance(config) + + trainer_main = t.get_trainer_program() + trainer_startup = fluid.default_startup_program() + + assert (src.num_blocks == 1) + assert (trainer_startup.num_blocks == src.num_blocks) + + return trainer_main, trainer_startup def get_pserver(self, ep, config=None, sync_mode=True): t = self._transpiler_instance(config, sync_mode) @@ -91,7 +101,21 @@ class TestBasicModel(TranspilerTest): pserver, startup = self.get_pserver(self.pserver1_ep) pserver2, startup2 = self.get_pserver(self.pserver2_ep) - trainer = self.get_trainer() + trainer, trainer_startup = self.get_trainer() + + # splited var blocks should be in startup program + self.assertTrue("fc_w.block0" in trainer_startup.global_block().vars) + self.assertTrue("fc_w.block1" in trainer_startup.global_block().vars) + self.assertTrue("fc_w" in trainer_startup.global_block().vars) + self.assertTrue("fc_b" in trainer_startup.global_block().vars) + self.assertTrue("fc_w@GRAD" not in trainer_startup.global_block().vars) + self.assertTrue("fc_b@GRAD" not in trainer_startup.global_block().vars) + + src = [op.type for op in trainer_startup.global_block().ops] + dst = ['fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', \ + 'fetch_barrier', 'concat'] + + self.assertEqual(src, dst) self.assertEqual([op.type for op in trainer.global_block().ops], [ 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', @@ -142,7 +166,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest): pserver, startup = self.get_pserver(self.pserver1_ep, config) pserver2, startup2 = self.get_pserver(self.pserver2_ep, config) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual([op.type for op in trainer.global_block().ops], [ 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', @@ -226,7 +250,7 @@ class TestLRDecay(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 4) lr_decay_ops = [op.type for op in pserver.blocks[1].ops] @@ -256,7 +280,7 @@ class TestLRDecayConditional(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() serv_op = pserver.blocks[0].ops[0] sub_blocks = [] @@ -305,7 +329,7 @@ class TestL2Decay(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 3) self.assertEqual([op.type for op in pserver.blocks[1].ops], @@ -340,7 +364,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): def transpiler_test_impl(self): pserver, startup = self.get_pserver(self.pserver1_ep) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(pserver.blocks), 9) self.assertEqual([op.type for op in pserver.blocks[1].ops], [ @@ -415,7 +439,7 @@ class TestLocalLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sum", "adam", "scale", "scale"]) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', @@ -453,7 +477,7 @@ class TestDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer = self.get_trainer() + trainer, _ = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', @@ -486,7 +510,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["adam", "scale", "scale"]) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', @@ -525,7 +549,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 5 save table self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) - trainer = self.get_trainer(config) + trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index c97beea1b3..ce4709f23b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -195,6 +195,9 @@ class DistributeTranspiler(object): if program is None: program = default_main_program() self.origin_program = program + self.origin_startup_program = default_startup_program().clone() + + self.startup_program = default_startup_program() self.trainer_num = trainers self.sync_mode = sync_mode self.trainer_id = trainer_id @@ -205,10 +208,10 @@ class DistributeTranspiler(object): ps_dispatcher = self.config.split_method(self.pserver_endpoints) self.has_distributed_lookup_table = self._has_distributed_lookup_table() - # split and create vars, then put splited vars in dicts for later use. + # step 1: split and create vars, then put splited vars in dicts for later use. self._init_splited_vars() - # step 3.1: insert send op to send gradient vars to parameter servers + # step 2: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() send_vars = [] @@ -265,7 +268,7 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - # step 3.2: insert recv op to receive parameters from parameter server + # step 3: insert recv op to receive parameters from parameter server recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) @@ -312,6 +315,8 @@ class DistributeTranspiler(object): outputs={"Out": [orig_param]}, attrs={"axis": 0}) + self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) + if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, pserver_endpoints) @@ -328,8 +333,78 @@ class DistributeTranspiler(object): # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? delete_ops(self.origin_program.global_block(), self.optimize_ops) self.origin_program.__str__() + return self.origin_program + def _get_trainer_startup_program(self, + recv_vars, + eplist, + startup_program=None): + """ + Get transpiled trainer side startup program. + + Args: + startup_program(Program): Startup program. + + Returns: + Program: trainer side startup program. + """ + if startup_program is None: + startup_program = self.startup_program + + # FIXME(gongwb): delete not need ops. + # note that: some parameter is not trainable and those ops can't be deleted. + + for varname, splited_var in self.param_var_mapping.iteritems(): + # Get the eplist of recv vars + eps = [] + for var in splited_var: + index = [v.name for v in recv_vars].index(var.name) + eps.append(eplist[index]) + + for var in splited_var: + if startup_program.global_block().has_var(var.name): + continue + + startup_program.global_block().create_var( + name=var.name, + persistable=False, + type=var.type, + dtype=var.dtype, + shape=var.shape, + lod_level=var.lod_level) + + op = startup_program.global_block().append_op( + type="recv", + inputs={}, + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + startup_program.global_block().append_op( + type="fetch_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": self.pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + for varname, splited_var in self.param_var_mapping.iteritems(): + #add concat ops to merge splited parameters received from parameter servers. + if len(splited_var) <= 1: + continue + orig_param = startup_program.global_block().vars[varname] + startup_program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={"axis": 0}) + + return startup_program + def get_pserver_program(self, endpoint): """ Get parameter server side program. @@ -576,14 +651,16 @@ class DistributeTranspiler(object): new_outputs = dict() # do not append startup op if var is not on this pserver op_on_pserver = False - for key in op.output_names: - newname, _ = _get_splited_name_and_shape(op.output(key)[0]) - if newname: - op_on_pserver = True - new_outputs[key] = created_var_map[newname] - elif op.output(key)[0] in pserver_vars: - op_on_pserver = True - new_outputs[key] = pserver_vars[op.output(key)[0]] + # TODO(gongwb): remove this line. + if op.type not in ["recv", "fetch_barrier", "concat"]: + for key in op.output_names: + newname, _ = _get_splited_name_and_shape(op.output(key)[0]) + if newname: + op_on_pserver = True + new_outputs[key] = created_var_map[newname] + elif op.output(key)[0] in pserver_vars: + op_on_pserver = True + new_outputs[key] = pserver_vars[op.output(key)[0]] if op_on_pserver: # most startup program ops have no inputs @@ -1022,7 +1099,6 @@ class DistributeTranspiler(object): var_mapping[varname] = \ [program.global_block().var(orig_var.name)] continue - var_mapping[varname] = [] orig_shape = orig_var.shape orig_dim1_flatten = 1 -- GitLab