From a82671c44e58a9df0fabbc3e224a588d73f17632 Mon Sep 17 00:00:00 2001 From: chengmo Date: Wed, 6 May 2020 20:59:20 +0800 Subject: [PATCH] fix sgd flags --- fleet_rec/core/trainers/cluster_trainer.py | 10 +++++++--- fleet_rec/core/trainers/tdm_cluster_trainer.py | 6 +++++- fleet_rec/core/trainers/tdm_single_trainer.py | 1 + 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/fleet_rec/core/trainers/cluster_trainer.py b/fleet_rec/core/trainers/cluster_trainer.py index 42d8d0f3..0bc1d550 100644 --- a/fleet_rec/core/trainers/cluster_trainer.py +++ b/fleet_rec/core/trainers/cluster_trainer.py @@ -43,7 +43,8 @@ class ClusterTrainer(TranspileTrainer): if envs.get_platform() == "LINUX": self.regist_context_processor('train_pass', self.dataset_train) else: - self.regist_context_processor('train_pass', self.dataloader_train) + self.regist_context_processor( + 'train_pass', self.dataloader_train) self.regist_context_processor('terminal_pass', self.terminal) def build_strategy(self): @@ -70,6 +71,9 @@ class ClusterTrainer(TranspileTrainer): def init(self, context): self.model.train_net() optimizer = self.model.optimizer() + optimizer_name = envs.get_global_env("hyper_parameters.optimizer") + if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']: + os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0 strategy = self.build_strategy() optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(self.model.get_cost_op()) @@ -103,8 +107,8 @@ class ClusterTrainer(TranspileTrainer): program = fluid.compiler.CompiledProgram( fleet.main_program).with_data_parallel( loss_name=self.model.get_cost_op().name, - build_strategy=self.strategy.get_build_strategy(), - exec_strategy=self.strategy.get_execute_strategy()) + build_strategy=self.strategy.get_build_strategy(), + exec_strategy=self.strategy.get_execute_strategy()) metrics_varnames = [] metrics_format = [] diff --git a/fleet_rec/core/trainers/tdm_cluster_trainer.py b/fleet_rec/core/trainers/tdm_cluster_trainer.py index 3dc925ec..dad5ba99 100644 --- a/fleet_rec/core/trainers/tdm_cluster_trainer.py +++ b/fleet_rec/core/trainers/tdm_cluster_trainer.py @@ -75,6 +75,9 @@ class TDMClusterTrainer(TranspileTrainer): def init(self, context): self.model.train_net() optimizer = self.model.optimizer() + optimizer_name = envs.get_global_env("hyper_parameters.optimizer") + if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']: + os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0 strategy = self.build_strategy() optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(self.model.get_cost_op()) @@ -93,6 +96,7 @@ class TDMClusterTrainer(TranspileTrainer): context['status'] = 'trainer_startup_pass' def server(self, context): + namespace = "train.startup" model_path = envs.get_global_env( "cluster.model_path", "", namespace) assert not model_path, "Cluster train must has init_model for TDM" @@ -103,7 +107,7 @@ class TDMClusterTrainer(TranspileTrainer): def trainer_startup(self, context): namespace = "train.startup" load_tree = envs.get_global_env( - "cluster.load_tree", False, namespace) + "cluster.load_tree", True, namespace) self.tree_layer_path = envs.get_global_env( "cluster.tree_layer_path", "", namespace) self.tree_travel_path = envs.get_global_env( diff --git a/fleet_rec/core/trainers/tdm_single_trainer.py b/fleet_rec/core/trainers/tdm_single_trainer.py index 11874645..98d9f790 100644 --- a/fleet_rec/core/trainers/tdm_single_trainer.py +++ b/fleet_rec/core/trainers/tdm_single_trainer.py @@ -1,3 +1,4 @@ +# -*- coding=utf-8 -*- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); -- GitLab