diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 43f401931224d2c46b78f20107ec161c830a559d..bffccbfc3ab518c17e0357a533b80f20a4364c9a 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -62,7 +62,62 @@ class SPMDRuleBase { platform::errors::NotFound("(%s) is not found in AttributeMap.")); return iter->second; } -}; +} + +std::unordered_map +ShardingMergeForTensors( + const std::vector>>& + tensor_notation_to_dim_pairs) { + std::unordered_map axis_to_dim_map; + std::unordered_map dim_to_axis_map; + + for (auto& pair : tensor_notation_to_dim_pairs) { + for (int i = 0; i < pair.second.size(); i++) { + auto tensor_axis = pair.first[i]; + auto mesh_dim = pair.second[i]; + + if (axis_to_dim_map.count(tensor_axis) == 0) { + axis_to_dim_map.insert({tensor_axis, mesh_dim}); + + } else { + int64_t merge_dim = ShardingMergeForAxis( + tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); + axis_to_dim_map.insert({tensor_axis, merge_dim}); + } + + if (dim_to_axis_map.count(mesh_dim) == 0) { + dim_to_axis_map.insert({tensor_axis, mesh_dim}); + } + } + } +} + +// Rule1: A repicated dimension could be merged by any sharded dimension. +// Rule2: A tensor axis could at most be sharded by one mesh dimension. +// (TODO trigger heuristics cost model and reshard to handle axis sharded by +// multiple dimension case.) +int64_t ShardingMergeForAxis(const std::string axis, + const int64_t mesh_dim1, + const int64_t mesh_dim2) { + if (mesh_dim1 != mesh_dim2) { + if (mesh_dim1 == -1) { + return mesh_dim2; + } else if (mesh_dim2 == -1) { + return mesh_dim1; + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " + "different mesh dimension [%d] and [%d].", + axis, + mesh_dim1, + mesh_dim2)); + } + + } else { + return mesh_dim1; + } +} + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.cc deleted file mode 100644 index e6dad05bfe44cae74eb36aa941d76720d4056e83..0000000000000000000000000000000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.cc +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -std::vector MatmulRule::InferForward( - const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: verify input args based on matmul logic - bool trans_x = attrs.Get("trans_x"); - bool trans_y = attrs.Get("trans_y"); - - // step1: Einsum Notation - // step1.1: generate base notations for each input tensor - - // step1.2: modify input notations base on matmul logic - - // step1.3: generate notations for each output tensor - - // step1.4: final Einsum Notaion - - // step1: Sharding Propogation -} - -std::vector MatmulRule::InferBackward( - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) {} - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc new file mode 100644 index 0000000000000000000000000000000000000000..297c6d5361f7992963a7d388a16c666ae63f24a2 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +std::vector MatmulSPMDRule::InferForward( + const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: verify input args based on matmul logic + int x_ndim = input_specs[0].shape.size(); + int y_ndim = input_specs[1].shape.size(); + std::vector x_dims_mapping = input_specs[0].DistAttr.dims_mapping; + std::vector y_dims_mapping = input_specs[0].DistAttr.dims_mapping; + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size() phi::errors::InvalidArgument( + "Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].", + x_ndim, + x_dims_mapping.size())); + PADDLE_ENFORCE_EQ( + y_ndim, + y_dims_mapping.size() phi::errors::InvalidArgument( + "Mismatch of Y's tensor size: [%d] and Y's dims_mapping size [%d].", + x_ndim, + x_dims_mapping.size())); + + bool trans_x = ExtractAttr("trans_x"); + bool trans_y = ExtractAttr("trans_y"); + + auto input_specs_size = input_specs.size() PADDLE_ENFORCE_EQ( + input_specs_size, + 2, + phi::errors::InvalidArgument( + "The size of InputSpec of matmul should be 2, but got [%d].", + input_specs_size)); + + // step1: Einsum Notation + int max_ndim = std::max(x_ndim, y_ndim); + + // reserve the char k, m, n for matrix product notation: mk,kn -> mn + std::string alphabet = "abcdefghijlopqrstuvwxyz"; + std::string x_string; + std::string y_string; + std::string out_string; + + // vector * vector = scala + if (x_ndim == 1 && y_ndim == 1) { + x_string = "k"; + y_string = "k"; + out_string = ""; + // vector * batched matrix + } else if (x_ndim == 1 && y_ndim > 1) { + x_string = "k"; + std::string y_broadcast_string = + GetBroadcastNotationString(y_ndim, max_ndim, alphabet); + y_string = y_broadcast_string + "kn"; + out_string = y_broadcast_string + "n"; + // batched matrix * vector + } else if (x_ndim > 1 && y_ndim == 1) { + y_string = "k"; + std::string x_broadcast_string = + GetBroadcastNotationString(x_ndim, max_ndim, alphabet); + x_string = x_broadcast_string + "mk"; + out_string = x_broadcast_string + "m"; + // batched matrix * batched matrix + } else if (x_ndim > 1 && y_ndim > 1) { + std::string x_broadcast_string = + GetBroadcastNotationString(x_ndim, max_ndim, alphabet); + std::string y_broadcast_string = + GetBroadcastNotationString(y_ndim, max_ndim, alphabet); + x_string = x_broadcast_string + "mk"; + y_string = y_broadcast_string + "kn"; + + if (x_ndim > y_ndim) { + out_string = x_broadcast_string + "mn"; + } else { + out_string = y_broadcast_string + "mn"; + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "MatmulSPMDRule Receive Unsupported x_dim [%d] and y_dim [%d].", + x_ndim, + y_ndim)); + } + + VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_string << "," + << y_string << " --> " << out_string << "]."; + + // step2: Sharding Propogation + if (trans_x) { + PADDLE_ENFORCE_GT( + x_ndim, + 2, + phi::errors::InvalidArgument("When trans_x is True, the size of X " + "tensor should be 2, but got [%d].", + x_ndim)); + std::iter_swap(x_dims_mapping.end() - 2, x_dims_mapping.end() - 1); + } + if (trans_y) { + PADDLE_ENFORCE_GT( + y_ndim, + 2, + phi::errors::InvalidArgument("When trans_x is True, the size of X " + "tensor should be 2, but got [%d].", + y_ndim)); + std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1); + } + // step2.1: Sharding Merge + std::pair*> x_pair(&x_string, + &x_dims_mapping); + std::pair*> y_pair(&y_string, + &y_dims_mapping); + std::vector>> input_pairs; + input_pairs.push_back(x_pair); + input_pairs.push_back(y_pair); + auto dim_to_sharding = ShardingMerge(input_pairs); + + // step2.3: Handle Broadcast + // step2.3: Handle Partial +} + +std::vector MatmulSPMDRule::InferBackward( + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) {} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle + +/// @brief +// int max_dim = 0; +// int ndim = 0; +// std::vector intput_ndims; +// for (auto& input_spec : input_specs){ +// ndim = input_spec.shape().size(); +// intput_ndims.push_back(ndim); +// if (ndim > max_dim) { +// max_dim = ndim; +// } +// } + +// std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; +// std::vector input_dim_chars; +// for (auto& intput_ndim : intput_ndims){ +// input_dim_chars.push_back(alphabet.substr(max_dim - intput_ndim, +// intput_ndim)); +// } + +// int max_dim = 0; +// int ndim = 0; +// std::vector intput_ndims; +// for (auto& input_spec : input_specs){ +// ndim = input_spec.shape().size(); +// intput_ndims.push_back(ndim); +// if (ndim > max_dim) { +// max_dim = ndim; +// } +// } + +// std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; +// std::vector input_dim_chars; +// for (auto& intput_ndim : intput_ndims){ +// input_dim_chars.push_back(alphabet.substr(max_dim - intput_ndim, +// intput_ndim)); +// } diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h similarity index 96% rename from paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.h rename to paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h index 8832aac420a17a759b9d3c9a40ef9682bec4d376..86ec9a3992a4927e84324ec752bd1a04c10d4a13 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h @@ -25,7 +25,7 @@ namespace paddle { namespace distributed { namespace auto_parallel { -class MatmulRule : public SPMDRuleBase { +class MatmulSPMDRule : public SPMDRuleBase { public: std::vector InferForward( const std::vector& input_specs,