From a28e6f63a31c1ecee86818eb7778974660159cfa Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Mon, 4 Sep 2023 10:34:43 +0800 Subject: [PATCH] reshard r to p (#56833) --- paddle/fluid/pybind/auto_parallel_py.cc | 9 +++ paddle/fluid/pybind/eager_method.cc | 71 +++++++++++++--- .../distributed/auto_parallel/CMakeLists.txt | 3 +- .../distributed/auto_parallel/dist_attr.cc | 2 +- .../auto_parallel/r_to_p_reshard_function.cc | 80 +++++++++++++++++++ .../auto_parallel/r_to_p_reshard_function.h | 34 ++++++++ .../auto_parallel/s_to_r_reshard_function.cc | 1 - paddle/phi/kernels/assign_kernel.h | 8 ++ test/auto_parallel/CMakeLists.txt | 3 + test/auto_parallel/reshard_r_to_p.py | 73 +++++++++++++++++ test/auto_parallel/reshard_r_to_s.py | 1 + test/auto_parallel/test_reshard_r_to_p.py | 45 +++++++++++ test/auto_parallel/test_reshard_r_to_s.py | 1 - test/cpp/auto_parallel/spmd_rule_test.cc | 10 +-- 14 files changed, 319 insertions(+), 22 deletions(-) create mode 100644 paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h create mode 100644 test/auto_parallel/reshard_r_to_p.py create mode 100644 test/auto_parallel/test_reshard_r_to_p.py diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 6f639f145dc..bdd467fbefa 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -32,6 +32,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" @@ -157,6 +158,10 @@ void BindAutoParallel(py::module *m) { *m, "SToRReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "RToPReshardFunction", ReshardFunction) + .def(py::init<>()); + py::class_(*m, "ProcessMesh") .def(py::init<>()) .def(py::init &, @@ -338,6 +343,10 @@ void BindAutoParallel(py::module *m) { .def("_is_partial", &TensorDistAttr::is_partial) .def("_partial_dims", &TensorDistAttr::partial_dims) .def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims) + .def("_set_partial_dims", + [](TensorDistAttr &self, const std::vector &dims) { + self.set_partial_status(dims); + }) .def("_clean_partial_status", &TensorDistAttr::clean_partial_status); py::class_(*m, "SPMDRuleBase") diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 2a4699f9454..dd99770e05f 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -61,6 +61,8 @@ typedef SSIZE_T ssize_t; #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -99,6 +101,30 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) { } } +namespace { +#ifdef PADDLE_WITH_DISTRIBUTE +phi::DenseTensor ReshardXToReplicated( + phi::distributed::DistTensor* dist_tensor) { + if (!phi::distributed::IsDimsMappingReplicated( + dist_tensor->dist_attr().dims_mapping())) { + phi::distributed::TensorDistAttr dist_attr(dist_tensor->dist_attr()); + std::vector dims_mapping(dist_tensor->dims().size(), -1); + dist_attr.set_dims_mapping(dims_mapping); + + // reshard to replicate dist tensor + auto* func = + phi::distributed::ChooseProperReshardFunction(*dist_tensor, dist_attr); + auto* dev_ctx = + phi::DeviceContextPool::Instance().Get(dist_tensor->place()); + auto out_tensor = func->Eval(dev_ctx, *dist_tensor, dist_attr); + return out_tensor->value(); + } else { + return dist_tensor->value(); + } +} +#endif +} // namespace + PyDoc_STRVAR(tensor_method_numpy__doc__, // NOLINT R"DOC(numpy($self, /) -- @@ -145,15 +171,6 @@ static PyObject* tensor_method_numpy(TensorObject* self, return array; } auto tensor_dims = self->tensor.shape(); -#ifdef PADDLE_WITH_DISTRIBUTE - // Now the DistTensor's numpy() return the local tensor value - if (self->tensor.is_dist_tensor()) { - tensor_dims = phi::vectorize( - static_cast(self->tensor.impl().get()) - ->value() - .dims()); - } -#endif auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type()); auto sizeof_dtype = phi::SizeOf(self->tensor.type()); Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; // NOLINT @@ -258,12 +275,11 @@ static PyObject* tensor_method_numpy(TensorObject* self, dense_tensor->Holder()->size()); } else if (self->tensor.is_dist_tensor()) { #ifdef PADDLE_WITH_DISTRIBUTE - // TODO(chenweihang): deal with DistTensor as local DenseTensor now, - // if the local DenseTensor is shard or partial, do gather or reduce? VLOG(6) << "Getting DistTensor's numpy value"; auto* dist_tensor = static_cast(self->tensor.impl().get()); - auto& dense_tensor = dist_tensor->value(); + auto dense_tensor = ReshardXToReplicated(dist_tensor); + cpu_tensor.set_meta(dense_tensor.meta()); // deep copy auto tmp_allocation_ptr = @@ -330,7 +346,8 @@ static PyObject* tensor_method_numpy(TensorObject* self, VLOG(6) << "Getting DistTensor's numpy value"; auto* dist_tensor = static_cast(self->tensor.impl().get()); - auto& dense_tensor = dist_tensor->value(); + auto dense_tensor = ReshardXToReplicated(dist_tensor); + cpu_tensor.set_meta(dense_tensor.meta()); auto tmp_allocation_ptr = memory::Alloc(cpu_place, dense_tensor.Holder()->size()); @@ -2680,6 +2697,30 @@ static PyObject* tensor__grad_value(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__local_value(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + if (self->tensor.is_dist_tensor()) { +#ifdef PADDLE_WITH_DISTRIBUTE + phi::distributed::DistTensor* dist_tensor = + static_cast(self->tensor.impl().get()); + paddle::Tensor result( + std::make_shared(dist_tensor->value())); + return ToPyObject(result); +#else + PADDLE_THROW(platform::errors::Unavailable( + "The `_local_value` method of (Dist)Tensor is not supported " + "in the current PaddlePaddle, please recompile and install " + "PaddlePaddle " + "with the option of `WITH_DISTRIBUTE=ON`.")); +#endif + } else { + RETURN_PY_NONE + } + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__unset_fake_empty(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -3131,6 +3172,10 @@ PyMethodDef variable_methods[] = { // NOLINT (PyCFunction)(void (*)())tensor__grad_value, METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_local_value", + (PyCFunction)(void (*)())tensor__local_value, + METH_VARARGS | METH_KEYWORDS, + nullptr}, {"_unset_fake_empty", (PyCFunction)(void (*)())tensor__unset_fake_empty, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index d9d1c27ed23..3068881c82a 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -13,4 +13,5 @@ collect_srcs( inferspmd_utils.cc reshard_function.cc r_to_s_reshard_function.cc - s_to_r_reshard_function.cc) + s_to_r_reshard_function.cc + r_to_p_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index 0e091e98b27..ae58402acb0 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -227,7 +227,7 @@ bool TensorDistAttr::verify_partial_status() const { if (itr.first < 0 || itr.first >= process_mesh_.ndim()) { return false; } - if (itr.second < ReduceType::kRedSum || itr.second <= ReduceType::kRedAll) { + if (itr.second < ReduceType::kRedSum || itr.second > ReduceType::kRedAll) { return false; } } diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc new file mode 100644 index 00000000000..e941e82c98b --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/kernels/assign_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { +namespace distributed { + +bool RToPReshardFunction::IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) { + bool flag = true; + const auto& in_dist_attr = in.dist_attr(); + + const auto& in_dims_mapping = in_dist_attr.dims_mapping(); + + flag &= IsDimsMappingReplicated(in_dims_mapping); + flag &= out_dist_attr.is_partial(); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + flag &= (in_process_mesh.ndim() == 1); + flag &= (out_process_mesh.ndim() == 1); + flag &= (in_process_mesh == out_process_mesh); + + return flag; +} + +void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + const auto& out_process_mesh = out_dist_attr.process_mesh(); + int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0]; + IntArray shape(in.dims().Get(), in.dims().size()); + + if (local_rank != 0) { + // reset the physical tensor to zero + RESHARD_FUNCTOR(dev_ctx, Full, in.dtype(), shape, 0, GetMutableTensor(out)); + } else { + // assign the input value to output + if (phi::CPUContext::classof(dev_ctx)) { + Assign(static_cast(*dev_ctx), + in.value(), + GetMutableTensor(out)); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + } else if (phi::GPUContext::classof(dev_ctx)) { + Assign(static_cast(*dev_ctx), + in.value(), + GetMutableTensor(out)); +#endif + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "The assign in reshard only supported on CPU and GPU for now.")); + } + } + SetDistProps(out, in.dims(), out_dist_attr); +} + +REGISTER_RESHARD_FUNC(RToPReshardFunction); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h new file mode 100644 index 00000000000..af3bdb41d78 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" + +namespace phi { +namespace distributed { + +class RToPReshardFunction final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc index 61f57e1d669..8e4cb877c95 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc @@ -76,7 +76,6 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, in.value(), in_process_ids.size(), GetMutableTensor(out)); - std::map split_axis_to_mesh_axis = GetSplitAxisWithDimsMapping(in_dims_mapping); int64_t split_axis = split_axis_to_mesh_axis.begin()->first; diff --git a/paddle/phi/kernels/assign_kernel.h b/paddle/phi/kernels/assign_kernel.h index 7fa0350ad0e..fa331f76add 100644 --- a/paddle/phi/kernels/assign_kernel.h +++ b/paddle/phi/kernels/assign_kernel.h @@ -38,6 +38,14 @@ DenseTensor Assign(const Context& dev_ctx, const DenseTensor& x) { return out; } +template +void Assign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { + MetaTensor meta_out(out); + MetaTensor meta_x(x); + UnchangedInferMeta(meta_x, &meta_out); + AssignKernel(dev_ctx, x, out); +} + // In order to be compatible with the `AsDispensable` input in the original // assign op maker, the input parameter here needs to be dispensable, but // this looks weird diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index eef02ecb28c..458c273951e 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -85,6 +85,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s) set_tests_properties(test_reshard_r_to_s PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p) + set_tests_properties(test_reshard_r_to_p + PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) # End of unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout diff --git a/test/auto_parallel/reshard_r_to_p.py b/test/auto_parallel/reshard_r_to_p.py new file mode 100644 index 00000000000..13e899876e1 --- /dev/null +++ b/test/auto_parallel/reshard_r_to_p.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.framework import core + + +class TestReshardRToP: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._backend = os.getenv("backend") + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + a = paddle.ones(self._shape) + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + out_dist_attr._set_partial_dims([0]) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + reshard_func = core.RToPReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + if dist.get_rank() == 0: + np.testing.assert_equal( + out._local_value().numpy(), input_tensor.numpy() + ) + else: + zeros = paddle.zeros(self._shape) + np.testing.assert_equal(out._local_value().numpy(), zeros.numpy()) + + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, input_tensor._local_shape).all() + + +if __name__ == '__main__': + TestReshardRToP().run_test_case() diff --git a/test/auto_parallel/reshard_r_to_s.py b/test/auto_parallel/reshard_r_to_s.py index 814b0ef0dd7..690c42fa492 100644 --- a/test/auto_parallel/reshard_r_to_s.py +++ b/test/auto_parallel/reshard_r_to_s.py @@ -61,6 +61,7 @@ class TestReshardRToS: if out_shape[self._shard] % 2 == 0: out_shape[self._shard] = out_shape[self._shard] // 2 + np.testing.assert_equal(out.numpy(), input_tensor.numpy()) else: out_shape[self._shard] = ( out_shape[self._shard] // 2 diff --git a/test/auto_parallel/test_reshard_r_to_p.py b/test/auto_parallel/test_reshard_r_to_p.py new file mode 100644 index 00000000000..a8619722e40 --- /dev/null +++ b/test/auto_parallel/test_reshard_r_to_p.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestReshardRToP(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "shape": "(10, 20)", + "dtype": "float32", + "seeds": str(self._seeds), + } + self._changeable_envs = { + "shape": ["(10, 20)"], + "backend": ["cpu", "gpu"], + } + + def test_reshard_r_to_p(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "reshard_r_to_p.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_reshard_r_to_s.py b/test/auto_parallel/test_reshard_r_to_s.py index 187fa40918d..68699885094 100644 --- a/test/auto_parallel/test_reshard_r_to_s.py +++ b/test/auto_parallel/test_reshard_r_to_s.py @@ -21,7 +21,6 @@ class TestReshardRToS(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "shape": "(10, 20)", "dtype": "float32", "seeds": str(self._seeds), } diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 30907b707aa..079f2d84ea0 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -209,8 +209,8 @@ TEST(MatmulSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); VLOG(4) << "test8 done." << std::endl << std::endl << std::endl; - // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = - // abcmn[-1, -1, -1, 1] partial[0]: done + // abcmk[-1, -1, 0, 1]+trans_x=true, kn[1, 0]+trans_y=true --> abcmk[-1, -1, + // 0, -1],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0]: done x_dist_attr.set_dims_mapping({-1, -1, 0, 1}); y_dist_attr.set_dims_mapping({1, 0}); x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); @@ -221,7 +221,8 @@ TEST(MatmulSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({-1, -1, 0, 1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, 0})); + std::vector( + {-1, 0})); // confilct and should be changed to [-1, 0] EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), std::vector({-1, -1, 1, -1})); EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), @@ -229,8 +230,7 @@ TEST(MatmulSPMDRule, Ctor) { VLOG(4) << infered_dist_attrs.second[0].to_string(); infered_dist_attrs.second[0].clean_partial_status(); EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - infered_dist_attrs.second[0].set_partial_status(std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[0].verify_partial_status(), false); + // EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status(std::vector({1}))); VLOG(4) << "test9 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = -- GitLab