From f7e39d754c8d479f4c65214c5c530be6dc35bc29 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 19 May 2023 15:25:40 +0800 Subject: [PATCH] add sharidng axis merge --- .../auto_parallel/spmd_rules/common.h | 36 ++++++++++++++----- .../spmd_rules/matmul_spmd_rule.cc | 9 +++-- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index bffccbfc3ab..f6deca94d7b 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -66,30 +66,49 @@ class SPMDRuleBase { std::unordered_map ShardingMergeForTensors( - const std::vector>>& + const std::vector>>& tensor_notation_to_dim_pairs) { std::unordered_map axis_to_dim_map; std::unordered_map dim_to_axis_map; + int64_t merge_dim; for (auto& pair : tensor_notation_to_dim_pairs) { for (int i = 0; i < pair.second.size(); i++) { - auto tensor_axis = pair.first[i]; + auto tensor_axis = pair.first.substr(i, 1); auto mesh_dim = pair.second[i]; if (axis_to_dim_map.count(tensor_axis) == 0) { - axis_to_dim_map.insert({tensor_axis, mesh_dim}); - + merge_dim = mesh_dim; } else { - int64_t merge_dim = ShardingMergeForAxis( + merge_dim = ShardingMergeForAxis( tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); - axis_to_dim_map.insert({tensor_axis, merge_dim}); } + axis_to_dim_map.insert({tensor_axis, merge_dim}); - if (dim_to_axis_map.count(mesh_dim) == 0) { - dim_to_axis_map.insert({tensor_axis, mesh_dim}); + if (dim_to_axis_map.count(merge_dim) == 0) { + dim_to_axis_map.insert({merge_dim, tensor_axis}); + } else { + dim_to_axis_map[merge_dim] += tensor_axis; } } } + + // Resolute "mesh_dim shard by more than one axis" confict. + // Now we just naive pick the first axis naively. + // (TODO) use local cost model to pick the axis with lowest cost(in concern of + // memory or communication or computation). + for (auto& it : dim_to_axis_map) { + if (it.second.size() > 1) { + VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first + << "] are Sharding Multiple Tensor Axis: [" << it.second + << "]. The Axis: [" << it.second[0] << "] is Picked."; + for (int i = 1; i < it.second.size(); i++) { + axis_to_dim_map[it.second.substr(i, 1)] = -1; + } + } + } + + return axis_to_dim_map; } // Rule1: A repicated dimension could be merged by any sharded dimension. @@ -105,6 +124,7 @@ int64_t ShardingMergeForAxis(const std::string axis, } else if (mesh_dim2 == -1) { return mesh_dim1; } else { + // (TODO) local cost model here. PADDLE_THROW( phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " "different mesh dimension [%d] and [%d].", diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index 297c6d5361f..5bc63cd8dd5 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -123,11 +123,10 @@ std::vector MatmulSPMDRule::InferForward( std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1); } // step2.1: Sharding Merge - std::pair*> x_pair(&x_string, - &x_dims_mapping); - std::pair*> y_pair(&y_string, - &y_dims_mapping); - std::vector>> input_pairs; + std::pair> x_pair(x_string, x_dims_mapping); + std::pair> y_pair(y_string, y_dims_mapping); + std::vector>> + input_pairs; input_pairs.push_back(x_pair); input_pairs.push_back(y_pair); auto dim_to_sharding = ShardingMerge(input_pairs); -- GitLab