提交 a82671c4 编写于 作者: C chengmo

fix sgd flags

上级 8a9fc7c6
......@@ -43,7 +43,8 @@ class ClusterTrainer(TranspileTrainer):
if envs.get_platform() == "LINUX":
self.regist_context_processor('train_pass', self.dataset_train)
else:
self.regist_context_processor('train_pass', self.dataloader_train)
self.regist_context_processor(
'train_pass', self.dataloader_train)
self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self):
......@@ -70,6 +71,9 @@ class ClusterTrainer(TranspileTrainer):
def init(self, context):
self.model.train_net()
optimizer = self.model.optimizer()
optimizer_name = envs.get_global_env("hyper_parameters.optimizer")
if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']:
os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0
strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op())
......@@ -103,8 +107,8 @@ class ClusterTrainer(TranspileTrainer):
program = fluid.compiler.CompiledProgram(
fleet.main_program).with_data_parallel(
loss_name=self.model.get_cost_op().name,
build_strategy=self.strategy.get_build_strategy(),
exec_strategy=self.strategy.get_execute_strategy())
build_strategy=self.strategy.get_build_strategy(),
exec_strategy=self.strategy.get_execute_strategy())
metrics_varnames = []
metrics_format = []
......
......@@ -75,6 +75,9 @@ class TDMClusterTrainer(TranspileTrainer):
def init(self, context):
self.model.train_net()
optimizer = self.model.optimizer()
optimizer_name = envs.get_global_env("hyper_parameters.optimizer")
if optimizer_name in ['adam', 'ADAM', 'Adagrad', 'ADAGRAD']:
os.environ["FLAGS_communicator_is_sgd_optimizer"] = 0
strategy = self.build_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(self.model.get_cost_op())
......@@ -93,6 +96,7 @@ class TDMClusterTrainer(TranspileTrainer):
context['status'] = 'trainer_startup_pass'
def server(self, context):
namespace = "train.startup"
model_path = envs.get_global_env(
"cluster.model_path", "", namespace)
assert not model_path, "Cluster train must has init_model for TDM"
......@@ -103,7 +107,7 @@ class TDMClusterTrainer(TranspileTrainer):
def trainer_startup(self, context):
namespace = "train.startup"
load_tree = envs.get_global_env(
"cluster.load_tree", False, namespace)
"cluster.load_tree", True, namespace)
self.tree_layer_path = envs.get_global_env(
"cluster.tree_layer_path", "", namespace)
self.tree_travel_path = envs.get_global_env(
......
# -*- coding=utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册