提交 4cd1a2cb 编写于 作者: L liangjianzhong

revise syntax

上级 1dcb80ea
......@@ -18,7 +18,7 @@ namespace paddle {
namespace distributed {
namespace auto_parallel {
std::vector<DistTensorSpec> SPMDRuleBase::InferForward(
std::vector<TensorDistAttr> SPMDRuleBase::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
......@@ -26,7 +26,7 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferForward(
"derived class of SPMDRuleBase !"));
}
std::vector<DistTensorSpec> SPMDRuleBase::InferBackward(
std::vector<TensorDistAttr> SPMDRuleBase::InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
......@@ -36,12 +36,12 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferBackward(
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_notation_to_dim_pairs) {
tensor_axes_to_dim_pairs) {
std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim;
for (auto& pair : tensor_notation_to_dim_pairs) {
for (auto& pair : tensor_axes_to_dim_pairs) {
for (int i = 0; i < pair.second.size(); i++) {
auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i];
......@@ -84,9 +84,9 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string axis,
const int64_t mesh_dim1,
const int64_t mesh_dim2) {
int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2) {
if (mesh_dim1 != mesh_dim2) {
if (mesh_dim1 == -1) {
return mesh_dim2;
......@@ -118,8 +118,8 @@ TensorDistAttr CopyTensorDistAttrForOutput(
}
std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& in_axis_to_dim_map,
const std::string& out_axis) {
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes) {
std::vector<int64_t> partial_on_dims;
for (auto& it : in_axis_to_dim_map) {
......
......@@ -19,7 +19,9 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle {
namespace distributed {
......@@ -29,11 +31,26 @@ class SPMDRuleBase {
public:
virtual ~SPMDRuleBase() {}
virtual std::vector<DistTensorSpec> InferForward(
// Merge the DistAttr of input tensors and infer the DistAttr of the output
// tensors from the merged input information. The input are DistAttr and Shape
// (wrapp as DistTensorSpec) of the input tensors (tensors follow the same
// order defined in Op's Phi API) and Op Attribue of the current op. The ouput
// are the Merged DistAttr of input tensors and the infered DistAttr of the
// output tensors. The Merged DistAttr might be different from the original
// Intput DistAttrs, which means that the corressponding input tensor need to
// be reshard.
virtual std::vector<TensorDistAttr> InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs);
virtual std::vector<DistTensorSpec> InferBackward(
// Merge the DistAttr of output tensors and infer the DistAttr of the input
// tensors from the merged output information. The input are DistAttr and
// Shape (wrapp as DistTensorSpec) of the input tensors and Op Attribue of the
// current op. The ouput are the Merged DistAttr of output tensors and the
// infered DistAttr of the input tensors. This function will be use in Static
// Graph mode only, where we have the whole computation graph for sharding
// propogation.
virtual std::vector<TensorDistAttr> InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs);
......@@ -44,8 +61,7 @@ class SPMDRuleBase {
return PADDLE_GET_CONST(T, GetAttr(name, attrs));
}
virtual const Attribute& GetAttr(
const std::string& name,
const Attribute& GetAttr(const std::string& name,
const paddle::framework::AttributeMap& attrs) const {
auto iter = attrs.find(name);
PADDLE_ENFORCE_NE(
......@@ -56,23 +72,29 @@ class SPMDRuleBase {
}
};
// Merge sharding specification (dims mapping) of given tensors.
// The same axes of different tensors will be merged.
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_notation_to_dim_pairs);
tensor_axes_to_dim_pairs);
// Merge the sharding specification (dims mapping) for one tensor Axis.
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string axis,
const int64_t mesh_dim1,
const int64_t mesh_dim2);
int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2);
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);
// Resolute the partial mesh dimension of a output tensor, giving the
// merged sharding specifcation of input tensors and the axis names of output
// tensor. Input are
std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& in_axis_to_dim_map,
const std::string& out_axis);
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes);
} // namespace auto_parallel
} // namespace distributed
......
......@@ -18,10 +18,18 @@ namespace paddle {
namespace distributed {
namespace auto_parallel {
std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
std::vector<TensorDistAttr> MatmulSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on matmul logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of matmul should be 2, but got [%d].",
input_specs_size));
int x_ndim = input_specs[0].shape.size();
int y_ndim = input_specs[1].shape.size();
std::vector<int64_t> x_dims_mapping = input_specs[0].DistAttr.dims_mapping;
......@@ -42,54 +50,44 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
bool trans_x = ExtractAttr<bool>("trans_x");
bool trans_y = ExtractAttr<bool>("trans_y");
auto input_specs_size = input_specs.size() PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of matmul should be 2, but got [%d].",
input_specs_size));
// step1: Einsum Notation
int max_ndim = std::max(x_ndim, y_ndim);
// step1: build Einsum Notation
// reserve the char k, m, n for matrix product notation: mk,kn -> mn
int max_ndim = std::max(x_ndim, y_ndim);
std::string alphabet = "abcdefghijlopqrstuvwxyz";
std::string x_string;
std::string y_string;
std::string out_string;
std::string x_axes;
std::string y_axes;
std::string out_axes;
// Handle 4 different matmul cases in Paddle
// vector * vector = scala
if (x_ndim == 1 && y_ndim == 1) {
x_string = "k";
y_string = "k";
out_string = "";
x_axes = "k";
y_axes = "k";
out_axes = "";
// vector * batched matrix
} else if (x_ndim == 1 && y_ndim > 1) {
x_string = "k";
std::string y_broadcast_string =
GetBroadcastNotationString(y_ndim, max_ndim, alphabet);
y_string = y_broadcast_string + "kn";
out_string = y_broadcast_string + "n";
x_axes = "k";
std::string y_broadcast_axes = GetBroadcastAxes(y_ndim, max_ndim, alphabet);
y_axes = y_broadcast_axes + "kn";
out_axes = y_broadcast_axes + "n";
// batched matrix * vector
} else if (x_ndim > 1 && y_ndim == 1) {
y_string = "k";
std::string x_broadcast_string =
GetBroadcastNotationString(x_ndim, max_ndim, alphabet);
x_string = x_broadcast_string + "mk";
out_string = x_broadcast_string + "m";
y_axes = "k";
std::string x_broadcast_axes = GetBroadcastAxes(x_ndim, max_ndim, alphabet);
x_axes = x_broadcast_axes + "mk";
out_axes = x_broadcast_axes + "m";
// batched matrix * batched matrix
} else if (x_ndim > 1 && y_ndim > 1) {
std::string x_broadcast_string =
GetBroadcastNotationString(x_ndim, max_ndim, alphabet);
std::string y_broadcast_string =
GetBroadcastNotationString(y_ndim, max_ndim, alphabet);
x_string = x_broadcast_string + "mk";
y_string = y_broadcast_string + "kn";
std::string x_broadcast_axes = GetBroadcastAxes(x_ndim, max_ndim, alphabet);
std::string y_broadcast_axes = GetBroadcastAxes(y_ndim, max_ndim, alphabet);
x_axes = x_broadcast_axes + "mk";
y_axes = y_broadcast_axes + "kn";
if (x_ndim > y_ndim) {
out_string = x_broadcast_string + "mn";
out_axes = x_broadcast_axes + "mn";
} else {
out_string = y_broadcast_string + "mn";
out_axes = y_broadcast_axes + "mn";
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
......@@ -98,8 +96,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
y_ndim));
}
VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_string << ","
<< y_string << " --> " << out_string << "].";
VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_axes << ","
<< y_axes << " --> " << out_axes << "].";
// step2: Sharding Propogation
if (trans_x) {
......@@ -121,34 +119,34 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1);
}
// step2.1: Sharding Merge
std::pair<std::string, std::vector<int64_t>> x_pair(x_string, x_dims_mapping);
std::pair<std::string, std::vector<int64_t>> y_pair(y_string, y_dims_mapping);
std::pair<std::string, std::vector<int64_t>> x_pair(x_axes, x_dims_mapping);
std::pair<std::string, std::vector<int64_t>> y_pair(y_axes, y_dims_mapping);
std::vector<std::pair<const std::string, const std::vector<int64_t>>>
input_pairs;
input_pairs.push_back(x_pair);
input_pairs.push_back(y_pair);
auto axis_to_dim_map = ShardingMergeForTensors(input_pairs);
// step2.2: fill output's dim mapping.
// step2.2: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[0].DistAttr) std::vector<int64_t>
out_dims_mapping;
out_dims_mapping.reserve(out_string.size());
for (int i = 0; i < out_string.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_string.substr(i, 1)]);
CopyTensorDistAttrForOutput(input_specs[0].DistAttr);
std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size());
for (int i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step2.3: fill input's dim mapping.
// step2.3: Merge and get Inputs' New Dims Mapping.
TensorDistAttr x_dist_attr_dst = GetInferedDistAttr(
input_specs[0].DistAttr, input_specs[0].shape, x_string, axis_to_dim_map);
input_specs[0].DistAttr, input_specs[0].shape, x_axes, axis_to_dim_map);
TensorDistAttr y_dist_attr_dst = GetInferedDistAttr(
input_specs[1].DistAttr, input_specs[1].shape, y_string, axis_to_dim_map);
input_specs[1].DistAttr, input_specs[1].shape, y_axes, axis_to_dim_map);
// step2.3: Handle Partial
// Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_string);
ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
// Step2.3.2 handle input tensor partial (TODO)
......@@ -161,6 +159,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
<< ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping
<< "; Output dims_mapping: " << out_dims_mapping
<< ", partial_on_dims: " << partial_on_dims;
return { x_dist_attr_dst, y_dist_attr_dst, output_dist_attr_dst }
}
TensorDistAttr GetInferedDistAttr(
......@@ -184,7 +184,7 @@ TensorDistAttr GetInferedDistAttr(
return dist_attr_;
}
std::vector<DistTensorSpec> MatmulSPMDRule::InferBackward(
std::vector<TensorDistAttr> MatmulSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {}
......
......@@ -28,7 +28,7 @@ namespace auto_parallel {
TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axis,
const std::string& tensor_axes,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map);
class MatmulSPMDRule : public SPMDRuleBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册