diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index f1b8a60a169f9ea4073d49879fbcf8b8d61299f1..dc740da070bfb8b46292dec8b83c41a598b05c6e 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" namespace paddle { @@ -42,7 +44,7 @@ std::unordered_map ShardingMergeForTensors( int64_t merge_dim; for (auto& pair : tensor_axes_to_dim_pairs) { - for (int i = 0; i < pair.second.size(); i++) { + for (size_t i = 0; i < pair.second.size(); ++i) { auto tensor_axis = pair.first.substr(i, 1); auto mesh_dim = pair.second[i]; @@ -71,7 +73,7 @@ std::unordered_map ShardingMergeForTensors( VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first << "] are Sharding Multiple Tensor Axis: [" << it.second << "]. The Axis: [" << it.second[0] << "] is Picked."; - for (int i = 1; i < it.second.size(); i++) { + for (size_t i = 1; i < it.second.size(); ++i) { axis_to_dim_map[it.second.substr(i, 1)] = -1; } } @@ -113,7 +115,7 @@ TensorDistAttr CopyTensorDistAttrForOutput( new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); - new_dist_attr.set_annotated(false); + // new_dist_attr.set_annotated(false); TODO unset field is false by default. return new_dist_attr; } @@ -122,8 +124,8 @@ std::vector ResoluteOutputPartialDimension( const std::string& tensor_axes) { std::vector partial_on_dims; - for (auto& it : in_axis_to_dim_map) { - if (out_axis.find(it.first) != std::string::npos) { + for (auto& it : axis_to_dim_map) { + if (tensor_axes.find(it.first) == std::string::npos) { if (it.second > -1) { partial_on_dims.push_back(it.second); } @@ -132,6 +134,20 @@ std::vector ResoluteOutputPartialDimension( return partial_on_dims; } +std::string GetBroadcastAxes(const int64_t& tenosr_ndim, + const int64_t& broadcast_ndim, + const std::string& alphabet) { + PADDLE_ENFORCE_GE( + alphabet.size(), + broadcast_ndim, + phi::errors::InvalidArgument( + "size of alphabet [%d] is less than broadcast ndim [%d]", + alphabet.size(), + broadcast_ndim)); + return alphabet.substr(0, broadcast_ndim) + .substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); +} + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 2465998c9229125a54c49d77de4d86bdeba6707f..41d031106a4c348c278c3589c68baf6dfc142df3 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -19,14 +19,18 @@ limitations under the License. */ #include #include -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" + namespace paddle { namespace distributed { namespace auto_parallel { +using paddle::framework::Attribute; + class SPMDRuleBase { public: virtual ~SPMDRuleBase() {} @@ -64,10 +68,10 @@ class SPMDRuleBase { const Attribute& GetAttr(const std::string& name, const paddle::framework::AttributeMap& attrs) const { auto iter = attrs.find(name); - PADDLE_ENFORCE_NE( - iter, - attrs.end(), - platform::errors::NotFound("(%s) is not found in AttributeMap.")); + PADDLE_ENFORCE_NE(iter, + attrs.end(), + paddle::platform::errors::NotFound( + "(%s) is not found in AttributeMap.")); return iter->second; } }; @@ -96,6 +100,15 @@ std::vector ResoluteOutputPartialDimension( const std::unordered_map& axis_to_dim_map, const std::string& tensor_axes); +// Generate the axis notation of tensor for the einsum notation of a broadcast +// operation(alignment star from the rightmost axis). tenosr_ndim: the size of +// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast +// operation. alphabet: the characters used to represent the axes of tensor. +// length of alphabet should >= broadcast_ndim. +std::string GetBroadcastAxes(const int64_t& tenosr_ndim, + const int64_t& broadcast_ndim, + const std::string& alphabet); + } // namespace auto_parallel } // namespace distributed } // namespace paddle