diff --git a/python/paddle/distributed/fleet/base/strategy_compiler.py b/python/paddle/distributed/fleet/base/strategy_compiler.py index 29e10661888f8a7fd6e3c40ee356aad326c193a9..d598dd8ed4bbdd2d79268aae595a70d5f7209e1f 100644 --- a/python/paddle/distributed/fleet/base/strategy_compiler.py +++ b/python/paddle/distributed/fleet/base/strategy_compiler.py @@ -13,24 +13,95 @@ # limitations under the License. -def maximum_path_len_algo(optimizer_list): - max_idx = 0 - max_len = 0 - candidates = [] - for idx, opt in enumerate(optimizer_list): - local_buffer = [opt] - for opt_inner in optimizer_list: +def create_graph(optimizer_list): + nsize = len(optimizer_list) + + edge = [[0] * nsize for _ in range(nsize)] # adjacency matrix + indegree = [0] * nsize + for i, opt in enumerate(optimizer_list): + for j, opt_inner in enumerate(optimizer_list): if opt._can_update(opt_inner): - local_buffer.append(opt_inner) - if len(local_buffer) > max_len: - max_idx = idx - max_len = len(local_buffer) - candidates.append(local_buffer) - if len(candidates) == 0: + edge[i][j] = 1 # weight + indegree[j] += 1 + + return edge, indegree + + +def topo_sort(edge, indegree): + nsize = len(indegree) + + topo = [-1] * nsize + for i in range(nsize): + j = 0 + while j < nsize and indegree[j] != 0: + j += 1 + assert j < nsize, 'The combination of meta optimizers contains ring' + + topo[i] = j + indegree[j] = -1 + for k in range(nsize): + if edge[j][k] != 0: + indegree[k] -= 1 + + return topo + + +def floyd(edge): + nsize = len(edge) + max_len = -1 + max_edge = [-1, -1] + + max_path = [[[] for _ in range(nsize)] for _ in range(nsize)] + for i in range(nsize): + for j in range(nsize): + if edge[i][j] > 0: + max_path[i][j] = [j] + + if edge[i][j] > max_len: + max_len = edge[i][j] + max_edge = [i, j] + + # use floyd algorithm to find max_path + for k in range(nsize): + for i in range(nsize): + for j in range(nsize): + # if a-->b-->c, but a-/->c, can only apply a-->b or b-->c, + # however if a-->b-->c, and a-->c, can apply a->b->c + if edge[i][j] == 0: + continue + + if edge[i][k] == 0 or edge[k][j] == 0: + continue + + if edge[i][j] < edge[i][k] + edge[k][j]: + edge[i][j] = edge[i][k] + edge[k][j] + max_path[i][j] = max_path[i][k] + max_path[k][j] + + max_len = edge[i][j] + max_edge = [i, j] + + if max_len == -1: + return [0] + + return [max_edge[0]] + max_path[max_edge[0]][max_edge[1]] + + +def maximum_path_len_algo(optimizer_list): + if len(optimizer_list) == 0: return None - for idx, opt in enumerate(candidates[max_idx][:-1]): - opt._update_inner_optimizer(candidates[max_idx][idx + 1]) - return candidates[max_idx] + + edge, indegree = create_graph(optimizer_list) + topo_sort(edge, indegree) + max_path = floyd(edge) + + candidate = [] + for idx in max_path: + candidate.append(optimizer_list[idx]) + + for idx, opt in enumerate(candidate[:-1]): + opt._update_inner_optimizer(candidate[idx + 1]) + + return candidate class StrategyCompilerBase(object): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py index 6bc1a310d0aea0b5e7af0b5536fad8e4403d892f..b4112e88860cd21ebd96400079c51e1676af8f63 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py @@ -103,6 +103,51 @@ class TestFleetAMPOptimizer(TestFleetMetaOptimizer): # recompute self.assertIn('subprog', ''.join(outs)) + def test_amp_recompute_lars_optimizer(self): + """ test amp + recompute """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'amp') + self.set_strategy(strategy, 'recompute') + self.set_strategy(strategy, 'lars') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + strategy = fleet._final_strategy() + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + # recompute + self.assertIn('subprog', ''.join(outs)) + + # lars + self.assertIn('lars_momentum', ops) + + def test_amp_recompute_lamb_optimizer(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'amp') + self.set_strategy(strategy, 'recompute') + self.set_strategy(strategy, 'lamb') + self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam') + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + # recompute + self.assertIn('subprog', ''.join(outs)) + + # lamb + self.assertIn('lamb', ops) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py index 0faafd76a799d038c175e8ce5758f77374bfd37e..3a64c1818ccc6aedcdac1b6218af5396b98c1ab9 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py @@ -128,6 +128,36 @@ class TestFleetDGCOptimizer(TestFleetMetaOptimizer): # recompute self.assertIn('subprog', ''.join(outs)) + def test_amp_recompute_lars_dgc_not_apply_optimizer(self): + """ test amp + recompute + lars + dgc, + amp -/-> dgc, max_path is amp-->recompute-->lars + """ + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'dgc') + self.set_strategy(strategy, 'amp') + self.set_strategy(strategy, 'recompute') + self.set_strategy(strategy, 'lars') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + strategy = fleet._final_strategy() + + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' + ] + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + # recompute + self.assertIn('subprog', ''.join(outs)) + + # lars + self.assertIn('lars_momentum', ops) + + # dgc not apply + self.assertFalse(strategy.dgc) + if __name__ == "__main__": unittest.main()