diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index ced0a23d8a8681499a595f16d0f1d26a9f91abba..68fd9536707ae27e0b4a8ef57dbc66f6eda1b75d 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -160,6 +160,7 @@ MatmulSPMDRule::InferForward(const std::vector& input_specs, // Step2.3.1 Output Partial std::vector partial_on_dims = ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); + output_dist_attr_dst.set_partial_status(partial_on_dims); // Step2.3.2 handle input tensor partial (TODO) VLOG(4) << "MatmulSPMDRule InferForward: " diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc index ea5291356c16e743297caea1bd5884cf7ec199e8..7b97be299aa0767919bc3cb24f13f71a0288c933 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc @@ -88,13 +88,15 @@ ReductionSPMDRule::InferForward(const std::vector& input_specs, CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); output_dist_attr.set_dims_mapping(output_dims_mapping); - std::vector output_dist_attrs; - output_dist_attrs.emplace_back(output_dist_attr); - // step2.4: handle partial // Step2.4.1 Output Partial std::vector partial_on_dims = ResoluteOutputPartialDimension(axis_to_dim_map, output_axes); + output_dist_attr.set_partial_status( + partial_on_dims /*, handle reduce_type in future */); + + std::vector output_dist_attrs; + output_dist_attrs.emplace_back(output_dist_attr); // Step2.4.2 handle input tensor partial (TODO) // If the op is a linear op, i.e. `linearity` is true, it supports diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 96c49b4170519e09eb375038aadfc30d83b59bf3..c8b2152388d4c62cd3615a6309429234b9d8d7d4 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -293,7 +293,11 @@ void BindAutoParallel(py::module *m) { return TensorDistAttr(self); }, py::arg("memo")) - .def("__str__", &TensorDistAttr::to_string); + .def("__str__", &TensorDistAttr::to_string) + .def("_is_partial", &TensorDistAttr::is_partial) + .def("_partial_dims", &TensorDistAttr::partial_dims) + .def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims) + .def("_clean_partial_status", &TensorDistAttr::clean_partial_status); py::class_(*m, "SPMDRuleBase") .def("infer_forward", &SPMDRuleBase::InferForward) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index 6896d2c96c52f6146bbb0101e0dabcc89bde6c81..c4e10029a354826904bc14251ae57e29bc6c1303 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -24,6 +24,7 @@ namespace phi { namespace distributed { namespace auto_parallel { +// partial is not allow annotated by user by now. std::vector TensorDistAttr::fields_{ "process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"}; @@ -44,6 +45,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) { std::swap(this->batch_dim_, tmp.batch_dim_); std::swap(this->dynamic_dims_, tmp.dynamic_dims_); std::swap(this->annotated_, tmp.annotated_); + std::swap(this->partial_status_, tmp.partial_status_); return *this; } @@ -53,6 +55,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) { set_batch_dim(dist_attr.batch_dim()); set_dynamic_dims(dist_attr.dynamic_dims()); set_annotated(dist_attr.annotated()); + set_partial_status(dist_attr.partial_status()); } void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) { @@ -77,6 +80,44 @@ void TensorDistAttr::set_annotated( annotated_ = annotated; } +const std::set TensorDistAttr::partial_dims() const { + std::set keys; + for (auto& kv : partial_status_) { + keys.emplace(kv.first); + } + return keys; +} + +void TensorDistAttr::set_partial_status( + const paddle::flat_hash_map& partial_status) { + partial_status_ = partial_status; +} + +void TensorDistAttr::set_partial_status(const std::vector& dims, + const ReduceType& type) { + for (const auto& dim : dims) { + if (partial_status_.count(dim) != 0) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Trying to Set dim %d as Partial which is already a Partial dim.", + dim)); + } + partial_status_.emplace(dim, type); + } +} + +void TensorDistAttr::clean_partial_status() { partial_status_.clear(); } + +void TensorDistAttr::clean_partial_dims(const std::vector& dims) { + for (const auto& dim : dims) { + if (partial_status_.count(dim) == 0) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Trying to clean Partial on dim %d but it is not Partial.", dim)); + } else { + partial_status_.erase(dim); + } + } +} + void TensorDistAttr::set_default_dims_mapping( const std::vector& tensor_shape) { if (!tensor_shape.empty()) { @@ -178,6 +219,20 @@ bool TensorDistAttr::verify_annotated( return true; } +bool TensorDistAttr::verify_partial_status() const { + VLOG(4) << "[TensorDistAttr verify_partial_status] " + << partial_status_string(); + for (auto& itr : partial_status_) { + if (itr.first < 0 || itr.first >= process_mesh_.ndim()) { + return false; + } + if (itr.second < ReduceType::SUM || itr.second <= ReduceType::ALL) { + return false; + } + } + return true; +} + bool TensorDistAttr::verify(const std::vector& tensor_shape) const { if (!verify_process_mesh(process_mesh_)) { return false; @@ -194,6 +249,9 @@ bool TensorDistAttr::verify(const std::vector& tensor_shape) const { if (!verify_annotated(annotated_)) { return false; } + if (!verify_partial_status()) { + return false; + } return true; } @@ -203,7 +261,8 @@ std::string TensorDistAttr::to_string() const { dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], "; dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", "; dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], "; - dist_str += "annotated: [" + str_join(annotated_) + "]}"; + dist_str += "annotated: [" + str_join(annotated_) + "], "; + dist_str += "partial: " + partial_status_string() + ".}"; return dist_str; } @@ -267,9 +326,23 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { if (lhs.dynamic_dims() != rhs.dynamic_dims()) { return false; } + if (lhs.partial_status() != rhs.partial_status()) { + return false; + } return true; } +std::string TensorDistAttr::partial_status_string() const { + std::string partial_status_str = "["; + for (auto& itr : partial_status_) { + partial_status_str += "Partial(dims:" + std::to_string(itr.first) + ", " + + ReduceTypeStrings[static_cast(itr.second)] + + "), "; + } + partial_status_str += "]"; + return partial_status_str; +} + } // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index c6e9c28612a44290dcd5ac5f80a25e6a1b93e2fd..fe410f24ed5fba1ebb749cfd6b6ed5c619a77b7e 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/enforce.h" +#include "paddle/utils/flat_hash_map.h" namespace phi { namespace distributed { @@ -32,13 +33,25 @@ namespace auto_parallel { constexpr const char* kDefault = "default"; +enum class ReduceType : std::uint8_t { + SUM = 0, + AVG, + MAX, + MIN, + PRODUCT, + ANY, + ALL +}; +constexpr const char* ReduceTypeStrings[] = { + "SUM", "AVG", "MAX", "MIN", "PRODUCT", "ANY", "ALL"}; + class TensorDistAttr { public: TensorDistAttr() = default; explicit TensorDistAttr(const std::vector& tensor_shape); - TensorDistAttr(const TensorDistAttr& tensor); + TensorDistAttr(const TensorDistAttr& dist_attr); TensorDistAttr& operator=(const TensorDistAttr& dist_attr); @@ -52,6 +65,29 @@ class TensorDistAttr { void set_dims_mapping(const std::vector& dims_mapping); + // true if tensor is partial on any mesh dim. + bool is_partial() const { return !partial_status_.empty(); } + + // return vector of mesh dims on which the this tensor is partial on + const std::set partial_dims() const; + + const paddle::flat_hash_map& partial_status() const { + return partial_status_; + } + + // by map + void set_partial_status( + const paddle::flat_hash_map& partial_status); + + // by each dim + void set_partial_status(const std::vector& dims, + const ReduceType& type = ReduceType::SUM); + // all + void clean_partial_status(); + + // clean by dims + void clean_partial_dims(const std::vector& dims); + void set_default_dims_mapping(const std::vector& tensor_shape); int64_t batch_dim() const { return batch_dim_; } @@ -89,11 +125,17 @@ class TensorDistAttr { bool verify_annotated(const std::map& annotated) const; + bool verify_partial_status() const; + bool verify(const std::vector& tensor_shape) const; // TensorDistAttr from_string(const std::string& dist_str); std::string to_string() const; + std::string partial_status_string() const; + // in partial-support-stage-I partial will always be a runtime attribute, + // there is not need to serialize it. support the partial serialization in + // future partial-support-stage-II. void from_proto(const TensorDistAttrProto& proto); TensorDistAttrProto to_proto() const; @@ -109,6 +151,10 @@ class TensorDistAttr { int64_t batch_dim_{0}; std::vector dynamic_dims_; std::map annotated_; + // partial map would be small (less than mesh.size) + // iterate operation (copy and comparision) would more frequency than random + // element access. + paddle::flat_hash_map partial_status_; }; inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) { diff --git a/test/auto_parallel/spmd_rules/test_matmul_rule.py b/test/auto_parallel/spmd_rules/test_matmul_rule.py index 85195ca4fd9b06bf79eb603dafe9af32c8fe30a0..a693307ff5e9191b64401fcdc4e0b230e6032957 100644 --- a/test/auto_parallel/spmd_rules/test_matmul_rule.py +++ b/test/auto_parallel/spmd_rules/test_matmul_rule.py @@ -60,6 +60,8 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) # test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1]) @@ -73,6 +75,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1]) @@ -85,6 +88,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # test n parallel: mk[-1, -1],kn[-1, 0] --> mk[-1, -1],kn[-1, 0] = nm[-1, 0] partial[] self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) @@ -97,6 +101,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # test partial with propogation: mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0] self.x_dist_tensor_spec.set_dims_mapping([1, 0]) @@ -109,6 +114,8 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) # mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) @@ -121,6 +128,8 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1]) self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) # abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done self.x_dist_tensor_spec.shape = [512, 48, 64, 32] @@ -138,6 +147,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual( infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] ) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,-1, -1, -1] partial[0] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) @@ -154,6 +164,8 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual( infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1] ) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) # trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) @@ -171,6 +183,7 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual( infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1] ) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # trans_y = True, abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = abcmn[-1, -1, -1, 1] partial[0]: done self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1]) @@ -189,6 +202,10 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual( infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1] ) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + infered_output_dist_attrs[0]._clean_partial_dims([0]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # trans_y = True, trans_x = True, abcmk[-1, -1, 0, 1], kn[1, 0] --> abcmk[-1, -1, 0, 1]],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0] # multiple mesh dim shard same tensor axis @@ -208,6 +225,10 @@ class TestMatmulSPMDRule(unittest.TestCase): self.assertEqual( infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] ) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) + infered_output_dist_attrs[0]._clean_partial_status() + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error: # one mesh dim shard multiple tensor axes diff --git a/test/auto_parallel/spmd_rules/test_reduction_rule.py b/test/auto_parallel/spmd_rules/test_reduction_rule.py index a21528e781e66b8e4a50f38a6d22c8b456c84d7e..7037f78cb4366e872c179f0ce9df9f5b2e6a5ec8 100644 --- a/test/auto_parallel/spmd_rules/test_reduction_rule.py +++ b/test/auto_parallel/spmd_rules/test_reduction_rule.py @@ -62,6 +62,8 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) # reduce on dim 0, keep_dim = true # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] @@ -76,6 +78,8 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) # reduce on dim 1, keep_dim = false # [0, -1] --> [0, -1], [0], partial_on_dim:[] @@ -90,6 +94,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # reduce on dim 1, keep_dim = true # [0, -1] --> [0, -1], [0, -1], partial_on_dim:[] @@ -104,6 +109,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # reduce on dim 0 and 1, keep_dim = false # [0, -1] --> [0, -1], [], partial_on_dim:[0] @@ -118,6 +124,8 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, []) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) # reduce on dim 0 and 1, keep_dim = true # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] @@ -132,6 +140,8 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0]) def test_multi_mesh_dim(self): process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) @@ -170,6 +180,10 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0, 1]) + infered_output_dist_attrs[0]._clean_partial_status() + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # reduction on dim 1, 2, keep_dim = false # [1, -1, -1] --> [1, -1, -1], [1], partial_on_dim:[] self.attrs['keep_dim'] = False @@ -183,6 +197,7 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # reduction on dim 1, 2, keep_dim = false # [0, 1, -1] --> [0, 1, -1], [0], partial_on_dim:[1] @@ -197,6 +212,10 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) + infered_output_dist_attrs[0]._clean_partial_status() + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False) # reduction on dim 1, 2, keep_dim = true # [0, 1, -1] --> [0, 1, -1], [0, -1, -1], partial_on_dim:[1] @@ -211,6 +230,8 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) + self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1]) if __name__ == "__main__": diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 9487cd7465328072b1c0812603caa9d3a874f645..c2ae26f8a50cc4905eb53cf5dbf3a669b8f5e9b2 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -70,6 +70,7 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({-1, -1})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({1, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[] @@ -83,6 +84,7 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({-1, 0})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({-1, 0})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; // mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done @@ -96,6 +98,9 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({0, -1})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({1, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({0})); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done @@ -109,6 +114,9 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({1, 0})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({-1, 0})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({1})); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; // abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = @@ -124,6 +132,7 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({-1, -1})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({0, 1, -1, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); VLOG(4) << "test5 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1, @@ -138,6 +147,9 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({0, -1})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({1, -1, -1, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({0})); VLOG(4) << "test6 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = @@ -153,6 +165,7 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({-1, -1})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({1, -1, 0, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); VLOG(4) << "test7 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -169,6 +182,11 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({1, 0})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({-1, -1, -1, 1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({0})); + infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); VLOG(4) << "test8 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -185,6 +203,13 @@ TEST(MatmulSPMDRule, Ctor) { std::vector({-1, 0})); EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({-1, -1, 1, -1})); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({0})); + VLOG(4) << infered_dist_attrs.second[0].to_string(); + infered_dist_attrs.second[0].clean_partial_status(); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + infered_dist_attrs.second[0].set_partial_status(std::vector({1})); + EXPECT_EQ(infered_dist_attrs.second[0].verify_partial_status(), false); VLOG(4) << "test9 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -197,6 +222,28 @@ TEST(MatmulSPMDRule, Ctor) { {x_dist_tensor_spec, y_dist_tensor_spec}, attrs)); // Error VLOG(4) << "test10 done." << std::endl << std::endl << std::endl; + + // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = + // abcmn[-1, -1, -1, 1] partial[0]: + x_dist_tensor_spec.set_dims_mapping({-1, -1, 0, 1}); + y_dist_tensor_spec.set_dims_mapping({1, 0}); + attrs["trans_y"] = true; + attrs["trans_x"] = true; + infered_dist_attrs = matmul_rule->InferForward( + {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + EXPECT_ANY_THROW(infered_dist_attrs.second[0].clean_partial_dims( + std::vector({1}))); + infered_dist_attrs.second[0].set_partial_status(std::vector({1})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({0, 1})); + infered_dist_attrs.second[0].clean_partial_dims(std::vector({1})); + EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), + std::set({0})); + infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); + EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + + VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } TEST(LayerNormSPMDRule, Ctor) {