未验证 提交 3483398c 编写于 作者: J JZ-LIANG 提交者: GitHub

[Semi Auto] Matmul & Embedding InferBackward Rule (#56257)

* add embedding backward rule

* update backward api

* revert api

* matmul inferbackward

* update unitest
上级 459ddf90
...@@ -33,6 +33,16 @@ SPMDRuleBase::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -33,6 +33,16 @@ SPMDRuleBase::InferForward(const std::vector<DistTensorSpec>& input_specs,
"derived class of SPMDRuleBase !")); "derived class of SPMDRuleBase !"));
} }
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferBackward should be called from a "
"derived class of SPMDRuleBase !"));
}
// deprecated
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs, SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
...@@ -210,7 +220,8 @@ GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes, ...@@ -210,7 +220,8 @@ GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes,
std::vector<int64_t> GetDimsMappingForAxes( std::vector<int64_t> GetDimsMappingForAxes(
const std::string& axes, const std::string& axes,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map) { const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool unsharded_miss_axis) {
std::vector<int64_t> dims_mapping; std::vector<int64_t> dims_mapping;
for (int64_t i = 0, n = axes.size(); i < n; i++) { for (int64_t i = 0, n = axes.size(); i < n; i++) {
std::string axis = axes.substr(i, 1); std::string axis = axes.substr(i, 1);
...@@ -219,10 +230,15 @@ std::vector<int64_t> GetDimsMappingForAxes( ...@@ -219,10 +230,15 @@ std::vector<int64_t> GetDimsMappingForAxes(
} else { } else {
auto iter = axis_to_dim_map.find(axis); auto iter = axis_to_dim_map.find(axis);
if (iter == axis_to_dim_map.end()) { if (iter == axis_to_dim_map.end()) {
phi::errors::InvalidArgument( if (unsharded_miss_axis) {
"Tensor axis [%s] of not in axis_to_dim_map.", axis); dims_mapping.emplace_back(-1);
} else {
phi::errors::InvalidArgument(
"Tensor axis [%s] of not in axis_to_dim_map.", axis);
}
} else {
dims_mapping.emplace_back(iter->second);
} }
dims_mapping.emplace_back(iter->second);
} }
} }
return dims_mapping; return dims_mapping;
......
...@@ -51,7 +51,7 @@ class SPMDRuleBase { ...@@ -51,7 +51,7 @@ class SPMDRuleBase {
InferForward(const std::vector<DistTensorSpec>& input_specs, InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs); const paddle::framework::AttributeMap& attrs);
// Based on the information of Output Tensors and Op Attribute: // Based on the information of Input & Output Tensors and Op Attribute:
// 1. Merge the Sharding (dims_mapping) among Output Tensors. // 1. Merge the Sharding (dims_mapping) among Output Tensors.
// 2. Infer the Sharding (dims_mapping) for Input Tensors. // 2. Infer the Sharding (dims_mapping) for Input Tensors.
// The Info of output tensors (Shape and DistAttr) are wrapped as // The Info of output tensors (Shape and DistAttr) are wrapped as
...@@ -60,6 +60,12 @@ class SPMDRuleBase { ...@@ -60,6 +60,12 @@ class SPMDRuleBase {
// 1. The first vector: the merged DistAttr of output tensors. // 1. The first vector: the merged DistAttr of output tensors.
// 2. The infered DistAttr of Input tensors. // 2. The infered DistAttr of Input tensors.
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs);
// deprecated, to be remove in future
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs, InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs); const paddle::framework::AttributeMap& attrs);
...@@ -141,9 +147,12 @@ GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes, ...@@ -141,9 +147,12 @@ GetAxesDimsMappingPair(const std::vector<std::string>& tensor_axes,
// the annotated axes after inferring forward or backward. The parameter axis // the annotated axes after inferring forward or backward. The parameter axis
// stores the axes of the tensor. "1" is a special axis, for the axis "1", set // stores the axes of the tensor. "1" is a special axis, for the axis "1", set
// its dims mapping to -1. // its dims mapping to -1.
// if unsharded_miss_axis, "-1" is assigend to axes that has no key in
// axis_to_dim_map.
std::vector<int64_t> GetDimsMappingForAxes( std::vector<int64_t> GetDimsMappingForAxes(
const std::string& axes, const std::string& axes,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map); const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool unsharded_miss_axis = false);
// The static map that stores and initializes all the registered SPMD rules. // The static map that stores and initializes all the registered SPMD rules.
class SPMDRuleMap { class SPMDRuleMap {
......
...@@ -91,8 +91,7 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -91,8 +91,7 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Row-wise parallel of embedding table does NOT support Sparse, but " "Row-wise parallel of embedding table does NOT support Sparse, but "
"row axis of embedding table is sharded by mesh dimension [%d].", "row axis of embedding table is sharded by mesh dimension [%d].",
padding_idx, weight_row_axis_mapping));
weight_ndim));
} }
VLOG(6) << "EmbeddingSPMDRule InferForward Inputs: " VLOG(6) << "EmbeddingSPMDRule InferForward Inputs: "
...@@ -125,11 +124,12 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -125,11 +124,12 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
output_dist_attr_dst.set_dims_mapping(out_dims_mapping); output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step3.1: Handle Partial // step3.1: Handle Partial
// (TODO) support case where embedding table is partial in very beginning. // (TODO) support case where embedding table is partial at very beginning.
std::vector<int64_t> partial_on_dims; std::vector<int64_t> partial_on_dims;
if (weight_row_axis_mapping > -1) { if (weight_row_axis_mapping > -1) {
partial_on_dims.push_back(weight_row_axis_mapping); partial_on_dims.push_back(weight_row_axis_mapping);
} }
output_dist_attr_dst.set_partial_status(partial_on_dims);
// step4: merge potential conflict in inputs // step4: merge potential conflict in inputs
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
...@@ -156,10 +156,69 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -156,10 +156,69 @@ EmbeddingSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
} }
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
EmbeddingSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs, EmbeddingSPMDRule::InferBackward(
const paddle::framework::AttributeMap& attrs) { const std::vector<DistTensorSpec>& input_specs,
PADDLE_THROW(phi::errors::Unimplemented( const std::vector<DistTensorSpec>& output_specs,
"InferBackward of EmbeddingSPMDRule is NOT implemented yet.")); const paddle::framework::AttributeMap& attrs) {
// InferBackward is called after InferForward, so we skip some checks.
auto output_specs_size = output_specs.size();
PADDLE_ENFORCE_EQ(
output_specs_size,
1,
phi::errors::InvalidArgument(
"The size of OutputSpec of embedding should be 1, but got [%d].",
output_specs_size));
auto x_shape = input_specs[0].shape();
int x_ndim = x_shape.size();
auto out_shape = output_specs[0].shape();
int out_ndim = out_shape.size();
PADDLE_ENFORCE_EQ(x_ndim,
out_ndim - 1,
phi::errors::InvalidArgument(
"There should be x_ndim + 1 = out_ndim in Embedding, "
"but got x_ndim: [%d] and out_ndim: [%d].",
x_ndim,
out_ndim));
auto out_dist_attr_src = output_specs[0].dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
// step1: build Einsum Notation
std::string alphabet = "abcdefghilmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(out_ndim - 1, out_ndim - 1, alphabet);
std::string weight_axes = "jk";
std::string out_axes = x_axes + "k";
// step2: Sharding Propogation
// should not use input dims mapping for backward sharding merge
auto axis_to_dim_map =
ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false);
TensorDistAttr x_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
x_dist_attr_dst.set_dims_mapping(GetDimsMappingForAxes(
x_axes, axis_to_dim_map, /*unsharded_miss_axis=*/true));
TensorDistAttr weight_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[1].dist_attr());
weight_dist_attr_dst.set_dims_mapping(GetDimsMappingForAxes(
weight_axes, axis_to_dim_map, /*unsharded_miss_axis=*/true));
// step3: Handle Partial
// NOTE we skip the partial backward inference in Partial Stage-I.
// output partial --> weight sharded on first axis.
VLOG(4) << "EmbeddingSPMDRule InferBackward: "
<< "Einsum notation: [" << x_axes << "," << weight_axes << " --> "
<< out_axes << "]. " << std::endl
<< "Out shape: [" << str_join(out_shape) << "], src_dims_mapping: ["
<< str_join(out_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(out_dims_mapping) << "]; Input X dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping())
<< "], Input Weight dims_mapping:["
<< str_join(weight_dist_attr_dst.dims_mapping()) << "].";
return {{x_dist_attr_dst, weight_dist_attr_dst}, {out_dist_attr_src}};
} }
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -33,7 +33,8 @@ class EmbeddingSPMDRule : public SPMDRuleBase { ...@@ -33,7 +33,8 @@ class EmbeddingSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs, InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
}; };
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -20,6 +20,91 @@ namespace paddle { ...@@ -20,6 +20,91 @@ namespace paddle {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
using phi::distributed::auto_parallel::str_join; using phi::distributed::auto_parallel::str_join;
TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool trans_axis) {
TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr);
std::vector<int64_t> infered_dims_mapping;
infered_dims_mapping.reserve(tensor_axis.size());
for (size_t i = 0; i < tensor_axis.size(); ++i) {
if (shape.size() > i && shape[i] == 1) {
infered_dims_mapping.push_back(-1);
} else {
auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1));
if (itr == axis_to_dim_map.end()) {
// infer the k axis as -1 in inferbackward.
infered_dims_mapping.push_back(-1);
} else {
infered_dims_mapping.push_back(itr->second);
}
}
}
if (trans_axis) {
std::iter_swap(infered_dims_mapping.end() - 2,
infered_dims_mapping.end() - 1);
}
dist_attr_.set_dims_mapping(infered_dims_mapping);
return dist_attr_;
}
void FillMatmulOperandNotation(const int x_ndim,
const int y_ndim,
std::string* x_axes,
std::string* y_axes,
std::string* out_axes) {
int max_ndim = std::max(x_ndim, y_ndim);
// reserve the char k, m, n for matrix product notation: mk,kn -> mn
std::string alphabet = "abcdefghijlopqrstuvwxyz";
// Handle 4 different matmul cases in Paddle
// vector * vector = scala
if (x_ndim == 1 && y_ndim == 1) {
*x_axes = "k";
*y_axes = "k";
*out_axes = "";
// vector * batched matrix
} else if (x_ndim == 1 && y_ndim > 1) {
*x_axes = "k";
std::string y_broadcast_axes =
GetBroadcastAxes(y_ndim - 2, y_ndim - 2, alphabet);
*y_axes = y_broadcast_axes + "kn";
*out_axes = y_broadcast_axes + "n";
// batched matrix * vector
} else if (x_ndim > 1 && y_ndim == 1) {
*y_axes = "k";
std::string x_broadcast_axes =
GetBroadcastAxes(x_ndim - 2, x_ndim - 2, alphabet);
*x_axes = x_broadcast_axes + "mk";
*out_axes = x_broadcast_axes + "m";
// batched matrix * batched matrix
} else if (x_ndim > 1 && y_ndim > 1) {
std::string x_broadcast_axes =
GetBroadcastAxes(x_ndim - 2, max_ndim - 2, alphabet);
std::string y_broadcast_axes =
GetBroadcastAxes(y_ndim - 2, max_ndim - 2, alphabet);
*x_axes = x_broadcast_axes + "mk";
*y_axes = y_broadcast_axes + "kn";
if (x_ndim > y_ndim) {
*out_axes = x_broadcast_axes + "mn";
} else {
*out_axes = y_broadcast_axes + "mn";
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"MatmulSPMDRule Receive Unsupported x_dim [%d] and y_dim [%d].",
x_ndim,
y_ndim));
}
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
...@@ -67,54 +152,10 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -67,54 +152,10 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
<< "[" << (trans_y ? "true" : "false") << "]; "; << "[" << (trans_y ? "true" : "false") << "]; ";
// step1: build Einsum Notation // step1: build Einsum Notation
// reserve the char k, m, n for matrix product notation: mk,kn -> mn
int max_ndim = std::max(x_ndim, y_ndim);
std::string alphabet = "abcdefghijlopqrstuvwxyz";
std::string x_axes; std::string x_axes;
std::string y_axes; std::string y_axes;
std::string out_axes; std::string out_axes;
FillMatmulOperandNotation(x_ndim, y_ndim, &x_axes, &y_axes, &out_axes);
// Handle 4 different matmul cases in Paddle
// vector * vector = scala
if (x_ndim == 1 && y_ndim == 1) {
x_axes = "k";
y_axes = "k";
out_axes = "";
// vector * batched matrix
} else if (x_ndim == 1 && y_ndim > 1) {
x_axes = "k";
std::string y_broadcast_axes =
GetBroadcastAxes(y_ndim - 2, y_ndim - 2, alphabet);
y_axes = y_broadcast_axes + "kn";
out_axes = y_broadcast_axes + "n";
// batched matrix * vector
} else if (x_ndim > 1 && y_ndim == 1) {
y_axes = "k";
std::string x_broadcast_axes =
GetBroadcastAxes(x_ndim - 2, x_ndim - 2, alphabet);
x_axes = x_broadcast_axes + "mk";
out_axes = x_broadcast_axes + "m";
// batched matrix * batched matrix
} else if (x_ndim > 1 && y_ndim > 1) {
std::string x_broadcast_axes =
GetBroadcastAxes(x_ndim - 2, max_ndim - 2, alphabet);
std::string y_broadcast_axes =
GetBroadcastAxes(y_ndim - 2, max_ndim - 2, alphabet);
x_axes = x_broadcast_axes + "mk";
y_axes = y_broadcast_axes + "kn";
if (x_ndim > y_ndim) {
out_axes = x_broadcast_axes + "mn";
} else {
out_axes = y_broadcast_axes + "mn";
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"MatmulSPMDRule Receive Unsupported x_dim [%d] and y_dim [%d].",
x_ndim,
y_ndim));
}
// step2: Sharding Propogation // step2: Sharding Propogation
if (trans_x) { if (trans_x) {
...@@ -180,46 +221,72 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -180,46 +221,72 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}}; return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}};
} }
TensorDistAttr GetInferedDistAttr( std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
const TensorDistAttr& origin_dist_attr, MatmulSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<int64_t>& shape, const std::vector<DistTensorSpec>& output_specs,
const std::string& tensor_axis, const paddle::framework::AttributeMap& attrs) {
const std::unordered_map<std::string, int64_t>& axis_to_dim_map, // extra & verify input
const bool trans_axis) { auto output_specs_size = output_specs.size();
TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr); PADDLE_ENFORCE_EQ(
std::vector<int64_t> infered_dims_mapping; output_specs_size,
infered_dims_mapping.reserve(tensor_axis.size()); 1,
phi::errors::InvalidArgument(
"The size of OutputSpec of matmul should be 1, but got [%d].",
output_specs_size));
for (size_t i = 0; i < tensor_axis.size(); ++i) { auto out_shape = output_specs[0].shape();
if (shape.size() > i && shape[i] == 1) { int out_ndim = out_shape.size();
infered_dims_mapping.push_back(-1);
} else {
auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1));
if (itr == axis_to_dim_map.end()) {
phi::errors::InvalidArgument(
"Tensor axis [%s] of not in axis_to_dim_map.",
tensor_axis.substr(i, 1));
}
infered_dims_mapping.push_back(itr->second);
}
}
if (trans_axis) { auto x_shape = input_specs[0].shape();
std::iter_swap(infered_dims_mapping.end() - 2, auto y_shape = input_specs[1].shape();
infered_dims_mapping.end() - 1); int x_ndim = x_shape.size();
} int y_ndim = y_shape.size();
int max_ndim = std::max(x_ndim, y_ndim);
PADDLE_ENFORCE_EQ(max_ndim,
out_ndim,
phi::errors::InvalidArgument(
"The max ndim of inputs should be equal out_ndim in "
"Matmul, but got max ndim: [%d] and out_ndim: [%d].",
max_ndim,
out_ndim));
dist_attr_.set_dims_mapping(infered_dims_mapping); bool trans_x = ExtractAttr<bool>("trans_x", attrs);
return dist_attr_; bool trans_y = ExtractAttr<bool>("trans_y", attrs);
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> auto out_dist_attr_src = output_specs[0].dist_attr();
MatmulSPMDRule::InferBackward(const std::vector<DistTensorSpec>& output_specs, std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented( // step1: build Einsum Notation
"InferBackward of MatmulSPMDRule is NOT implemented yet.")); std::string x_axes;
std::string y_axes;
std::string out_axes;
FillMatmulOperandNotation(x_ndim, y_ndim, &x_axes, &y_axes, &out_axes);
// step2: Sharding Propogation
// should not use input dims mapping for backward sharding merge
auto axis_to_dim_map =
ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false);
TensorDistAttr x_dist_attr_dst = GetInferedDistAttr(
input_specs[0].dist_attr(), x_shape, x_axes, axis_to_dim_map, trans_x);
TensorDistAttr y_dist_attr_dst = GetInferedDistAttr(
input_specs[1].dist_attr(), y_shape, y_axes, axis_to_dim_map, trans_y);
// step3: Handle Partial
// NOTE we skip the partial backward inference in Partial Stage-I.
// output partial --> axis k is sharded.
VLOG(4) << "MatmulSPMDRule InferBackward: "
<< "Einsum notation: [" << x_axes << "," << y_axes << " --> "
<< out_axes << "]. " << std::endl
<< "Out shape: [" << str_join(out_shape) << "], src_dims_mapping: ["
<< str_join(out_dims_mapping) << "], dst_dims_mapping: ["
<< str_join(out_dims_mapping) << "]; Input X dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping())
<< "], Input Y dims_mapping:["
<< str_join(y_dist_attr_dst.dims_mapping()) << "].";
return {}; return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr_src}};
} }
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -32,6 +32,12 @@ TensorDistAttr GetInferedDistAttr( ...@@ -32,6 +32,12 @@ TensorDistAttr GetInferedDistAttr(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map, const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool trans_axis); const bool trans_axis);
void FillMatmulOperandNotation(const int x_ndim,
const int y_ndim,
std::string* x_axes,
std::string* y_axes,
std::string* out_axes);
class MatmulSPMDRule : public SPMDRuleBase { class MatmulSPMDRule : public SPMDRuleBase {
public: public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
...@@ -39,7 +45,8 @@ class MatmulSPMDRule : public SPMDRuleBase { ...@@ -39,7 +45,8 @@ class MatmulSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs, InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
}; };
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -338,7 +338,14 @@ void BindAutoParallel(py::module *m) { ...@@ -338,7 +338,14 @@ void BindAutoParallel(py::module *m) {
py::class_<SPMDRuleBase>(*m, "SPMDRuleBase") py::class_<SPMDRuleBase>(*m, "SPMDRuleBase")
.def("infer_forward", &SPMDRuleBase::InferForward) .def("infer_forward", &SPMDRuleBase::InferForward)
.def("infer_backward", &SPMDRuleBase::InferBackward); .def("infer_backward",
static_cast<std::pair<std::vector<TensorDistAttr>,
std::vector<TensorDistAttr>> (SPMDRuleBase::*)(
const std::vector<DistTensorSpec> &,
const std::vector<DistTensorSpec> &,
const paddle::framework::AttributeMap &)>(
&SPMDRuleBase::InferBackward));
// .def("infer_backward", &SPMDRuleBase::InferBackward) [revert in future]
py::class_<DistTensorSpec>(*m, "DistTensorSpec") py::class_<DistTensorSpec>(*m, "DistTensorSpec")
.def(py::init<>()) .def(py::init<>())
......
# file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") # file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
# string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") # string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
add_subdirectory(spmd_rules)
if(WITH_DISTRIBUTE AND WITH_GPU) if(WITH_DISTRIBUTE AND WITH_GPU)
# NOTE(zyl): unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and timeout
......
# file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") # file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
# string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") # string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(WITH_DISTRIBUTE AND WITH_GPU) if(WITH_DISTRIBUTE)
# NOTE(zyl): unittests WITH single card and WITHOUT timeout # NOTE(zyl): unittests WITH single card and WITHOUT timeout
py_test_modules(test_matmul_rule MODULES test_matmul_rule) py_test_modules(test_matmul_rule MODULES test_matmul_rule)
py_test_modules(test_matmul_rule MODULES test_embedding_rule) py_test_modules(test_embedding_rule MODULES test_embedding_rule)
py_test_modules(test_matmul_rule MODULES test_replicated_rule) py_test_modules(test_replicated_rule MODULES test_replicated_rule)
py_test_modules(test_matmul_rule MODULES test_softmax_rule) py_test_modules(test_softmax_rule MODULES test_softmax_rule)
py_test_modules(test_split_rule MODULES test_split_rule) py_test_modules(test_split_rule MODULES test_split_rule)
py_test_modules(test_transpose_rule MODULES test_transpose_rule) py_test_modules(test_transpose_rule MODULES test_transpose_rule)
py_test_modules(test_elementwise_rule MODULES test_elementwise_rule)
py_test_modules(test_cross_entropy_with_softmax_rule MODULES
test_cross_entropy_with_softmax_rule)
py_test_modules(test_reduction_rule MODULES test_reduction_rule)
py_test_modules(test_reshape_rule MODULES test_reshape_rule) py_test_modules(test_reshape_rule MODULES test_reshape_rule)
# End of unittests WITH single card WITHOUT timeout # End of unittests WITH single card WITHOUT timeout
......
...@@ -26,6 +26,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase): ...@@ -26,6 +26,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
def setUp(self): def setUp(self):
self.rule1 = get_spmd_rule("lookup_table_v2") self.rule1 = get_spmd_rule("lookup_table_v2")
def test_embedding_infer_forward(self):
# forward setup
x_shape = [4, 1024] # [B,S] x_shape = [4, 1024] # [B,S]
table_shape = [512, 768] # [V,H] table_shape = [512, 768] # [V,H]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
...@@ -45,7 +47,6 @@ class TestEmbeddingSPMDRule(unittest.TestCase): ...@@ -45,7 +47,6 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
'sparse': False, 'sparse': False,
} }
def test_embedding_infer_forward(self):
# data parallel # data parallel
self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.table_dist_tensor_spec.set_dims_mapping([-1, -1]) self.table_dist_tensor_spec.set_dims_mapping([-1, -1])
...@@ -88,6 +89,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase): ...@@ -88,6 +89,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# table row-wise parallel & padding_idx # table row-wise parallel & padding_idx
self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.x_dist_tensor_spec.set_dims_mapping([1, -1])
...@@ -110,6 +113,89 @@ class TestEmbeddingSPMDRule(unittest.TestCase): ...@@ -110,6 +113,89 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
self.attrs, self.attrs,
) )
def test_embedding_infer_backward(self):
# backward setup
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
x_shape = [4, 1024] # [B,S]
table_shape = [512, 768] # [V,H]
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.process_mesh = (
process_mesh # not set the dims mapping is ok.
)
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
table_tensor_dist_attr = TensorDistAttr()
table_tensor_dist_attr.process_mesh = (
process_mesh # not set the dims mapping is ok.
)
self.table_dist_tensor_spec = DistTensorSpec(
table_shape, table_tensor_dist_attr
)
out_shape = [4, 1024, 768] # [B,S, H]
out_tensor_dist_attr = TensorDistAttr()
out_tensor_dist_attr.process_mesh = process_mesh
self.out_dist_tensor_spec = DistTensorSpec(
out_shape, out_tensor_dist_attr
)
self.attrs = {
'padding_idx': -1,
'sparse': False,
}
# data parallel
self.out_dist_tensor_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# table col-wise parallel & dp
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, 1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1])
# sharded on multiple broadcast axes
self.out_dist_tensor_spec.set_dims_mapping([1, 0, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.table_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1])
# table row-wise parallel
# skiped
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -26,6 +26,13 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -26,6 +26,13 @@ class TestMatmulSPMDRule(unittest.TestCase):
def setUp(self): def setUp(self):
self.rule = get_spmd_rule("matmul") self.rule = get_spmd_rule("matmul")
self.attrs = {
'trans_x': False,
'trans_y': False,
}
def test_matmul_infer_forward(self):
# forward setup
x_shape = [64, 32] x_shape = [64, 32]
y_shape = [32, 48] y_shape = [32, 48]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
...@@ -40,12 +47,6 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -40,12 +47,6 @@ class TestMatmulSPMDRule(unittest.TestCase):
y_tensor_dist_attr.process_mesh = process_mesh y_tensor_dist_attr.process_mesh = process_mesh
self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr) self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr)
self.attrs = {
'trans_x': False,
'trans_y': False,
}
def test_matmul_infer_forward(self):
# TODO test partial: mk[1, 0],kn[0, -1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0] # TODO test partial: mk[1, 0],kn[0, -1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]
result_dist_attrs = self.rule.infer_forward( result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
...@@ -61,7 +62,7 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -61,7 +62,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] # test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.x_dist_tensor_spec.set_dims_mapping([1, -1])
...@@ -115,7 +116,7 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -115,7 +116,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: # mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]:
self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) self.x_dist_tensor_spec.set_dims_mapping([-1, -1])
...@@ -129,7 +130,7 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -129,7 +130,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1})
# abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done # abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done
self.x_dist_tensor_spec.shape = [512, 48, 64, 32] self.x_dist_tensor_spec.shape = [512, 48, 64, 32]
...@@ -165,7 +166,7 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -165,7 +166,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1] infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
) )
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[] # trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
...@@ -203,7 +204,7 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -203,7 +204,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1] infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1]
) )
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
infered_output_dist_attrs[0]._clean_partial_dims([0]) infered_output_dist_attrs[0]._clean_partial_dims([0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
...@@ -226,14 +227,14 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -226,14 +227,14 @@ class TestMatmulSPMDRule(unittest.TestCase):
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1]
) )
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
infered_output_dist_attrs[0]._clean_partial_status() infered_output_dist_attrs[0]._clean_partial_status()
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error: # trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error:
# one mesh dim shard multiple tensor axes # one tensor axis shard multiple mesh dim
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0]) self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, -1])
self.y_dist_tensor_spec.set_dims_mapping([1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1, 0])
self.attrs['trans_x'] = True self.attrs['trans_x'] = True
self.attrs['trans_y'] = True self.attrs['trans_y'] = True
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
...@@ -241,6 +242,143 @@ class TestMatmulSPMDRule(unittest.TestCase): ...@@ -241,6 +242,143 @@ class TestMatmulSPMDRule(unittest.TestCase):
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
) )
def test_matmul_infer_backward(self):
# backward setup
x_shape = [64, 32]
y_shape = [32, 48]
out_shape = [64, 48]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [-1, -1]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
y_tensor_dist_attr = TensorDistAttr()
y_tensor_dist_attr.dims_mapping = [-1, -1]
y_tensor_dist_attr.process_mesh = process_mesh
self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr)
out_tensor_dist_attr = TensorDistAttr()
out_tensor_dist_attr.dims_mapping = [1, 0]
out_tensor_dist_attr.process_mesh = process_mesh
self.out_dist_tensor_spec = DistTensorSpec(
out_shape, out_tensor_dist_attr
)
# mn[1, 0] --> mk[1, -1],kn[-1, 0]
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0])
self.assertEqual(infered_input_dist_attrs[0]._is_partial(), False)
self.assertEqual(infered_input_dist_attrs[1]._is_partial(), False)
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# test on broadcast axes propogation
# abmn[1, 0, -1, -1] --> 1mk[-1, -1, -1], abkn[1, 0, -1, -1]
self.out_dist_tensor_spec.shape = [512, 48, 64, 48]
self.x_dist_tensor_spec.shape = [1, 64, 32]
self.y_dist_tensor_spec.shape = [512, 48, 32, 48]
self.x_dist_tensor_spec.set_dims_mapping(
[0, -1, 1]
) # dims mapping of input should not influence inferbackward
self.y_dist_tensor_spec.set_dims_mapping(
[
-1,
-1,
1,
0,
]
) # dims mapping of input should not influence inferbackward
self.out_dist_tensor_spec.set_dims_mapping([1, 0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1])
self.assertEqual(
infered_input_dist_attrs[1].dims_mapping, [1, 0, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, 0, -1, -1]
)
# abmn[-1, 0, -1, 1] --> abmk[-1, 0, -1, -1], a1kn[-1, -1, -1, 1]
self.out_dist_tensor_spec.shape = [512, 48, 64, 48]
self.x_dist_tensor_spec.shape = [512, 48, 64, 32]
self.y_dist_tensor_spec.shape = [512, 1, 32, 48]
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1]
)
self.assertEqual(
infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
# trans_x = true, trans_y = true, abmn[-1, -1, 0, 1] --> abmk[-1, -1, -1, 0], a1kn[-1, -1, 1, -1]
self.out_dist_tensor_spec.shape = [512, 48, 64, 48]
self.x_dist_tensor_spec.shape = [512, 48, 32, 64]
self.y_dist_tensor_spec.shape = [512, 1, 48, 32]
self.out_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
self.attrs['trans_x'] = True
self.attrs['trans_y'] = True
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0]
)
self.assertEqual(
infered_input_dist_attrs[1].dims_mapping, [-1, -1, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
# # trans_x = true, trans_y = true, abmn[-1, 1, 0, 1] --> error:
# one mesh dim shard multiple tensor axes
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0, 1])
with self.assertRaises(RuntimeError):
self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -63,7 +63,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -63,7 +63,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# reduce on dim 0, keep_dim = true # reduce on dim 0, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
...@@ -79,7 +79,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# reduce on dim 1, keep_dim = false # reduce on dim 1, keep_dim = false
# [0, -1] --> [0, -1], [0], partial_on_dim:[] # [0, -1] --> [0, -1], [0], partial_on_dim:[]
...@@ -125,7 +125,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -125,7 +125,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, []) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
# reduce on dim 0 and 1, keep_dim = true # reduce on dim 0 and 1, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
...@@ -141,7 +141,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -141,7 +141,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0})
def test_multi_mesh_dim(self): def test_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
...@@ -181,7 +181,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -181,7 +181,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0, 1]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0, 1})
infered_output_dist_attrs[0]._clean_partial_status() infered_output_dist_attrs[0]._clean_partial_status()
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# reduction on dim 1, 2, keep_dim = false # reduction on dim 1, 2, keep_dim = false
...@@ -213,7 +213,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -213,7 +213,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1})
infered_output_dist_attrs[0]._clean_partial_status() infered_output_dist_attrs[0]._clean_partial_status()
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
...@@ -231,7 +231,7 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -231,7 +231,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1})
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -343,6 +343,73 @@ TEST(LayerNormSPMDRule, Ctor) { ...@@ -343,6 +343,73 @@ TEST(LayerNormSPMDRule, Ctor) {
VLOG(4) << "test2 done."; VLOG(4) << "test2 done.";
} }
TEST(MatmulSPMDRuleInferBackward, Ctor) {
// build input data class
std::vector<int64_t> x_shape = {512, 1024, 64, 32};
std::vector<int64_t> y_shape = {512, 1, 32, 48};
std::vector<int64_t> out_shape = {512, 1024, 64, 48};
std::vector<int64_t> mesh_shape = {2, 3};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
TensorDistAttr x_dist_attr = TensorDistAttr();
x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(
std::vector<int64_t>({-1, 1, 0, -1})); // no affect
x_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
TensorDistAttr y_dist_attr = TensorDistAttr();
y_dist_attr.set_process_mesh(process_mesh);
y_dist_attr.set_dims_mapping(
std::vector<int64_t>({0, 1, -1, -1})); // no affect
y_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
TensorDistAttr out_dist_attr = TensorDistAttr();
out_dist_attr.set_process_mesh(process_mesh);
out_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, -1, 1, -1}));
out_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
out_dist_attr.set_partial_status(std::vector<int64_t>({0}));
DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr);
DistTensorSpec y_dist_tensor_spec = DistTensorSpec(y_shape, y_dist_attr);
DistTensorSpec out_dist_tensor_spec =
DistTensorSpec(out_shape, out_dist_attr);
paddle::framework::AttributeMap attrs;
attrs["trans_x"] = false;
attrs["trans_y"] = false;
SPMDRuleBase* matmul_rule = SPMDRuleMap::Instance().Get("matmul");
// TODO(zyc) update in future: propogate the partial in inferbackward
// abmn[-1, -1, 1, -1] + partial[0] --> abmk[-1, -1, 1, -1], a1kn[-1, -1, -1,
// -1]
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infered_dist_attrs =
matmul_rule->InferBackward({x_dist_tensor_spec, y_dist_tensor_spec},
{out_dist_tensor_spec},
attrs);
size_t input_size = 2;
size_t output_size = 1;
EXPECT_EQ(infered_dist_attrs.first.size(), input_size);
EXPECT_EQ(infered_dist_attrs.second.size(), output_size);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 1, -1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, -1, -1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 1, -1}));
EXPECT_EQ(infered_dist_attrs.first[0].is_partial(), false);
EXPECT_EQ(infered_dist_attrs.first[1].is_partial(), false);
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true);
VLOG(4) << "test1 done." << std::endl << std::endl << std::endl;
}
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册