提交 a82671c4 编写于 作者: C chengmo

fix sgd flags

上级 8a9fc7c6
...@@ -43,7 +43,8 @@ class ClusterTrainer(TranspileTrainer): ...@@ -43,7 +43,8 @@ class ClusterTrainer(TranspileTrainer):
if envs.get_platform() == "LINUX": if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train) self.regist_context_processor('train_pass', self.dataset_train)
else: else:
self.regist_context_processor('train_pass', self.dataloader_train) self.regist_context_processor(
'train_pass', self.dataloader_train)
self.regist_context_processor('terminal_pass', self.terminal) self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self): def build_strategy(self):
...@@ -70,6 +71,9 @@ class ClusterTrainer(TranspileTrainer): ...@@ -70,6 +71,9 @@ class ClusterTrainer(TranspileTrainer):
def init(self, context): def init(self, context):
self.model.train_net() self.model.train_net()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
optimizer_name = envs.get_global_env("hyper_parameters.optimizer")
if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']:
os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0
strategy = self.build_strategy() strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op()) optimizer.minimize(self.model.get_cost_op())
...@@ -103,8 +107,8 @@ class ClusterTrainer(TranspileTrainer): ...@@ -103,8 +107,8 @@ class ClusterTrainer(TranspileTrainer):
program = fluid.compiler.CompiledProgram( program = fluid.compiler.CompiledProgram(
fleet.main_program).with_data_parallel( fleet.main_program).with_data_parallel(
loss_name=self.model.get_cost_op().name, loss_name=self.model.get_cost_op().name,
build_strategy=self.strategy.get_build_strategy(), build_strategy=self.strategy.get_build_strategy(),
exec_strategy=self.strategy.get_execute_strategy()) exec_strategy=self.strategy.get_execute_strategy())
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
......
...@@ -75,6 +75,9 @@ class TDMClusterTrainer(TranspileTrainer): ...@@ -75,6 +75,9 @@ class TDMClusterTrainer(TranspileTrainer):
def init(self, context): def init(self, context):
self.model.train_net() self.model.train_net()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
optimizer_name = envs.get_global_env("hyper_parameters.optimizer")
if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']:
os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0
strategy = self.build_strategy() strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op()) optimizer.minimize(self.model.get_cost_op())
...@@ -93,6 +96,7 @@ class TDMClusterTrainer(TranspileTrainer): ...@@ -93,6 +96,7 @@ class TDMClusterTrainer(TranspileTrainer):
context['status'] = 'trainer_startup_pass' context['status'] = 'trainer_startup_pass'
def server(self, context): def server(self, context):
namespace = "train.startup"
model_path = envs.get_global_env( model_path = envs.get_global_env(
"cluster.model_path", "", namespace) "cluster.model_path", "", namespace)
assert not model_path, "Cluster train must has init_model for TDM" assert not model_path, "Cluster train must has init_model for TDM"
...@@ -103,7 +107,7 @@ class TDMClusterTrainer(TranspileTrainer): ...@@ -103,7 +107,7 @@ class TDMClusterTrainer(TranspileTrainer):
def trainer_startup(self, context): def trainer_startup(self, context):
namespace = "train.startup" namespace = "train.startup"
load_tree = envs.get_global_env( load_tree = envs.get_global_env(
"cluster.load_tree", False, namespace) "cluster.load_tree", True, namespace)
self.tree_layer_path = envs.get_global_env( self.tree_layer_path = envs.get_global_env(
"cluster.tree_layer_path", "", namespace) "cluster.tree_layer_path", "", namespace)
self.tree_travel_path = envs.get_global_env( self.tree_travel_path = envs.get_global_env(
......
# -*- coding=utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册