From 8ee23da846075f902d29e2c6bd10cb27bb0fd489 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Thu, 3 May 2018 10:17:11 -0700 Subject: [PATCH] Fluid new API: dist train without modifying code Works with 1 trainer 1 pserver. 2 trainer 1 pserver will stuck at the end of first step, still investigating. The user only need to set envrionment variables to enable distributed training. run pserver: PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 python no_test_word2vec_new_api.py run trainer: PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_TRAINER_ID=0 python no_test_word2vec_new_api.py --- python/paddle/fluid/trainer.py | 56 +++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 0aada3deb..5385d798e 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import core import framework import executor @@ -20,6 +21,7 @@ import contextlib # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module import optimizer as opt_module +import distribute_transpiler __all__ = [ 'Trainer', @@ -76,22 +78,61 @@ class Trainer(object): raise TypeError( "The optimizer should be an instance of Optimizer") - optimizer.minimize(loss) + optimize_ops, params_grads = optimizer.minimize(loss) self.place = Trainer._check_and_get_place(place) + self.dist_transpile_if_necessary(optimize_ops, params_grads) + # 2. move the default_main_program to self.program and run the # default_startup program on an empty core.Scope() # Run startup program - exe = executor.Executor(place) - exe.run(self.startup_program, scope=self.scope) + with self._prog_and_scope_guard(): + exe = executor.Executor(place) + exe.run(self.startup_program) if param_path: # load params from param_path into scope # TODO(yuyang): This depends on parameters implementation. pass - # TODO(helin): support distributed training + def dist_transpile_if_necessary(self, optimize_ops, params_grads): + if "PADDLE_TRAINING_ROLE" not in os.environ: + return + + # the port of all pservers, needed by both trainer and pserver + port = os.getenv("PADDLE_PSERVER_PORT", "6174") + # comma separated ips of all pservers, needed by trainer and + # pserver + pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "") + eplist = [] + for ip in pserver_ips.split(","): + eplist.append(':'.join([ip, port])) + pserver_endpoints = ",".join(eplist) + # total number of workers/trainers in the job, needed by + # trainer and pserver + trainers = int(os.getenv("PADDLE_TRAINERS")) + # the IP of the local machine, needed by pserver only + current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port + # the unique trainer id, starting from 0, needed by trainer + # only + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + # the role, should be either PSERVER or TRAINER + training_role = os.getenv("PADDLE_TRAINING_ROLE") + with self._prog_and_scope_guard(): + t = distribute_transpiler.DistributeTranspiler() + t.transpile( + trainer_id, pservers=pserver_endpoints, trainers=trainers) + if training_role == "PSERVER": + self.train_program = t.get_pserver_program(current_endpoint) + self.startup_program = t.get_startup_program(current_endpoint, + self.train_program) + elif training_role == "TRAINER": + self.train_program = t.get_trainer_program() + else: + raise ValueError( + 'TRAINING_ROLE environment variable must be either TRAINER or PSERVER' + ) def train(self, num_epochs, @@ -117,6 +158,13 @@ class Trainer(object): raise NotImplementedError( "Parallel Executor version of trainer is not implemented") + training_role = os.getenv("PADDLE_TRAINING_ROLE", "") + if training_role == "PSERVER": + with self._prog_and_scope_guard(): + exe = executor.Executor(self.place) + exe.run() + return + self._train_by_executor(num_epochs, event_handler, reader, feed_order) def test(self, reader): -- GitLab