From 3483398c59238e871ba18fde4318eeb5cca5f9a1 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 25 Aug 2023 20:38:48 +0800 Subject: [PATCH] [Semi Auto] Matmul & Embedding InferBackward Rule (#56257) * add embedding backward rule * update backward api * revert api * matmul inferbackward * update unitest --- .../auto_parallel/spmd_rules/common.cc | 24 +- .../auto_parallel/spmd_rules/common.h | 13 +- .../spmd_rules/embedding_spmd_rule.cc | 73 +++++- .../spmd_rules/embedding_spmd_rule.h | 3 +- .../spmd_rules/matmul_spmd_rule.cc | 227 ++++++++++++------ .../spmd_rules/matmul_spmd_rule.h | 9 +- paddle/fluid/pybind/auto_parallel_py.cc | 9 +- test/auto_parallel/CMakeLists.txt | 2 + test/auto_parallel/spmd_rules/CMakeLists.txt | 12 +- .../spmd_rules/test_embedding_rule.py | 88 ++++++- .../spmd_rules/test_matmul_rule.py | 168 +++++++++++-- .../spmd_rules/test_reduction_rule.py | 14 +- test/cpp/auto_parallel/spmd_rule_test.cc | 67 ++++++ 13 files changed, 586 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index 8c71bf111a9..9bef518850b 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -33,6 +33,16 @@ SPMDRuleBase::InferForward(const std::vector& input_specs, "derived class of SPMDRuleBase !")); } +std::pair, std::vector> +SPMDRuleBase::InferBackward(const std::vector& input_specs, + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + PADDLE_THROW( + phi::errors::Unimplemented("InferBackward should be called from a " + "derived class of SPMDRuleBase !")); +} + +// deprecated std::pair, std::vector> SPMDRuleBase::InferBackward(const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { @@ -210,7 +220,8 @@ GetAxesDimsMappingPair(const std::vector& tensor_axes, std::vector GetDimsMappingForAxes( const std::string& axes, - const std::unordered_map& axis_to_dim_map) { + const std::unordered_map& axis_to_dim_map, + const bool unsharded_miss_axis) { std::vector dims_mapping; for (int64_t i = 0, n = axes.size(); i < n; i++) { std::string axis = axes.substr(i, 1); @@ -219,10 +230,15 @@ std::vector GetDimsMappingForAxes( } else { auto iter = axis_to_dim_map.find(axis); if (iter == axis_to_dim_map.end()) { - phi::errors::InvalidArgument( - "Tensor axis [%s] of not in axis_to_dim_map.", axis); + if (unsharded_miss_axis) { + dims_mapping.emplace_back(-1); + } else { + phi::errors::InvalidArgument( + "Tensor axis [%s] of not in axis_to_dim_map.", axis); + } + } else { + dims_mapping.emplace_back(iter->second); } - dims_mapping.emplace_back(iter->second); } } return dims_mapping; diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 26c421eb27e..dd493276548 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -51,7 +51,7 @@ class SPMDRuleBase { InferForward(const std::vector& input_specs, const paddle::framework::AttributeMap& attrs); - // Based on the information of Output Tensors and Op Attribute: + // Based on the information of Input & Output Tensors and Op Attribute: // 1. Merge the Sharding (dims_mapping) among Output Tensors. // 2. Infer the Sharding (dims_mapping) for Input Tensors. // The Info of output tensors (Shape and DistAttr) are wrapped as @@ -60,6 +60,12 @@ class SPMDRuleBase { // 1. The first vector: the merged DistAttr of output tensors. // 2. The infered DistAttr of Input tensors. virtual std::pair, std::vector> + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs); + + // deprecated, to be remove in future + virtual std::pair, std::vector> InferBackward(const std::vector& output_specs, const paddle::framework::AttributeMap& attrs); @@ -141,9 +147,12 @@ GetAxesDimsMappingPair(const std::vector& tensor_axes, // the annotated axes after inferring forward or backward. The parameter axis // stores the axes of the tensor. "1" is a special axis, for the axis "1", set // its dims mapping to -1. +// if unsharded_miss_axis, "-1" is assigend to axes that has no key in +// axis_to_dim_map. std::vector GetDimsMappingForAxes( const std::string& axes, - const std::unordered_map& axis_to_dim_map); + const std::unordered_map& axis_to_dim_map, + const bool unsharded_miss_axis = false); // The static map that stores and initializes all the registered SPMD rules. class SPMDRuleMap { diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.cc index cbf6bb94af5..b64afd92300 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.cc @@ -91,8 +91,7 @@ EmbeddingSPMDRule::InferForward(const std::vector& input_specs, phi::errors::InvalidArgument( "Row-wise parallel of embedding table does NOT support Sparse, but " "row axis of embedding table is sharded by mesh dimension [%d].", - padding_idx, - weight_ndim)); + weight_row_axis_mapping)); } VLOG(6) << "EmbeddingSPMDRule InferForward Inputs: " @@ -125,11 +124,12 @@ EmbeddingSPMDRule::InferForward(const std::vector& input_specs, output_dist_attr_dst.set_dims_mapping(out_dims_mapping); // step3.1: Handle Partial - // (TODO) support case where embedding table is partial in very beginning. + // (TODO) support case where embedding table is partial at very beginning. std::vector partial_on_dims; if (weight_row_axis_mapping > -1) { partial_on_dims.push_back(weight_row_axis_mapping); } + output_dist_attr_dst.set_partial_status(partial_on_dims); // step4: merge potential conflict in inputs TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); @@ -156,10 +156,69 @@ EmbeddingSPMDRule::InferForward(const std::vector& input_specs, } std::pair, std::vector> -EmbeddingSPMDRule::InferBackward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of EmbeddingSPMDRule is NOT implemented yet.")); +EmbeddingSPMDRule::InferBackward( + const std::vector& input_specs, + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + // InferBackward is called after InferForward, so we skip some checks. + auto output_specs_size = output_specs.size(); + PADDLE_ENFORCE_EQ( + output_specs_size, + 1, + phi::errors::InvalidArgument( + "The size of OutputSpec of embedding should be 1, but got [%d].", + output_specs_size)); + + auto x_shape = input_specs[0].shape(); + int x_ndim = x_shape.size(); + auto out_shape = output_specs[0].shape(); + int out_ndim = out_shape.size(); + + PADDLE_ENFORCE_EQ(x_ndim, + out_ndim - 1, + phi::errors::InvalidArgument( + "There should be x_ndim + 1 = out_ndim in Embedding, " + "but got x_ndim: [%d] and out_ndim: [%d].", + x_ndim, + out_ndim)); + + auto out_dist_attr_src = output_specs[0].dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + + // step1: build Einsum Notation + std::string alphabet = "abcdefghilmnopqrstuvwxyz"; + std::string x_axes = GetBroadcastAxes(out_ndim - 1, out_ndim - 1, alphabet); + std::string weight_axes = "jk"; + std::string out_axes = x_axes + "k"; + + // step2: Sharding Propogation + // should not use input dims mapping for backward sharding merge + auto axis_to_dim_map = + ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false); + TensorDistAttr x_dist_attr_dst = + CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); + x_dist_attr_dst.set_dims_mapping(GetDimsMappingForAxes( + x_axes, axis_to_dim_map, /*unsharded_miss_axis=*/true)); + TensorDistAttr weight_dist_attr_dst = + CopyTensorDistAttrForOutput(input_specs[1].dist_attr()); + weight_dist_attr_dst.set_dims_mapping(GetDimsMappingForAxes( + weight_axes, axis_to_dim_map, /*unsharded_miss_axis=*/true)); + + // step3: Handle Partial + // NOTE we skip the partial backward inference in Partial Stage-I. + // output partial --> weight sharded on first axis. + + VLOG(4) << "EmbeddingSPMDRule InferBackward: " + << "Einsum notation: [" << x_axes << "," << weight_axes << " --> " + << out_axes << "]. " << std::endl + << "Out shape: [" << str_join(out_shape) << "], src_dims_mapping: [" + << str_join(out_dims_mapping) << "], dst_dims_mapping: [" + << str_join(out_dims_mapping) << "]; Input X dims_mapping: [" + << str_join(x_dist_attr_dst.dims_mapping()) + << "], Input Weight dims_mapping:[" + << str_join(weight_dist_attr_dst.dims_mapping()) << "]."; + + return {{x_dist_attr_dst, weight_dist_attr_dst}, {out_dist_attr_src}}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h index 58a2d34d2a2..cf90a9de0e0 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h @@ -33,7 +33,8 @@ class EmbeddingSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index 68fd9536707..d280ccec37d 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -20,6 +20,91 @@ namespace paddle { namespace distributed { namespace auto_parallel { using phi::distributed::auto_parallel::str_join; + +TensorDistAttr GetInferedDistAttr( + const TensorDistAttr& origin_dist_attr, + const std::vector& shape, + const std::string& tensor_axis, + const std::unordered_map& axis_to_dim_map, + const bool trans_axis) { + TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr); + std::vector infered_dims_mapping; + infered_dims_mapping.reserve(tensor_axis.size()); + + for (size_t i = 0; i < tensor_axis.size(); ++i) { + if (shape.size() > i && shape[i] == 1) { + infered_dims_mapping.push_back(-1); + } else { + auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1)); + if (itr == axis_to_dim_map.end()) { + // infer the k axis as -1 in inferbackward. + infered_dims_mapping.push_back(-1); + } else { + infered_dims_mapping.push_back(itr->second); + } + } + } + + if (trans_axis) { + std::iter_swap(infered_dims_mapping.end() - 2, + infered_dims_mapping.end() - 1); + } + + dist_attr_.set_dims_mapping(infered_dims_mapping); + return dist_attr_; +} + +void FillMatmulOperandNotation(const int x_ndim, + const int y_ndim, + std::string* x_axes, + std::string* y_axes, + std::string* out_axes) { + int max_ndim = std::max(x_ndim, y_ndim); + // reserve the char k, m, n for matrix product notation: mk,kn -> mn + std::string alphabet = "abcdefghijlopqrstuvwxyz"; + + // Handle 4 different matmul cases in Paddle + // vector * vector = scala + if (x_ndim == 1 && y_ndim == 1) { + *x_axes = "k"; + *y_axes = "k"; + *out_axes = ""; + // vector * batched matrix + } else if (x_ndim == 1 && y_ndim > 1) { + *x_axes = "k"; + std::string y_broadcast_axes = + GetBroadcastAxes(y_ndim - 2, y_ndim - 2, alphabet); + *y_axes = y_broadcast_axes + "kn"; + *out_axes = y_broadcast_axes + "n"; + // batched matrix * vector + } else if (x_ndim > 1 && y_ndim == 1) { + *y_axes = "k"; + std::string x_broadcast_axes = + GetBroadcastAxes(x_ndim - 2, x_ndim - 2, alphabet); + *x_axes = x_broadcast_axes + "mk"; + *out_axes = x_broadcast_axes + "m"; + // batched matrix * batched matrix + } else if (x_ndim > 1 && y_ndim > 1) { + std::string x_broadcast_axes = + GetBroadcastAxes(x_ndim - 2, max_ndim - 2, alphabet); + std::string y_broadcast_axes = + GetBroadcastAxes(y_ndim - 2, max_ndim - 2, alphabet); + *x_axes = x_broadcast_axes + "mk"; + *y_axes = y_broadcast_axes + "kn"; + + if (x_ndim > y_ndim) { + *out_axes = x_broadcast_axes + "mn"; + } else { + *out_axes = y_broadcast_axes + "mn"; + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "MatmulSPMDRule Receive Unsupported x_dim [%d] and y_dim [%d].", + x_ndim, + y_ndim)); + } +} + std::pair, std::vector> MatmulSPMDRule::InferForward(const std::vector& input_specs, const paddle::framework::AttributeMap& attrs) { @@ -67,54 +152,10 @@ MatmulSPMDRule::InferForward(const std::vector& input_specs, << "[" << (trans_y ? "true" : "false") << "]; "; // step1: build Einsum Notation - - // reserve the char k, m, n for matrix product notation: mk,kn -> mn - int max_ndim = std::max(x_ndim, y_ndim); - std::string alphabet = "abcdefghijlopqrstuvwxyz"; std::string x_axes; std::string y_axes; std::string out_axes; - - // Handle 4 different matmul cases in Paddle - // vector * vector = scala - if (x_ndim == 1 && y_ndim == 1) { - x_axes = "k"; - y_axes = "k"; - out_axes = ""; - // vector * batched matrix - } else if (x_ndim == 1 && y_ndim > 1) { - x_axes = "k"; - std::string y_broadcast_axes = - GetBroadcastAxes(y_ndim - 2, y_ndim - 2, alphabet); - y_axes = y_broadcast_axes + "kn"; - out_axes = y_broadcast_axes + "n"; - // batched matrix * vector - } else if (x_ndim > 1 && y_ndim == 1) { - y_axes = "k"; - std::string x_broadcast_axes = - GetBroadcastAxes(x_ndim - 2, x_ndim - 2, alphabet); - x_axes = x_broadcast_axes + "mk"; - out_axes = x_broadcast_axes + "m"; - // batched matrix * batched matrix - } else if (x_ndim > 1 && y_ndim > 1) { - std::string x_broadcast_axes = - GetBroadcastAxes(x_ndim - 2, max_ndim - 2, alphabet); - std::string y_broadcast_axes = - GetBroadcastAxes(y_ndim - 2, max_ndim - 2, alphabet); - x_axes = x_broadcast_axes + "mk"; - y_axes = y_broadcast_axes + "kn"; - - if (x_ndim > y_ndim) { - out_axes = x_broadcast_axes + "mn"; - } else { - out_axes = y_broadcast_axes + "mn"; - } - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "MatmulSPMDRule Receive Unsupported x_dim [%d] and y_dim [%d].", - x_ndim, - y_ndim)); - } + FillMatmulOperandNotation(x_ndim, y_ndim, &x_axes, &y_axes, &out_axes); // step2: Sharding Propogation if (trans_x) { @@ -180,46 +221,72 @@ MatmulSPMDRule::InferForward(const std::vector& input_specs, return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}}; } -TensorDistAttr GetInferedDistAttr( - const TensorDistAttr& origin_dist_attr, - const std::vector& shape, - const std::string& tensor_axis, - const std::unordered_map& axis_to_dim_map, - const bool trans_axis) { - TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr); - std::vector infered_dims_mapping; - infered_dims_mapping.reserve(tensor_axis.size()); +std::pair, std::vector> +MatmulSPMDRule::InferBackward(const std::vector& input_specs, + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + // extra & verify input + auto output_specs_size = output_specs.size(); + PADDLE_ENFORCE_EQ( + output_specs_size, + 1, + phi::errors::InvalidArgument( + "The size of OutputSpec of matmul should be 1, but got [%d].", + output_specs_size)); - for (size_t i = 0; i < tensor_axis.size(); ++i) { - if (shape.size() > i && shape[i] == 1) { - infered_dims_mapping.push_back(-1); - } else { - auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1)); - if (itr == axis_to_dim_map.end()) { - phi::errors::InvalidArgument( - "Tensor axis [%s] of not in axis_to_dim_map.", - tensor_axis.substr(i, 1)); - } - infered_dims_mapping.push_back(itr->second); - } - } + auto out_shape = output_specs[0].shape(); + int out_ndim = out_shape.size(); - if (trans_axis) { - std::iter_swap(infered_dims_mapping.end() - 2, - infered_dims_mapping.end() - 1); - } + auto x_shape = input_specs[0].shape(); + auto y_shape = input_specs[1].shape(); + int x_ndim = x_shape.size(); + int y_ndim = y_shape.size(); + int max_ndim = std::max(x_ndim, y_ndim); + PADDLE_ENFORCE_EQ(max_ndim, + out_ndim, + phi::errors::InvalidArgument( + "The max ndim of inputs should be equal out_ndim in " + "Matmul, but got max ndim: [%d] and out_ndim: [%d].", + max_ndim, + out_ndim)); - dist_attr_.set_dims_mapping(infered_dims_mapping); - return dist_attr_; -} + bool trans_x = ExtractAttr("trans_x", attrs); + bool trans_y = ExtractAttr("trans_y", attrs); -std::pair, std::vector> -MatmulSPMDRule::InferBackward(const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of MatmulSPMDRule is NOT implemented yet.")); + auto out_dist_attr_src = output_specs[0].dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + + // step1: build Einsum Notation + std::string x_axes; + std::string y_axes; + std::string out_axes; + FillMatmulOperandNotation(x_ndim, y_ndim, &x_axes, &y_axes, &out_axes); + + // step2: Sharding Propogation + // should not use input dims mapping for backward sharding merge + auto axis_to_dim_map = + ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false); + + TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( + input_specs[0].dist_attr(), x_shape, x_axes, axis_to_dim_map, trans_x); + TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( + input_specs[1].dist_attr(), y_shape, y_axes, axis_to_dim_map, trans_y); + + // step3: Handle Partial + // NOTE we skip the partial backward inference in Partial Stage-I. + // output partial --> axis k is sharded. + + VLOG(4) << "MatmulSPMDRule InferBackward: " + << "Einsum notation: [" << x_axes << "," << y_axes << " --> " + << out_axes << "]. " << std::endl + << "Out shape: [" << str_join(out_shape) << "], src_dims_mapping: [" + << str_join(out_dims_mapping) << "], dst_dims_mapping: [" + << str_join(out_dims_mapping) << "]; Input X dims_mapping: [" + << str_join(x_dist_attr_dst.dims_mapping()) + << "], Input Y dims_mapping:[" + << str_join(y_dist_attr_dst.dims_mapping()) << "]."; - return {}; + return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr_src}}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h index 6ce43a314d4..70d603e509c 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h @@ -32,6 +32,12 @@ TensorDistAttr GetInferedDistAttr( const std::unordered_map& axis_to_dim_map, const bool trans_axis); +void FillMatmulOperandNotation(const int x_ndim, + const int y_ndim, + std::string* x_axes, + std::string* y_axes, + std::string* out_axes); + class MatmulSPMDRule : public SPMDRuleBase { public: std::pair, std::vector> @@ -39,7 +45,8 @@ class MatmulSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 860aac80ba7..977583daf73 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -338,7 +338,14 @@ void BindAutoParallel(py::module *m) { py::class_(*m, "SPMDRuleBase") .def("infer_forward", &SPMDRuleBase::InferForward) - .def("infer_backward", &SPMDRuleBase::InferBackward); + .def("infer_backward", + static_cast, + std::vector> (SPMDRuleBase::*)( + const std::vector &, + const std::vector &, + const paddle::framework::AttributeMap &)>( + &SPMDRuleBase::InferBackward)); + // .def("infer_backward", &SPMDRuleBase::InferBackward) [revert in future] py::class_(*m, "DistTensorSpec") .def(py::init<>()) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 2bfdb89ffa0..817508e5703 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -1,6 +1,8 @@ # file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") # string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +add_subdirectory(spmd_rules) + if(WITH_DISTRIBUTE AND WITH_GPU) # NOTE(zyl): unittests WITH multi cards and timeout diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index c981aee6f83..96c14a7a2a8 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -1,15 +1,19 @@ # file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") # string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -if(WITH_DISTRIBUTE AND WITH_GPU) +if(WITH_DISTRIBUTE) # NOTE(zyl): unittests WITH single card and WITHOUT timeout py_test_modules(test_matmul_rule MODULES test_matmul_rule) - py_test_modules(test_matmul_rule MODULES test_embedding_rule) - py_test_modules(test_matmul_rule MODULES test_replicated_rule) - py_test_modules(test_matmul_rule MODULES test_softmax_rule) + py_test_modules(test_embedding_rule MODULES test_embedding_rule) + py_test_modules(test_replicated_rule MODULES test_replicated_rule) + py_test_modules(test_softmax_rule MODULES test_softmax_rule) py_test_modules(test_split_rule MODULES test_split_rule) py_test_modules(test_transpose_rule MODULES test_transpose_rule) + py_test_modules(test_elementwise_rule MODULES test_elementwise_rule) + py_test_modules(test_cross_entropy_with_softmax_rule MODULES + test_cross_entropy_with_softmax_rule) + py_test_modules(test_reduction_rule MODULES test_reduction_rule) py_test_modules(test_reshape_rule MODULES test_reshape_rule) # End of unittests WITH single card WITHOUT timeout diff --git a/test/auto_parallel/spmd_rules/test_embedding_rule.py b/test/auto_parallel/spmd_rules/test_embedding_rule.py index b55c423ffff..747b0b7c19e 100644 --- a/test/auto_parallel/spmd_rules/test_embedding_rule.py +++ b/test/auto_parallel/spmd_rules/test_embedding_rule.py @@ -26,6 +26,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase): def setUp(self): self.rule1 = get_spmd_rule("lookup_table_v2") + def test_embedding_infer_forward(self): + # forward setup x_shape = [4, 1024] # [B,S] table_shape = [512, 768] # [V,H] process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) @@ -45,7 +47,6 @@ class TestEmbeddingSPMDRule(unittest.TestCase): 'sparse': False, } - def test_embedding_infer_forward(self): # data parallel self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.table_dist_tensor_spec.set_dims_mapping([-1, -1]) @@ -88,6 +89,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # table row-wise parallel & padding_idx self.x_dist_tensor_spec.set_dims_mapping([1, -1]) @@ -110,6 +113,89 @@ class TestEmbeddingSPMDRule(unittest.TestCase): self.attrs, ) + def test_embedding_infer_backward(self): + # backward setup + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_shape = [4, 1024] # [B,S] + table_shape = [512, 768] # [V,H] + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.process_mesh = ( + process_mesh # not set the dims mapping is ok. + ) + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + table_tensor_dist_attr = TensorDistAttr() + table_tensor_dist_attr.process_mesh = ( + process_mesh # not set the dims mapping is ok. + ) + self.table_dist_tensor_spec = DistTensorSpec( + table_shape, table_tensor_dist_attr + ) + + out_shape = [4, 1024, 768] # [B,S, H] + out_tensor_dist_attr = TensorDistAttr() + out_tensor_dist_attr.process_mesh = process_mesh + self.out_dist_tensor_spec = DistTensorSpec( + out_shape, out_tensor_dist_attr + ) + + self.attrs = { + 'padding_idx': -1, + 'sparse': False, + } + + # data parallel + self.out_dist_tensor_spec.set_dims_mapping([1, -1, -1]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.table_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1]) + + # table col-wise parallel & dp + self.out_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.table_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # sharded on multiple broadcast axes + self.out_dist_tensor_spec.set_dims_mapping([1, 0, -1]) + + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.table_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # table row-wise parallel + # skiped + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_matmul_rule.py b/test/auto_parallel/spmd_rules/test_matmul_rule.py index a693307ff5e..59e47113302 100644 --- a/test/auto_parallel/spmd_rules/test_matmul_rule.py +++ b/test/auto_parallel/spmd_rules/test_matmul_rule.py @@ -26,6 +26,13 @@ class TestMatmulSPMDRule(unittest.TestCase): def setUp(self): self.rule = get_spmd_rule("matmul") + self.attrs = { + 'trans_x': False, + 'trans_y': False, + } + + def test_matmul_infer_forward(self): + # forward setup x_shape = [64, 32] y_shape = [32, 48] process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) @@ -40,12 +47,6 @@ class TestMatmulSPMDRule(unittest.TestCase): y_tensor_dist_attr.process_mesh = process_mesh self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr) - self.attrs = { - 'trans_x': False, - 'trans_y': False, - } - - def test_matmul_infer_forward(self): # TODO test partial: mk[1, 0],kn[0, -1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0] result_dist_attrs = self.rule.infer_forward( [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs @@ -61,7 +62,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1]) @@ -115,7 +116,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) @@ -129,7 +130,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1}) # abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done self.x_dist_tensor_spec.shape = [512, 48, 64, 32] @@ -165,7 +166,7 @@ class TestMatmulSPMDRule(unittest.TestCase): infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1] ) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) @@ -203,7 +204,7 @@ class TestMatmulSPMDRule(unittest.TestCase): infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1] ) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) infered_output_dist_attrs[0]._clean_partial_dims([0]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) @@ -226,14 +227,14 @@ class TestMatmulSPMDRule(unittest.TestCase): infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] ) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) infered_output_dist_attrs[0]._clean_partial_status() self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error: - # one mesh dim shard multiple tensor axes - self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0]) - self.y_dist_tensor_spec.set_dims_mapping([1, 0]) + # one tensor axis shard multiple mesh dim + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, 0]) self.attrs['trans_x'] = True self.attrs['trans_y'] = True with self.assertRaises(NotImplementedError): @@ -241,6 +242,143 @@ class TestMatmulSPMDRule(unittest.TestCase): [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs ) + def test_matmul_infer_backward(self): + # backward setup + x_shape = [64, 32] + y_shape = [32, 48] + out_shape = [64, 48] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + y_tensor_dist_attr = TensorDistAttr() + y_tensor_dist_attr.dims_mapping = [-1, -1] + y_tensor_dist_attr.process_mesh = process_mesh + self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr) + + out_tensor_dist_attr = TensorDistAttr() + out_tensor_dist_attr.dims_mapping = [1, 0] + out_tensor_dist_attr.process_mesh = process_mesh + self.out_dist_tensor_spec = DistTensorSpec( + out_shape, out_tensor_dist_attr + ) + + # mn[1, 0] --> mk[1, -1],kn[-1, 0] + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_input_dist_attrs[0]._is_partial(), False) + self.assertEqual(infered_input_dist_attrs[1]._is_partial(), False) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) + + # test on broadcast axes propogation + # abmn[1, 0, -1, -1] --> 1mk[-1, -1, -1], abkn[1, 0, -1, -1] + self.out_dist_tensor_spec.shape = [512, 48, 64, 48] + self.x_dist_tensor_spec.shape = [1, 64, 32] + self.y_dist_tensor_spec.shape = [512, 48, 32, 48] + self.x_dist_tensor_spec.set_dims_mapping( + [0, -1, 1] + ) # dims mapping of input should not influence inferbackward + self.y_dist_tensor_spec.set_dims_mapping( + [ + -1, + -1, + 1, + 0, + ] + ) # dims mapping of input should not influence inferbackward + self.out_dist_tensor_spec.set_dims_mapping([1, 0, -1, -1]) + + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1]) + self.assertEqual( + infered_input_dist_attrs[1].dims_mapping, [1, 0, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, 0, -1, -1] + ) + + # abmn[-1, 0, -1, 1] --> abmk[-1, 0, -1, -1], a1kn[-1, -1, -1, 1] + self.out_dist_tensor_spec.shape = [512, 48, 64, 48] + self.x_dist_tensor_spec.shape = [512, 48, 64, 32] + self.y_dist_tensor_spec.shape = [512, 1, 32, 48] + self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1] + ) + self.assertEqual( + infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + + # trans_x = true, trans_y = true, abmn[-1, -1, 0, 1] --> abmk[-1, -1, -1, 0], a1kn[-1, -1, 1, -1] + self.out_dist_tensor_spec.shape = [512, 48, 64, 48] + self.x_dist_tensor_spec.shape = [512, 48, 32, 64] + self.y_dist_tensor_spec.shape = [512, 1, 48, 32] + self.out_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + self.attrs['trans_x'] = True + self.attrs['trans_y'] = True + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0] + ) + self.assertEqual( + infered_input_dist_attrs[1].dims_mapping, [-1, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + + # # trans_x = true, trans_y = true, abmn[-1, 1, 0, 1] --> error: + # one mesh dim shard multiple tensor axes + self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0, 1]) + with self.assertRaises(RuntimeError): + self.rule.infer_backward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + [self.out_dist_tensor_spec], + self.attrs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/spmd_rules/test_reduction_rule.py b/test/auto_parallel/spmd_rules/test_reduction_rule.py index 7037f78cb43..e3b4012b6bd 100644 --- a/test/auto_parallel/spmd_rules/test_reduction_rule.py +++ b/test/auto_parallel/spmd_rules/test_reduction_rule.py @@ -63,7 +63,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # reduce on dim 0, keep_dim = true # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] @@ -79,7 +79,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # reduce on dim 1, keep_dim = false # [0, -1] --> [0, -1], [0], partial_on_dim:[] @@ -125,7 +125,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, []) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) # reduce on dim 0 and 1, keep_dim = true # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] @@ -141,7 +141,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0}) def test_multi_mesh_dim(self): process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) @@ -181,7 +181,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0, 1]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {0, 1}) infered_output_dist_attrs[0]._clean_partial_status() self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # reduction on dim 1, 2, keep_dim = false @@ -213,7 +213,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1}) infered_output_dist_attrs[0]._clean_partial_status() self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) @@ -231,7 +231,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) - self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1}) if __name__ == "__main__": diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index c2ae26f8a50..dfd8394faa1 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -343,6 +343,73 @@ TEST(LayerNormSPMDRule, Ctor) { VLOG(4) << "test2 done."; } +TEST(MatmulSPMDRuleInferBackward, Ctor) { + // build input data class + std::vector x_shape = {512, 1024, 64, 32}; + std::vector y_shape = {512, 1, 32, 48}; + std::vector out_shape = {512, 1024, 64, 48}; + + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + TensorDistAttr x_dist_attr = TensorDistAttr(); + x_dist_attr.set_process_mesh(process_mesh); + x_dist_attr.set_dims_mapping( + std::vector({-1, 1, 0, -1})); // no affect + x_dist_attr.set_dynamic_dims(std::vector({false, false})); + + TensorDistAttr y_dist_attr = TensorDistAttr(); + y_dist_attr.set_process_mesh(process_mesh); + y_dist_attr.set_dims_mapping( + std::vector({0, 1, -1, -1})); // no affect + y_dist_attr.set_dynamic_dims(std::vector({false, false})); + + TensorDistAttr out_dist_attr = TensorDistAttr(); + out_dist_attr.set_process_mesh(process_mesh); + out_dist_attr.set_dims_mapping(std::vector({-1, -1, 1, -1})); + out_dist_attr.set_dynamic_dims(std::vector({false, false})); + out_dist_attr.set_partial_status(std::vector({0})); + + DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr); + DistTensorSpec y_dist_tensor_spec = DistTensorSpec(y_shape, y_dist_attr); + DistTensorSpec out_dist_tensor_spec = + DistTensorSpec(out_shape, out_dist_attr); + + paddle::framework::AttributeMap attrs; + attrs["trans_x"] = false; + attrs["trans_y"] = false; + + SPMDRuleBase* matmul_rule = SPMDRuleMap::Instance().Get("matmul"); + + // TODO(zyc) update in future: propogate the partial in inferbackward + // abmn[-1, -1, 1, -1] + partial[0] --> abmk[-1, -1, 1, -1], a1kn[-1, -1, -1, + // -1] + std::pair, std::vector> + infered_dist_attrs = + matmul_rule->InferBackward({x_dist_tensor_spec, y_dist_tensor_spec}, + {out_dist_tensor_spec}, + attrs); + + size_t input_size = 2; + size_t output_size = 1; + EXPECT_EQ(infered_dist_attrs.first.size(), input_size); + EXPECT_EQ(infered_dist_attrs.second.size(), output_size); + + EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), + std::vector({-1, -1, 1, -1})); + EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), + std::vector({-1, -1, -1, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), + std::vector({-1, -1, 1, -1})); + EXPECT_EQ(infered_dist_attrs.first[0].is_partial(), false); + EXPECT_EQ(infered_dist_attrs.first[1].is_partial(), false); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); + + VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; +} + } // namespace auto_parallel } // namespace distributed } // namespace paddle -- GitLab