From 180edcc2be9957f56e1537a636ad3d205f3f9cee Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Tue, 23 May 2023 16:12:24 +0800 Subject: [PATCH] matmul main logic done --- .../auto_parallel/spmd_rules/common.h | 88 +++-------------- .../spmd_rules/matmul_spmd_rule.cc | 94 +++++++++++-------- .../spmd_rules/matmul_spmd_rule.h | 6 ++ 3 files changed, 71 insertions(+), 117 deletions(-) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index f6deca94d7b..90cbffa5b13 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -31,19 +31,11 @@ class SPMDRuleBase { virtual std::vector InferForward( const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW( - phi::errors::Unimplemented("InferForward should be called from a " - "derived class of SPMDRuleBase !")); - } + const paddle::framework::AttributeMap& attrs); virtual std::vector InferBackward( const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW( - phi::errors::Unimplemented("InferBackward should be called from a " - "derived class of SPMDRuleBase !")); - } + const paddle::framework::AttributeMap& attrs); template inline const T& ExtractAttr( @@ -62,54 +54,11 @@ class SPMDRuleBase { platform::errors::NotFound("(%s) is not found in AttributeMap.")); return iter->second; } -} +}; -std::unordered_map -ShardingMergeForTensors( +std::unordered_map ShardingMergeForTensors( const std::vector>>& - tensor_notation_to_dim_pairs) { - std::unordered_map axis_to_dim_map; - std::unordered_map dim_to_axis_map; - int64_t merge_dim; - - for (auto& pair : tensor_notation_to_dim_pairs) { - for (int i = 0; i < pair.second.size(); i++) { - auto tensor_axis = pair.first.substr(i, 1); - auto mesh_dim = pair.second[i]; - - if (axis_to_dim_map.count(tensor_axis) == 0) { - merge_dim = mesh_dim; - } else { - merge_dim = ShardingMergeForAxis( - tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); - } - axis_to_dim_map.insert({tensor_axis, merge_dim}); - - if (dim_to_axis_map.count(merge_dim) == 0) { - dim_to_axis_map.insert({merge_dim, tensor_axis}); - } else { - dim_to_axis_map[merge_dim] += tensor_axis; - } - } - } - - // Resolute "mesh_dim shard by more than one axis" confict. - // Now we just naive pick the first axis naively. - // (TODO) use local cost model to pick the axis with lowest cost(in concern of - // memory or communication or computation). - for (auto& it : dim_to_axis_map) { - if (it.second.size() > 1) { - VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first - << "] are Sharding Multiple Tensor Axis: [" << it.second - << "]. The Axis: [" << it.second[0] << "] is Picked."; - for (int i = 1; i < it.second.size(); i++) { - axis_to_dim_map[it.second.substr(i, 1)] = -1; - } - } - } - - return axis_to_dim_map; -} + tensor_notation_to_dim_pairs); // Rule1: A repicated dimension could be merged by any sharded dimension. // Rule2: A tensor axis could at most be sharded by one mesh dimension. @@ -117,26 +66,13 @@ ShardingMergeForTensors( // multiple dimension case.) int64_t ShardingMergeForAxis(const std::string axis, const int64_t mesh_dim1, - const int64_t mesh_dim2) { - if (mesh_dim1 != mesh_dim2) { - if (mesh_dim1 == -1) { - return mesh_dim2; - } else if (mesh_dim2 == -1) { - return mesh_dim1; - } else { - // (TODO) local cost model here. - PADDLE_THROW( - phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " - "different mesh dimension [%d] and [%d].", - axis, - mesh_dim1, - mesh_dim2)); - } - - } else { - return mesh_dim1; - } -} + const int64_t mesh_dim2); + +TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); + +std::vector ResoluteOutputPartialDimension( + const std::unordered_map& in_axis_to_dim_map, + const std::string& out_axis); } // namespace auto_parallel } // namespace distributed diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index 5bc63cd8dd5..9d6a973323e 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" namespace paddle { @@ -129,10 +127,61 @@ std::vector MatmulSPMDRule::InferForward( input_pairs; input_pairs.push_back(x_pair); input_pairs.push_back(y_pair); - auto dim_to_sharding = ShardingMerge(input_pairs); + auto axis_to_dim_map = ShardingMergeForTensors(input_pairs); + + // step2.2: fill output's dim mapping. + TensorDistAttr output_dist_attr_dst = + CopyTensorDistAttrForOutput(input_specs[0].DistAttr) std::vector + out_dims_mapping; + out_dims_mapping.reserve(out_string.size()); + for (int i = 0; i < out_string.size(); ++i) { + out_dims_mapping.push_back(axis_to_dim_map[out_string.substr(i, 1)]); + } + output_dist_attr_dst.set_dims_mapping(out_dims_mapping); + + // step2.3: fill input's dim mapping. + TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( + input_specs[0].DistAttr, input_specs[0].shape, x_string, axis_to_dim_map); + TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( + input_specs[1].DistAttr, input_specs[1].shape, y_string, axis_to_dim_map); - // step2.3: Handle Broadcast // step2.3: Handle Partial + // Step2.3.1 Output Partial + std::vector partial_on_dims = + ResoluteOutputPartialDimension(axis_to_dim_map, out_string); + + // Step2.3.2 handle input tensor partial (TODO) + + VLOG(4) << "MatmulSPMDRule InferForward: " + << "X shape: " << input_specs[0].shape + << ", src_dims_mapping: " << x_dims_mapping + << ", dst_dims_mapping: " << x_dist_attr_dst.dims_mapping + << "; Y shape: " << input_specs[1].shape + << ", src_dims_mapping: " << x_dims_mapping + << ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping + << "; Output dims_mapping: " << out_dims_mapping + << ", partial_on_dims: " << partial_on_dims; +} + +TensorDistAttr GetInferedDistAttr( + const TensorDistAttr& origin_dist_attr, + const std::vector& shape, + const std::string& tensor_axis, + const std::unordered_map& axis_to_dim_map) { + TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr); + std::vector infered_dims_mapping; + infered_dims_mapping.reserve(tensor_string.size()); + + for (int i = 0; i < tensor_axis.size(); ++i) { + if (shape.size() > i && shape[i] == 1) { + infered_dims_mapping.push_back(-1); + } else { + infered_dims_mapping.push_back(axis_to_dim_map[tensor_axis.substr(i, 1)]); + } + } + + dist_attr_.set_dims_mapping(infered_dims_mapping); + return dist_attr_; } std::vector MatmulSPMDRule::InferBackward( @@ -142,40 +191,3 @@ std::vector MatmulSPMDRule::InferBackward( } // namespace auto_parallel } // namespace distributed } // namespace paddle - -/// @brief -// int max_dim = 0; -// int ndim = 0; -// std::vector intput_ndims; -// for (auto& input_spec : input_specs){ -// ndim = input_spec.shape().size(); -// intput_ndims.push_back(ndim); -// if (ndim > max_dim) { -// max_dim = ndim; -// } -// } - -// std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; -// std::vector input_dim_chars; -// for (auto& intput_ndim : intput_ndims){ -// input_dim_chars.push_back(alphabet.substr(max_dim - intput_ndim, -// intput_ndim)); -// } - -// int max_dim = 0; -// int ndim = 0; -// std::vector intput_ndims; -// for (auto& input_spec : input_specs){ -// ndim = input_spec.shape().size(); -// intput_ndims.push_back(ndim); -// if (ndim > max_dim) { -// max_dim = ndim; -// } -// } - -// std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; -// std::vector input_dim_chars; -// for (auto& intput_ndim : intput_ndims){ -// input_dim_chars.push_back(alphabet.substr(max_dim - intput_ndim, -// intput_ndim)); -// } diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h index 86ec9a3992a..89677e8e4e2 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h @@ -25,6 +25,12 @@ namespace paddle { namespace distributed { namespace auto_parallel { +TensorDistAttr GetInferedDistAttr( + const TensorDistAttr& origin_dist_attr, + const std::vector& shape, + const std::string& tensor_axis, + const std::unordered_map& axis_to_dim_map); + class MatmulSPMDRule : public SPMDRuleBase { public: std::vector InferForward( -- GitLab