未验证 提交 739043c6 编写于 作者: W WangXi 提交者: GitHub

【paddle.fleet】fleet add _get_applied_meta_list and _get_applied_graph_list,...

【paddle.fleet】fleet add _get_applied_meta_list and _get_applied_graph_list, test=develop (#27952) (#28053)
上级 39745d44
......@@ -34,6 +34,8 @@ __all__ = [
fleet = Fleet()
_final_strategy = fleet._final_strategy
_get_applied_meta_list = fleet._get_applied_meta_list
_get_applied_graph_list = fleet._get_applied_graph_list
init = fleet.init
is_first_worker = fleet.is_first_worker
worker_index = fleet.worker_index
......
......@@ -925,6 +925,24 @@ class Fleet(object):
else:
return self._context["valid_strategy"]
def _get_applied_meta_list(self):
if "applied_meta_list" not in self._context:
print(
"WARNING: You may need to call minimize function before _get_applied_meta_list called"
)
return []
else:
return self._context["applied_meta_list"]
def _get_applied_graph_list(self):
if "applied_graph_list" not in self._context:
print(
"WARNING: You may need to call minimize function before _get_applied_graph_list called"
)
return []
else:
return self._context["applied_graph_list"]
def minimize(self,
loss,
startup_program=None,
......@@ -1043,6 +1061,12 @@ class Fleet(object):
context["valid_strategy"] = copy.deepcopy(valid_strategy)
applied_meta_list = self.strategy_compiler._get_applied_meta_list()
applied_graph_list = self.strategy_compiler._get_applied_graph_list()
context['applied_meta_list'] = applied_meta_list
context['applied_graph_list'] = applied_graph_list
self._context = context
self.valid_strategy = valid_strategy
......
......@@ -122,13 +122,19 @@ class StrategyCompiler(StrategyCompilerBase):
def __init__(self):
super(StrategyCompiler, self).__init__()
self._meta_optimizer = None
self._graph_optimizer = None
self._meta_optimizers = []
self._graph_optimizers = []
self._valid_optimizer_list = None
self._user_defined_strategy = None
self._meta_optimizer_candidates = []
self._graph_optimizer_candidates = []
def _get_applied_meta_list(self):
return [type(opt).__name__ for opt in self._meta_optimizers]
def _get_applied_graph_list(self):
return [type(opt).__name__ for opt in self._graph_optimizers]
def _get_valid_strategy(self, dist_strategy, can_not_apply_optimizer_list):
import copy
valid_strategy = copy.deepcopy(dist_strategy)
......@@ -178,8 +184,8 @@ class StrategyCompiler(StrategyCompilerBase):
# and graph_optimizer, the corresponding distributed strategy
# should be updated.
self._meta_optimizers = meta_optimizers
self._graph_optimizers = graph_optimizers
self._meta_optimizers = [] if meta_optimizers is None else meta_optimizers
self._graph_optimizers = [] if graph_optimizers is None else graph_optimizers
return_meta = None if meta_optimizers == None else meta_optimizers[
0]
......
......@@ -72,7 +72,7 @@ class AMPOptimizer(MetaOptimizerBase):
"incr_every_n_steps": 1000,
"decr_every_n_nan_or_inf": 2,
"incr_ratio": 2.0,
"decr_ratio": 8.0,
"decr_ratio": 0.8,
"use_dynamic_loss_scaling": True
}
......
......@@ -133,8 +133,14 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer):
self.set_strategy(strategy, 'amp')
self.set_strategy(strategy, 'recompute')
self.set_strategy(strategy, 'lamb')
self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam')
applied_meta_list = fleet._get_applied_meta_list()
applied_graph_list = fleet._get_applied_graph_list()
print(applied_meta_list, applied_graph_list)
self.assertEqual(len(applied_meta_list), 3)
ops = [op.type for op in avg_cost.block.ops]
outs = [
op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul'
......
......@@ -48,6 +48,9 @@ class TestDistributedStrategyAuto(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
applied_meta_list = fleet._get_applied_meta_list()
print("applied_meta_list: {}".format(applied_meta_list))
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid
paddle.enable_static()
class TestFleetBase(unittest.TestCase):
......@@ -48,5 +49,44 @@ class TestFleetBase(unittest.TestCase):
optimizer.minimize(avg_cost)
class TestFleetBase(unittest.TestCase):
def setUp(self):
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
def test_fleet_get_applied_optimizer(self):
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
fleet.init(is_collective=True)
meta_list = fleet._get_applied_meta_list()
graph_list = fleet._get_applied_graph_list()
# not called minimize function
self.assertEqual(len(meta_list), 0)
self.assertEqual(len(graph_list), 0)
strategy = fleet.DistributedStrategy()
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
meta_list = fleet._get_applied_meta_list()
graph_list = fleet._get_applied_graph_list()
self.assertEqual(len(meta_list), 0)
self.assertEqual(len(graph_list), 1)
if __name__ == "__main__":
unittest.main()
......@@ -16,6 +16,9 @@ import unittest
import paddle
import os
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
paddle.enable_static()
class TestFleetBase(unittest.TestCase):
......@@ -27,7 +30,6 @@ class TestFleetBase(unittest.TestCase):
"127.0.0.1:36001,127.0.0.2:36001"
def test_fleet_init(self):
import paddle.distributed.fleet as fleet
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["POD_IP"] = "127.0.0.1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册