diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index 408f98772268ac5d3ab49c440817debf7eda28f8..f1b8a60a169f9ea4073d49879fbcf8b8d61299f1 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -18,7 +18,7 @@ namespace paddle { namespace distributed { namespace auto_parallel { -std::vector SPMDRuleBase::InferForward( +std::vector SPMDRuleBase::InferForward( const std::vector& input_specs, const paddle::framework::AttributeMap& attrs) { PADDLE_THROW( @@ -26,7 +26,7 @@ std::vector SPMDRuleBase::InferForward( "derived class of SPMDRuleBase !")); } -std::vector SPMDRuleBase::InferBackward( +std::vector SPMDRuleBase::InferBackward( const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { PADDLE_THROW( @@ -36,12 +36,12 @@ std::vector SPMDRuleBase::InferBackward( std::unordered_map ShardingMergeForTensors( const std::vector>>& - tensor_notation_to_dim_pairs) { + tensor_axes_to_dim_pairs) { std::unordered_map axis_to_dim_map; std::unordered_map dim_to_axis_map; int64_t merge_dim; - for (auto& pair : tensor_notation_to_dim_pairs) { + for (auto& pair : tensor_axes_to_dim_pairs) { for (int i = 0; i < pair.second.size(); i++) { auto tensor_axis = pair.first.substr(i, 1); auto mesh_dim = pair.second[i]; @@ -84,9 +84,9 @@ std::unordered_map ShardingMergeForTensors( // Rule2: A tensor axis could at most be sharded by one mesh dimension. // (TODO trigger heuristics cost model and reshard to handle axis sharded by // multiple dimension case.) -int64_t ShardingMergeForAxis(const std::string axis, - const int64_t mesh_dim1, - const int64_t mesh_dim2) { +int64_t ShardingMergeForAxis(const std::string& axis, + const int64_t& mesh_dim1, + const int64_t& mesh_dim2) { if (mesh_dim1 != mesh_dim2) { if (mesh_dim1 == -1) { return mesh_dim2; @@ -118,8 +118,8 @@ TensorDistAttr CopyTensorDistAttrForOutput( } std::vector ResoluteOutputPartialDimension( - const std::unordered_map& in_axis_to_dim_map, - const std::string& out_axis) { + const std::unordered_map& axis_to_dim_map, + const std::string& tensor_axes) { std::vector partial_on_dims; for (auto& it : in_axis_to_dim_map) { diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 90cbffa5b1386c31e85b40a760d921f2534992c8..2465998c9229125a54c49d77de4d86bdeba6707f 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -19,7 +19,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/framework/type_defs.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" namespace paddle { namespace distributed { @@ -29,11 +31,26 @@ class SPMDRuleBase { public: virtual ~SPMDRuleBase() {} - virtual std::vector InferForward( + // Merge the DistAttr of input tensors and infer the DistAttr of the output + // tensors from the merged input information. The input are DistAttr and Shape + // (wrapp as DistTensorSpec) of the input tensors (tensors follow the same + // order defined in Op's Phi API) and Op Attribue of the current op. The ouput + // are the Merged DistAttr of input tensors and the infered DistAttr of the + // output tensors. The Merged DistAttr might be different from the original + // Intput DistAttrs, which means that the corressponding input tensor need to + // be reshard. + virtual std::vector InferForward( const std::vector& input_specs, const paddle::framework::AttributeMap& attrs); - virtual std::vector InferBackward( + // Merge the DistAttr of output tensors and infer the DistAttr of the input + // tensors from the merged output information. The input are DistAttr and + // Shape (wrapp as DistTensorSpec) of the input tensors and Op Attribue of the + // current op. The ouput are the Merged DistAttr of output tensors and the + // infered DistAttr of the input tensors. This function will be use in Static + // Graph mode only, where we have the whole computation graph for sharding + // propogation. + virtual std::vector InferBackward( const std::vector& output_specs, const paddle::framework::AttributeMap& attrs); @@ -44,9 +61,8 @@ class SPMDRuleBase { return PADDLE_GET_CONST(T, GetAttr(name, attrs)); } - virtual const Attribute& GetAttr( - const std::string& name, - const paddle::framework::AttributeMap& attrs) const { + const Attribute& GetAttr(const std::string& name, + const paddle::framework::AttributeMap& attrs) const { auto iter = attrs.find(name); PADDLE_ENFORCE_NE( iter, @@ -56,23 +72,29 @@ class SPMDRuleBase { } }; +// Merge sharding specification (dims mapping) of given tensors. +// The same axes of different tensors will be merged. std::unordered_map ShardingMergeForTensors( const std::vector>>& - tensor_notation_to_dim_pairs); + tensor_axes_to_dim_pairs); +// Merge the sharding specification (dims mapping) for one tensor Axis. // Rule1: A repicated dimension could be merged by any sharded dimension. // Rule2: A tensor axis could at most be sharded by one mesh dimension. // (TODO trigger heuristics cost model and reshard to handle axis sharded by // multiple dimension case.) -int64_t ShardingMergeForAxis(const std::string axis, - const int64_t mesh_dim1, - const int64_t mesh_dim2); +int64_t ShardingMergeForAxis(const std::string& axis, + const int64_t& mesh_dim1, + const int64_t& mesh_dim2); TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); +// Resolute the partial mesh dimension of a output tensor, giving the +// merged sharding specifcation of input tensors and the axis names of output +// tensor. Input are std::vector ResoluteOutputPartialDimension( - const std::unordered_map& in_axis_to_dim_map, - const std::string& out_axis); + const std::unordered_map& axis_to_dim_map, + const std::string& tensor_axes); } // namespace auto_parallel } // namespace distributed diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index 74be7c143ef8f3d5cf33fa2bc0eb5c000691aae7..899548659dfbab861180e64c1d970f04ec877764 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -18,10 +18,18 @@ namespace paddle { namespace distributed { namespace auto_parallel { -std::vector MatmulSPMDRule::InferForward( +std::vector MatmulSPMDRule::InferForward( const std::vector& input_specs, const paddle::framework::AttributeMap& attrs) { // step0: verify input args based on matmul logic + auto input_specs_size = input_specs.size(); + PADDLE_ENFORCE_EQ( + input_specs_size, + 2, + phi::errors::InvalidArgument( + "The size of InputSpec of matmul should be 2, but got [%d].", + input_specs_size)); + int x_ndim = input_specs[0].shape.size(); int y_ndim = input_specs[1].shape.size(); std::vector x_dims_mapping = input_specs[0].DistAttr.dims_mapping; @@ -42,54 +50,44 @@ std::vector MatmulSPMDRule::InferForward( bool trans_x = ExtractAttr("trans_x"); bool trans_y = ExtractAttr("trans_y"); - auto input_specs_size = input_specs.size() PADDLE_ENFORCE_EQ( - input_specs_size, - 2, - phi::errors::InvalidArgument( - "The size of InputSpec of matmul should be 2, but got [%d].", - input_specs_size)); - - // step1: Einsum Notation - int max_ndim = std::max(x_ndim, y_ndim); + // step1: build Einsum Notation // reserve the char k, m, n for matrix product notation: mk,kn -> mn + int max_ndim = std::max(x_ndim, y_ndim); std::string alphabet = "abcdefghijlopqrstuvwxyz"; - std::string x_string; - std::string y_string; - std::string out_string; + std::string x_axes; + std::string y_axes; + std::string out_axes; + // Handle 4 different matmul cases in Paddle // vector * vector = scala if (x_ndim == 1 && y_ndim == 1) { - x_string = "k"; - y_string = "k"; - out_string = ""; + x_axes = "k"; + y_axes = "k"; + out_axes = ""; // vector * batched matrix } else if (x_ndim == 1 && y_ndim > 1) { - x_string = "k"; - std::string y_broadcast_string = - GetBroadcastNotationString(y_ndim, max_ndim, alphabet); - y_string = y_broadcast_string + "kn"; - out_string = y_broadcast_string + "n"; + x_axes = "k"; + std::string y_broadcast_axes = GetBroadcastAxes(y_ndim, max_ndim, alphabet); + y_axes = y_broadcast_axes + "kn"; + out_axes = y_broadcast_axes + "n"; // batched matrix * vector } else if (x_ndim > 1 && y_ndim == 1) { - y_string = "k"; - std::string x_broadcast_string = - GetBroadcastNotationString(x_ndim, max_ndim, alphabet); - x_string = x_broadcast_string + "mk"; - out_string = x_broadcast_string + "m"; + y_axes = "k"; + std::string x_broadcast_axes = GetBroadcastAxes(x_ndim, max_ndim, alphabet); + x_axes = x_broadcast_axes + "mk"; + out_axes = x_broadcast_axes + "m"; // batched matrix * batched matrix } else if (x_ndim > 1 && y_ndim > 1) { - std::string x_broadcast_string = - GetBroadcastNotationString(x_ndim, max_ndim, alphabet); - std::string y_broadcast_string = - GetBroadcastNotationString(y_ndim, max_ndim, alphabet); - x_string = x_broadcast_string + "mk"; - y_string = y_broadcast_string + "kn"; + std::string x_broadcast_axes = GetBroadcastAxes(x_ndim, max_ndim, alphabet); + std::string y_broadcast_axes = GetBroadcastAxes(y_ndim, max_ndim, alphabet); + x_axes = x_broadcast_axes + "mk"; + y_axes = y_broadcast_axes + "kn"; if (x_ndim > y_ndim) { - out_string = x_broadcast_string + "mn"; + out_axes = x_broadcast_axes + "mn"; } else { - out_string = y_broadcast_string + "mn"; + out_axes = y_broadcast_axes + "mn"; } } else { PADDLE_THROW(phi::errors::InvalidArgument( @@ -98,8 +96,8 @@ std::vector MatmulSPMDRule::InferForward( y_ndim)); } - VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_string << "," - << y_string << " --> " << out_string << "]."; + VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_axes << "," + << y_axes << " --> " << out_axes << "]."; // step2: Sharding Propogation if (trans_x) { @@ -121,34 +119,34 @@ std::vector MatmulSPMDRule::InferForward( std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1); } // step2.1: Sharding Merge - std::pair> x_pair(x_string, x_dims_mapping); - std::pair> y_pair(y_string, y_dims_mapping); + std::pair> x_pair(x_axes, x_dims_mapping); + std::pair> y_pair(y_axes, y_dims_mapping); std::vector>> input_pairs; input_pairs.push_back(x_pair); input_pairs.push_back(y_pair); auto axis_to_dim_map = ShardingMergeForTensors(input_pairs); - // step2.2: fill output's dim mapping. + // step2.2: Infer Output's Dims Mapping. TensorDistAttr output_dist_attr_dst = - CopyTensorDistAttrForOutput(input_specs[0].DistAttr) std::vector - out_dims_mapping; - out_dims_mapping.reserve(out_string.size()); - for (int i = 0; i < out_string.size(); ++i) { - out_dims_mapping.push_back(axis_to_dim_map[out_string.substr(i, 1)]); + CopyTensorDistAttrForOutput(input_specs[0].DistAttr); + std::vector out_dims_mapping; + out_dims_mapping.reserve(out_axes.size()); + for (int i = 0; i < out_axes.size(); ++i) { + out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]); } output_dist_attr_dst.set_dims_mapping(out_dims_mapping); - // step2.3: fill input's dim mapping. + // step2.3: Merge and get Inputs' New Dims Mapping. TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( - input_specs[0].DistAttr, input_specs[0].shape, x_string, axis_to_dim_map); + input_specs[0].DistAttr, input_specs[0].shape, x_axes, axis_to_dim_map); TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( - input_specs[1].DistAttr, input_specs[1].shape, y_string, axis_to_dim_map); + input_specs[1].DistAttr, input_specs[1].shape, y_axes, axis_to_dim_map); // step2.3: Handle Partial // Step2.3.1 Output Partial std::vector partial_on_dims = - ResoluteOutputPartialDimension(axis_to_dim_map, out_string); + ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); // Step2.3.2 handle input tensor partial (TODO) @@ -161,6 +159,8 @@ std::vector MatmulSPMDRule::InferForward( << ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping << "; Output dims_mapping: " << out_dims_mapping << ", partial_on_dims: " << partial_on_dims; + + return { x_dist_attr_dst, y_dist_attr_dst, output_dist_attr_dst } } TensorDistAttr GetInferedDistAttr( @@ -184,7 +184,7 @@ TensorDistAttr GetInferedDistAttr( return dist_attr_; } -std::vector MatmulSPMDRule::InferBackward( +std::vector MatmulSPMDRule::InferBackward( const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) {} diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h index 65b0600eae5c24731c197a745ffcc5fbf45dfdae..2afee21c322a206bb554a37799fdf9fd52f61770 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h @@ -28,7 +28,7 @@ namespace auto_parallel { TensorDistAttr GetInferedDistAttr( const TensorDistAttr& origin_dist_attr, const std::vector& shape, - const std::string& tensor_axis, + const std::string& tensor_axes, const std::unordered_map& axis_to_dim_map); class MatmulSPMDRule : public SPMDRuleBase {