未验证 提交 a28e6f63 编写于 作者: L LiYuRio 提交者: GitHub

reshard r to p (#56833)

上级 413ca989
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
...@@ -157,6 +158,10 @@ void BindAutoParallel(py::module *m) { ...@@ -157,6 +158,10 @@ void BindAutoParallel(py::module *m) {
*m, "SToRReshardFunction", ReshardFunction) *m, "SToRReshardFunction", ReshardFunction)
.def(py::init<>()); .def(py::init<>());
py::class_<phi::distributed::RToPReshardFunction>(
*m, "RToPReshardFunction", ReshardFunction)
.def(py::init<>());
py::class_<ProcessMesh>(*m, "ProcessMesh") py::class_<ProcessMesh>(*m, "ProcessMesh")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::vector<int64_t> &, .def(py::init<const std::vector<int64_t> &,
...@@ -338,6 +343,10 @@ void BindAutoParallel(py::module *m) { ...@@ -338,6 +343,10 @@ void BindAutoParallel(py::module *m) {
.def("_is_partial", &TensorDistAttr::is_partial) .def("_is_partial", &TensorDistAttr::is_partial)
.def("_partial_dims", &TensorDistAttr::partial_dims) .def("_partial_dims", &TensorDistAttr::partial_dims)
.def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims) .def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims)
.def("_set_partial_dims",
[](TensorDistAttr &self, const std::vector<int64_t> &dims) {
self.set_partial_status(dims);
})
.def("_clean_partial_status", &TensorDistAttr::clean_partial_status); .def("_clean_partial_status", &TensorDistAttr::clean_partial_status);
py::class_<SPMDRuleBase>(*m, "SPMDRuleBase") py::class_<SPMDRuleBase>(*m, "SPMDRuleBase")
......
...@@ -61,6 +61,8 @@ typedef SSIZE_T ssize_t; ...@@ -61,6 +61,8 @@ typedef SSIZE_T ssize_t;
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -99,6 +101,30 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) { ...@@ -99,6 +101,30 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
} }
} }
namespace {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::DenseTensor ReshardXToReplicated(
phi::distributed::DistTensor* dist_tensor) {
if (!phi::distributed::IsDimsMappingReplicated(
dist_tensor->dist_attr().dims_mapping())) {
phi::distributed::TensorDistAttr dist_attr(dist_tensor->dist_attr());
std::vector<int64_t> dims_mapping(dist_tensor->dims().size(), -1);
dist_attr.set_dims_mapping(dims_mapping);
// reshard to replicate dist tensor
auto* func =
phi::distributed::ChooseProperReshardFunction(*dist_tensor, dist_attr);
auto* dev_ctx =
phi::DeviceContextPool::Instance().Get(dist_tensor->place());
auto out_tensor = func->Eval(dev_ctx, *dist_tensor, dist_attr);
return out_tensor->value();
} else {
return dist_tensor->value();
}
}
#endif
} // namespace
PyDoc_STRVAR(tensor_method_numpy__doc__, // NOLINT PyDoc_STRVAR(tensor_method_numpy__doc__, // NOLINT
R"DOC(numpy($self, /) R"DOC(numpy($self, /)
-- --
...@@ -145,15 +171,6 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -145,15 +171,6 @@ static PyObject* tensor_method_numpy(TensorObject* self,
return array; return array;
} }
auto tensor_dims = self->tensor.shape(); auto tensor_dims = self->tensor.shape();
#ifdef PADDLE_WITH_DISTRIBUTE
// Now the DistTensor's numpy() return the local tensor value
if (self->tensor.is_dist_tensor()) {
tensor_dims = phi::vectorize(
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get())
->value()
.dims());
}
#endif
auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type()); auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
auto sizeof_dtype = phi::SizeOf(self->tensor.type()); auto sizeof_dtype = phi::SizeOf(self->tensor.type());
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; // NOLINT Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; // NOLINT
...@@ -258,12 +275,11 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -258,12 +275,11 @@ static PyObject* tensor_method_numpy(TensorObject* self,
dense_tensor->Holder()->size()); dense_tensor->Holder()->size());
} else if (self->tensor.is_dist_tensor()) { } else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(chenweihang): deal with DistTensor as local DenseTensor now,
// if the local DenseTensor is shard or partial, do gather or reduce?
VLOG(6) << "Getting DistTensor's numpy value"; VLOG(6) << "Getting DistTensor's numpy value";
auto* dist_tensor = auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get()); static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
auto& dense_tensor = dist_tensor->value(); auto dense_tensor = ReshardXToReplicated(dist_tensor);
cpu_tensor.set_meta(dense_tensor.meta()); cpu_tensor.set_meta(dense_tensor.meta());
// deep copy // deep copy
auto tmp_allocation_ptr = auto tmp_allocation_ptr =
...@@ -330,7 +346,8 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -330,7 +346,8 @@ static PyObject* tensor_method_numpy(TensorObject* self,
VLOG(6) << "Getting DistTensor's numpy value"; VLOG(6) << "Getting DistTensor's numpy value";
auto* dist_tensor = auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get()); static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
auto& dense_tensor = dist_tensor->value(); auto dense_tensor = ReshardXToReplicated(dist_tensor);
cpu_tensor.set_meta(dense_tensor.meta()); cpu_tensor.set_meta(dense_tensor.meta());
auto tmp_allocation_ptr = auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor.Holder()->size()); memory::Alloc(cpu_place, dense_tensor.Holder()->size());
...@@ -2680,6 +2697,30 @@ static PyObject* tensor__grad_value(TensorObject* self, ...@@ -2680,6 +2697,30 @@ static PyObject* tensor__grad_value(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__local_value(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
paddle::Tensor result(
std::make_shared<phi::DenseTensor>(dist_tensor->value()));
return ToPyObject(result);
#else
PADDLE_THROW(platform::errors::Unavailable(
"The `_local_value` method of (Dist)Tensor is not supported "
"in the current PaddlePaddle, please recompile and install "
"PaddlePaddle "
"with the option of `WITH_DISTRIBUTE=ON`."));
#endif
} else {
RETURN_PY_NONE
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__unset_fake_empty(TensorObject* self, static PyObject* tensor__unset_fake_empty(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -3131,6 +3172,10 @@ PyMethodDef variable_methods[] = { // NOLINT ...@@ -3131,6 +3172,10 @@ PyMethodDef variable_methods[] = { // NOLINT
(PyCFunction)(void (*)())tensor__grad_value, (PyCFunction)(void (*)())tensor__grad_value,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
nullptr}, nullptr},
{"_local_value",
(PyCFunction)(void (*)())tensor__local_value,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_unset_fake_empty", {"_unset_fake_empty",
(PyCFunction)(void (*)())tensor__unset_fake_empty, (PyCFunction)(void (*)())tensor__unset_fake_empty,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
......
...@@ -13,4 +13,5 @@ collect_srcs( ...@@ -13,4 +13,5 @@ collect_srcs(
inferspmd_utils.cc inferspmd_utils.cc
reshard_function.cc reshard_function.cc
r_to_s_reshard_function.cc r_to_s_reshard_function.cc
s_to_r_reshard_function.cc) s_to_r_reshard_function.cc
r_to_p_reshard_function.cc)
...@@ -227,7 +227,7 @@ bool TensorDistAttr::verify_partial_status() const { ...@@ -227,7 +227,7 @@ bool TensorDistAttr::verify_partial_status() const {
if (itr.first < 0 || itr.first >= process_mesh_.ndim()) { if (itr.first < 0 || itr.first >= process_mesh_.ndim()) {
return false; return false;
} }
if (itr.second < ReduceType::kRedSum || itr.second <= ReduceType::kRedAll) { if (itr.second < ReduceType::kRedSum || itr.second > ReduceType::kRedAll) {
return false; return false;
} }
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/kernels/assign_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace phi {
namespace distributed {
bool RToPReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
flag &= IsDimsMappingReplicated(in_dims_mapping);
flag &= out_dist_attr.is_partial();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);
return flag;
}
void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const auto& out_process_mesh = out_dist_attr.process_mesh();
int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0];
IntArray shape(in.dims().Get(), in.dims().size());
if (local_rank != 0) {
// reset the physical tensor to zero
RESHARD_FUNCTOR(dev_ctx, Full, in.dtype(), shape, 0, GetMutableTensor(out));
} else {
// assign the input value to output
if (phi::CPUContext::classof(dev_ctx)) {
Assign(static_cast<const CPUContext&>(*dev_ctx),
in.value(),
GetMutableTensor(out));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (phi::GPUContext::classof(dev_ctx)) {
Assign(static_cast<const GPUContext&>(*dev_ctx),
in.value(),
GetMutableTensor(out));
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"The assign in reshard only supported on CPU and GPU for now."));
}
}
SetDistProps(out, in.dims(), out_dist_attr);
}
REGISTER_RESHARD_FUNC(RToPReshardFunction);
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
namespace phi {
namespace distributed {
class RToPReshardFunction final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};
} // namespace distributed
} // namespace phi
...@@ -76,7 +76,6 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, ...@@ -76,7 +76,6 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
in.value(), in.value(),
in_process_ids.size(), in_process_ids.size(),
GetMutableTensor(out)); GetMutableTensor(out));
std::map<int64_t, int64_t> split_axis_to_mesh_axis = std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping); GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first; int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
......
...@@ -38,6 +38,14 @@ DenseTensor Assign(const Context& dev_ctx, const DenseTensor& x) { ...@@ -38,6 +38,14 @@ DenseTensor Assign(const Context& dev_ctx, const DenseTensor& x) {
return out; return out;
} }
template <typename Context>
void Assign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
MetaTensor meta_out(out);
MetaTensor meta_x(x);
UnchangedInferMeta(meta_x, &meta_out);
AssignKernel<Context>(dev_ctx, x, out);
}
// In order to be compatible with the `AsDispensable` input in the original // In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but // assign op maker, the input parameter here needs to be dispensable, but
// this looks weird // this looks weird
......
...@@ -85,6 +85,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -85,6 +85,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s) py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s)
set_tests_properties(test_reshard_r_to_s set_tests_properties(test_reshard_r_to_s
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
set_tests_properties(test_reshard_r_to_p
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout # End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.framework import core
class TestReshardRToP:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())
dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)
out_dist_attr._set_partial_dims([0])
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
reshard_func = core.RToPReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
if dist.get_rank() == 0:
np.testing.assert_equal(
out._local_value().numpy(), input_tensor.numpy()
)
else:
zeros = paddle.zeros(self._shape)
np.testing.assert_equal(out._local_value().numpy(), zeros.numpy())
assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, input_tensor._local_shape).all()
if __name__ == '__main__':
TestReshardRToP().run_test_case()
...@@ -61,6 +61,7 @@ class TestReshardRToS: ...@@ -61,6 +61,7 @@ class TestReshardRToS:
if out_shape[self._shard] % 2 == 0: if out_shape[self._shard] % 2 == 0:
out_shape[self._shard] = out_shape[self._shard] // 2 out_shape[self._shard] = out_shape[self._shard] // 2
np.testing.assert_equal(out.numpy(), input_tensor.numpy())
else: else:
out_shape[self._shard] = ( out_shape[self._shard] = (
out_shape[self._shard] // 2 out_shape[self._shard] // 2
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestReshardRToP(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
}
self._changeable_envs = {
"shape": ["(10, 20)"],
"backend": ["cpu", "gpu"],
}
def test_reshard_r_to_p(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_r_to_p.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
...@@ -21,7 +21,6 @@ class TestReshardRToS(test_base.CommunicationTestDistBase): ...@@ -21,7 +21,6 @@ class TestReshardRToS(test_base.CommunicationTestDistBase):
def setUp(self): def setUp(self):
super().setUp(num_of_devices=2, timeout=120) super().setUp(num_of_devices=2, timeout=120)
self._default_envs = { self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32", "dtype": "float32",
"seeds": str(self._seeds), "seeds": str(self._seeds),
} }
......
...@@ -209,8 +209,8 @@ TEST(MatmulSPMDRule, Ctor) { ...@@ -209,8 +209,8 @@ TEST(MatmulSPMDRule, Ctor) {
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test8 done." << std::endl << std::endl << std::endl; VLOG(4) << "test8 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = // abcmk[-1, -1, 0, 1]+trans_x=true, kn[1, 0]+trans_y=true --> abcmk[-1, -1,
// abcmn[-1, -1, -1, 1] partial[0]: done // 0, -1],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0]: done
x_dist_attr.set_dims_mapping({-1, -1, 0, 1}); x_dist_attr.set_dims_mapping({-1, -1, 0, 1});
y_dist_attr.set_dims_mapping({1, 0}); y_dist_attr.set_dims_mapping({1, 0});
x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr);
...@@ -221,7 +221,8 @@ TEST(MatmulSPMDRule, Ctor) { ...@@ -221,7 +221,8 @@ TEST(MatmulSPMDRule, Ctor) {
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 0, 1})); std::vector<int64_t>({-1, -1, 0, 1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, 0})); std::vector<int64_t>(
{-1, 0})); // confilct and should be changed to [-1, 0]
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 1, -1})); std::vector<int64_t>({-1, -1, 1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
...@@ -229,8 +230,7 @@ TEST(MatmulSPMDRule, Ctor) { ...@@ -229,8 +230,7 @@ TEST(MatmulSPMDRule, Ctor) {
VLOG(4) << infered_dist_attrs.second[0].to_string(); VLOG(4) << infered_dist_attrs.second[0].to_string();
infered_dist_attrs.second[0].clean_partial_status(); infered_dist_attrs.second[0].clean_partial_status();
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
infered_dist_attrs.second[0].set_partial_status(std::vector<int64_t>({1})); // EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status(std::vector<int64_t>({1})));
EXPECT_EQ(infered_dist_attrs.second[0].verify_partial_status(), false);
VLOG(4) << "test9 done." << std::endl << std::endl << std::endl; VLOG(4) << "test9 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册