提交 4cd1a2cb 编写于 作者: L liangjianzhong

revise syntax

上级 1dcb80ea
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
std::vector<DistTensorSpec> SPMDRuleBase::InferForward( std::vector<TensorDistAttr> SPMDRuleBase::InferForward(
const std::vector<DistTensorSpec>& input_specs, const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW( PADDLE_THROW(
...@@ -26,7 +26,7 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferForward( ...@@ -26,7 +26,7 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferForward(
"derived class of SPMDRuleBase !")); "derived class of SPMDRuleBase !"));
} }
std::vector<DistTensorSpec> SPMDRuleBase::InferBackward( std::vector<TensorDistAttr> SPMDRuleBase::InferBackward(
const std::vector<DistTensorSpec>& output_specs, const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW( PADDLE_THROW(
...@@ -36,12 +36,12 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferBackward( ...@@ -36,12 +36,12 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferBackward(
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_notation_to_dim_pairs) { tensor_axes_to_dim_pairs) {
std::unordered_map<std::string, int64_t> axis_to_dim_map; std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map; std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim; int64_t merge_dim;
for (auto& pair : tensor_notation_to_dim_pairs) { for (auto& pair : tensor_axes_to_dim_pairs) {
for (int i = 0; i < pair.second.size(); i++) { for (int i = 0; i < pair.second.size(); i++) {
auto tensor_axis = pair.first.substr(i, 1); auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i]; auto mesh_dim = pair.second[i];
...@@ -84,9 +84,9 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors( ...@@ -84,9 +84,9 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
// Rule2: A tensor axis could at most be sharded by one mesh dimension. // Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by // (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.) // multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string axis, int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t mesh_dim1, const int64_t& mesh_dim1,
const int64_t mesh_dim2) { const int64_t& mesh_dim2) {
if (mesh_dim1 != mesh_dim2) { if (mesh_dim1 != mesh_dim2) {
if (mesh_dim1 == -1) { if (mesh_dim1 == -1) {
return mesh_dim2; return mesh_dim2;
...@@ -118,8 +118,8 @@ TensorDistAttr CopyTensorDistAttrForOutput( ...@@ -118,8 +118,8 @@ TensorDistAttr CopyTensorDistAttrForOutput(
} }
std::vector<int64_t> ResoluteOutputPartialDimension( std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& in_axis_to_dim_map, const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& out_axis) { const std::string& tensor_axes) {
std::vector<int64_t> partial_on_dims; std::vector<int64_t> partial_on_dims;
for (auto& it : in_axis_to_dim_map) { for (auto& it : in_axis_to_dim_map) {
......
...@@ -19,7 +19,9 @@ limitations under the License. */ ...@@ -19,7 +19,9 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -29,11 +31,26 @@ class SPMDRuleBase { ...@@ -29,11 +31,26 @@ class SPMDRuleBase {
public: public:
virtual ~SPMDRuleBase() {} virtual ~SPMDRuleBase() {}
virtual std::vector<DistTensorSpec> InferForward( // Merge the DistAttr of input tensors and infer the DistAttr of the output
// tensors from the merged input information. The input are DistAttr and Shape
// (wrapp as DistTensorSpec) of the input tensors (tensors follow the same
// order defined in Op's Phi API) and Op Attribue of the current op. The ouput
// are the Merged DistAttr of input tensors and the infered DistAttr of the
// output tensors. The Merged DistAttr might be different from the original
// Intput DistAttrs, which means that the corressponding input tensor need to
// be reshard.
virtual std::vector<TensorDistAttr> InferForward(
const std::vector<DistTensorSpec>& input_specs, const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs); const paddle::framework::AttributeMap& attrs);
virtual std::vector<DistTensorSpec> InferBackward( // Merge the DistAttr of output tensors and infer the DistAttr of the input
// tensors from the merged output information. The input are DistAttr and
// Shape (wrapp as DistTensorSpec) of the input tensors and Op Attribue of the
// current op. The ouput are the Merged DistAttr of output tensors and the
// infered DistAttr of the input tensors. This function will be use in Static
// Graph mode only, where we have the whole computation graph for sharding
// propogation.
virtual std::vector<TensorDistAttr> InferBackward(
const std::vector<DistTensorSpec>& output_specs, const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs); const paddle::framework::AttributeMap& attrs);
...@@ -44,9 +61,8 @@ class SPMDRuleBase { ...@@ -44,9 +61,8 @@ class SPMDRuleBase {
return PADDLE_GET_CONST(T, GetAttr(name, attrs)); return PADDLE_GET_CONST(T, GetAttr(name, attrs));
} }
virtual const Attribute& GetAttr( const Attribute& GetAttr(const std::string& name,
const std::string& name, const paddle::framework::AttributeMap& attrs) const {
const paddle::framework::AttributeMap& attrs) const {
auto iter = attrs.find(name); auto iter = attrs.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
...@@ -56,23 +72,29 @@ class SPMDRuleBase { ...@@ -56,23 +72,29 @@ class SPMDRuleBase {
} }
}; };
// Merge sharding specification (dims mapping) of given tensors.
// The same axes of different tensors will be merged.
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_notation_to_dim_pairs); tensor_axes_to_dim_pairs);
// Merge the sharding specification (dims mapping) for one tensor Axis.
// Rule1: A repicated dimension could be merged by any sharded dimension. // Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension. // Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by // (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.) // multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string axis, int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t mesh_dim1, const int64_t& mesh_dim1,
const int64_t mesh_dim2); const int64_t& mesh_dim2);
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);
// Resolute the partial mesh dimension of a output tensor, giving the
// merged sharding specifcation of input tensors and the axis names of output
// tensor. Input are
std::vector<int64_t> ResoluteOutputPartialDimension( std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& in_axis_to_dim_map, const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& out_axis); const std::string& tensor_axes);
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
......
...@@ -18,10 +18,18 @@ namespace paddle { ...@@ -18,10 +18,18 @@ namespace paddle {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
std::vector<DistTensorSpec> MatmulSPMDRule::InferForward( std::vector<TensorDistAttr> MatmulSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs, const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on matmul logic // step0: verify input args based on matmul logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of matmul should be 2, but got [%d].",
input_specs_size));
int x_ndim = input_specs[0].shape.size(); int x_ndim = input_specs[0].shape.size();
int y_ndim = input_specs[1].shape.size(); int y_ndim = input_specs[1].shape.size();
std::vector<int64_t> x_dims_mapping = input_specs[0].DistAttr.dims_mapping; std::vector<int64_t> x_dims_mapping = input_specs[0].DistAttr.dims_mapping;
...@@ -42,54 +50,44 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward( ...@@ -42,54 +50,44 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
bool trans_x = ExtractAttr<bool>("trans_x"); bool trans_x = ExtractAttr<bool>("trans_x");
bool trans_y = ExtractAttr<bool>("trans_y"); bool trans_y = ExtractAttr<bool>("trans_y");
auto input_specs_size = input_specs.size() PADDLE_ENFORCE_EQ( // step1: build Einsum Notation
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of matmul should be 2, but got [%d].",
input_specs_size));
// step1: Einsum Notation
int max_ndim = std::max(x_ndim, y_ndim);
// reserve the char k, m, n for matrix product notation: mk,kn -> mn // reserve the char k, m, n for matrix product notation: mk,kn -> mn
int max_ndim = std::max(x_ndim, y_ndim);
std::string alphabet = "abcdefghijlopqrstuvwxyz"; std::string alphabet = "abcdefghijlopqrstuvwxyz";
std::string x_string; std::string x_axes;
std::string y_string; std::string y_axes;
std::string out_string; std::string out_axes;
// Handle 4 different matmul cases in Paddle
// vector * vector = scala // vector * vector = scala
if (x_ndim == 1 && y_ndim == 1) { if (x_ndim == 1 && y_ndim == 1) {
x_string = "k"; x_axes = "k";
y_string = "k"; y_axes = "k";
out_string = ""; out_axes = "";
// vector * batched matrix // vector * batched matrix
} else if (x_ndim == 1 && y_ndim > 1) { } else if (x_ndim == 1 && y_ndim > 1) {
x_string = "k"; x_axes = "k";
std::string y_broadcast_string = std::string y_broadcast_axes = GetBroadcastAxes(y_ndim, max_ndim, alphabet);
GetBroadcastNotationString(y_ndim, max_ndim, alphabet); y_axes = y_broadcast_axes + "kn";
y_string = y_broadcast_string + "kn"; out_axes = y_broadcast_axes + "n";
out_string = y_broadcast_string + "n";
// batched matrix * vector // batched matrix * vector
} else if (x_ndim > 1 && y_ndim == 1) { } else if (x_ndim > 1 && y_ndim == 1) {
y_string = "k"; y_axes = "k";
std::string x_broadcast_string = std::string x_broadcast_axes = GetBroadcastAxes(x_ndim, max_ndim, alphabet);
GetBroadcastNotationString(x_ndim, max_ndim, alphabet); x_axes = x_broadcast_axes + "mk";
x_string = x_broadcast_string + "mk"; out_axes = x_broadcast_axes + "m";
out_string = x_broadcast_string + "m";
// batched matrix * batched matrix // batched matrix * batched matrix
} else if (x_ndim > 1 && y_ndim > 1) { } else if (x_ndim > 1 && y_ndim > 1) {
std::string x_broadcast_string = std::string x_broadcast_axes = GetBroadcastAxes(x_ndim, max_ndim, alphabet);
GetBroadcastNotationString(x_ndim, max_ndim, alphabet); std::string y_broadcast_axes = GetBroadcastAxes(y_ndim, max_ndim, alphabet);
std::string y_broadcast_string = x_axes = x_broadcast_axes + "mk";
GetBroadcastNotationString(y_ndim, max_ndim, alphabet); y_axes = y_broadcast_axes + "kn";
x_string = x_broadcast_string + "mk";
y_string = y_broadcast_string + "kn";
if (x_ndim > y_ndim) { if (x_ndim > y_ndim) {
out_string = x_broadcast_string + "mn"; out_axes = x_broadcast_axes + "mn";
} else { } else {
out_string = y_broadcast_string + "mn"; out_axes = y_broadcast_axes + "mn";
} }
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
...@@ -98,8 +96,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward( ...@@ -98,8 +96,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
y_ndim)); y_ndim));
} }
VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_string << "," VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_axes << ","
<< y_string << " --> " << out_string << "]."; << y_axes << " --> " << out_axes << "].";
// step2: Sharding Propogation // step2: Sharding Propogation
if (trans_x) { if (trans_x) {
...@@ -121,34 +119,34 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward( ...@@ -121,34 +119,34 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1); std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1);
} }
// step2.1: Sharding Merge // step2.1: Sharding Merge
std::pair<std::string, std::vector<int64_t>> x_pair(x_string, x_dims_mapping); std::pair<std::string, std::vector<int64_t>> x_pair(x_axes, x_dims_mapping);
std::pair<std::string, std::vector<int64_t>> y_pair(y_string, y_dims_mapping); std::pair<std::string, std::vector<int64_t>> y_pair(y_axes, y_dims_mapping);
std::vector<std::pair<const std::string, const std::vector<int64_t>>> std::vector<std::pair<const std::string, const std::vector<int64_t>>>
input_pairs; input_pairs;
input_pairs.push_back(x_pair); input_pairs.push_back(x_pair);
input_pairs.push_back(y_pair); input_pairs.push_back(y_pair);
auto axis_to_dim_map = ShardingMergeForTensors(input_pairs); auto axis_to_dim_map = ShardingMergeForTensors(input_pairs);
// step2.2: fill output's dim mapping. // step2.2: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst = TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[0].DistAttr) std::vector<int64_t> CopyTensorDistAttrForOutput(input_specs[0].DistAttr);
out_dims_mapping; std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_string.size()); out_dims_mapping.reserve(out_axes.size());
for (int i = 0; i < out_string.size(); ++i) { for (int i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_string.substr(i, 1)]); out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
} }
output_dist_attr_dst.set_dims_mapping(out_dims_mapping); output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step2.3: fill input's dim mapping. // step2.3: Merge and get Inputs' New Dims Mapping.
TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( TensorDistAttr x_dist_attr_dst = GetInferedDistAttr(
input_specs[0].DistAttr, input_specs[0].shape, x_string, axis_to_dim_map); input_specs[0].DistAttr, input_specs[0].shape, x_axes, axis_to_dim_map);
TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( TensorDistAttr y_dist_attr_dst = GetInferedDistAttr(
input_specs[1].DistAttr, input_specs[1].shape, y_string, axis_to_dim_map); input_specs[1].DistAttr, input_specs[1].shape, y_axes, axis_to_dim_map);
// step2.3: Handle Partial // step2.3: Handle Partial
// Step2.3.1 Output Partial // Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims = std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_string); ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
// Step2.3.2 handle input tensor partial (TODO) // Step2.3.2 handle input tensor partial (TODO)
...@@ -161,6 +159,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward( ...@@ -161,6 +159,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
<< ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping << ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping
<< "; Output dims_mapping: " << out_dims_mapping << "; Output dims_mapping: " << out_dims_mapping
<< ", partial_on_dims: " << partial_on_dims; << ", partial_on_dims: " << partial_on_dims;
return { x_dist_attr_dst, y_dist_attr_dst, output_dist_attr_dst }
} }
TensorDistAttr GetInferedDistAttr( TensorDistAttr GetInferedDistAttr(
...@@ -184,7 +184,7 @@ TensorDistAttr GetInferedDistAttr( ...@@ -184,7 +184,7 @@ TensorDistAttr GetInferedDistAttr(
return dist_attr_; return dist_attr_;
} }
std::vector<DistTensorSpec> MatmulSPMDRule::InferBackward( std::vector<TensorDistAttr> MatmulSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs, const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {} const paddle::framework::AttributeMap& attrs) {}
......
...@@ -28,7 +28,7 @@ namespace auto_parallel { ...@@ -28,7 +28,7 @@ namespace auto_parallel {
TensorDistAttr GetInferedDistAttr( TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr, const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape, const std::vector<int64_t>& shape,
const std::string& tensor_axis, const std::string& tensor_axes,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map); const std::unordered_map<std::string, int64_t>& axis_to_dim_map);
class MatmulSPMDRule : public SPMDRuleBase { class MatmulSPMDRule : public SPMDRuleBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册