提交 f7e39d75 编写于 作者: J JZ-LIANG

add sharidng axis merge

上级 21b5a75a
...@@ -66,30 +66,49 @@ class SPMDRuleBase { ...@@ -66,30 +66,49 @@ class SPMDRuleBase {
std::unordered_map<std::string, int64_t> std::unordered_map<std::string, int64_t>
ShardingMergeForTensors( ShardingMergeForTensors(
const std::vector<std::pair<std::string, const std::vector<int64_t>>>& const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_notation_to_dim_pairs) { tensor_notation_to_dim_pairs) {
std::unordered_map<std::string, int64_t> axis_to_dim_map; std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map; std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim;
for (auto& pair : tensor_notation_to_dim_pairs) { for (auto& pair : tensor_notation_to_dim_pairs) {
for (int i = 0; i < pair.second.size(); i++) { for (int i = 0; i < pair.second.size(); i++) {
auto tensor_axis = pair.first[i]; auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i]; auto mesh_dim = pair.second[i];
if (axis_to_dim_map.count(tensor_axis) == 0) { if (axis_to_dim_map.count(tensor_axis) == 0) {
axis_to_dim_map.insert({tensor_axis, mesh_dim}); merge_dim = mesh_dim;
} else { } else {
int64_t merge_dim = ShardingMergeForAxis( merge_dim = ShardingMergeForAxis(
tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]);
}
axis_to_dim_map.insert({tensor_axis, merge_dim}); axis_to_dim_map.insert({tensor_axis, merge_dim});
if (dim_to_axis_map.count(merge_dim) == 0) {
dim_to_axis_map.insert({merge_dim, tensor_axis});
} else {
dim_to_axis_map[merge_dim] += tensor_axis;
}
}
} }
if (dim_to_axis_map.count(mesh_dim) == 0) { // Resolute "mesh_dim shard by more than one axis" confict.
dim_to_axis_map.insert({tensor_axis, mesh_dim}); // Now we just naive pick the first axis naively.
// (TODO) use local cost model to pick the axis with lowest cost(in concern of
// memory or communication or computation).
for (auto& it : dim_to_axis_map) {
if (it.second.size() > 1) {
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (int i = 1; i < it.second.size(); i++) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
} }
} }
} }
return axis_to_dim_map;
} }
// Rule1: A repicated dimension could be merged by any sharded dimension. // Rule1: A repicated dimension could be merged by any sharded dimension.
...@@ -105,6 +124,7 @@ int64_t ShardingMergeForAxis(const std::string axis, ...@@ -105,6 +124,7 @@ int64_t ShardingMergeForAxis(const std::string axis,
} else if (mesh_dim2 == -1) { } else if (mesh_dim2 == -1) {
return mesh_dim1; return mesh_dim1;
} else { } else {
// (TODO) local cost model here.
PADDLE_THROW( PADDLE_THROW(
phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two "
"different mesh dimension [%d] and [%d].", "different mesh dimension [%d] and [%d].",
......
...@@ -123,11 +123,10 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward( ...@@ -123,11 +123,10 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1); std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1);
} }
// step2.1: Sharding Merge // step2.1: Sharding Merge
std::pair<std::string*, std::vector<int64_t>*> x_pair(&x_string, std::pair<std::string, std::vector<int64_t>> x_pair(x_string, x_dims_mapping);
&x_dims_mapping); std::pair<std::string, std::vector<int64_t>> y_pair(y_string, y_dims_mapping);
std::pair<std::string*, std::vector<int64_t>*> y_pair(&y_string, std::vector<std::pair<const std::string, const std::vector<int64_t>>>
&y_dims_mapping); input_pairs;
std::vector<std::pair<std::string, std::vector<int64_t>>> input_pairs;
input_pairs.push_back(x_pair); input_pairs.push_back(x_pair);
input_pairs.push_back(y_pair); input_pairs.push_back(y_pair);
auto dim_to_sharding = ShardingMerge(input_pairs); auto dim_to_sharding = ShardingMerge(input_pairs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册