未验证 提交 4569ae13 编写于 作者: L LiYuRio 提交者: GitHub

Implement reshard from s to r with same process_mesh (#56039)

上级 bfa65993
...@@ -18,16 +18,20 @@ ...@@ -18,16 +18,20 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h" #include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
#include "paddle/utils/pybind.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#endif #endif
namespace py = pybind11; namespace py = pybind11;
...@@ -111,7 +115,43 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { ...@@ -111,7 +115,43 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
void BindAutoParallel(py::module *m) { void BindAutoParallel(py::module *m) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
py::class_<phi::distributed::RToSReshardFunction>(*m, "RToSReshardFunction") auto ReshardFunction =
py::class_<phi::distributed::ReshardFunction>(*m, "ReshardFunction")
.def(
"is_suitable",
[](phi::distributed::ReshardFunction &self,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
return self.IsSuitable(*p_dist, dist_attr);
},
py::call_guard<py::gil_scoped_release>())
.def(
"eval",
[](phi::distributed::ReshardFunction &self,
phi::DeviceContext *dev_ctx,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
auto res_dist = self.Eval(dev_ctx, *p_dist, dist_attr);
return paddle::Tensor(res_dist);
},
py::call_guard<py::gil_scoped_release>());
py::class_<phi::distributed::RToSReshardFunction>(
*m, "RToSReshardFunction", ReshardFunction)
.def(py::init<>());
py::class_<phi::distributed::SToRReshardFunction>(
*m, "SToRReshardFunction", ReshardFunction)
.def(py::init<>()); .def(py::init<>());
#endif #endif
......
...@@ -3,8 +3,15 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto) ...@@ -3,8 +3,15 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "") set(DISTRIBUTED_SRCS "")
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc list(
reshard_split_functor.cc r_to_s_reshard_function.cc) APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
endif() endif()
collect_srcs( collect_srcs(
......
...@@ -48,7 +48,7 @@ bool RToSReshardFunction::IsSuitable( ...@@ -48,7 +48,7 @@ bool RToSReshardFunction::IsSuitable(
} }
std::shared_ptr<DistTensor> RToSReshardFunction::Eval( std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
const phi::DeviceContext& dev_ctx, phi::DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) { const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping(); const auto& out_dims_mapping = out_dist_attr->dims_mapping();
...@@ -85,7 +85,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval( ...@@ -85,7 +85,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
num_of_process, in.dims()[split_axis] / num_of_process)); num_of_process, in.dims()[split_axis] / num_of_process));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor( std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
dev_ctx, in_physical_tensor_cur_rank, sections, split_axis); *dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
VLOG(3) << "The current process will remain the idx " VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor"; << coord_in_mesh[mesh_axis] << " piece of tensor";
......
...@@ -29,7 +29,7 @@ class RToSReshardFunction final : public ReshardFunction { ...@@ -29,7 +29,7 @@ class RToSReshardFunction final : public ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override; const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval( std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx, DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override; const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
}; };
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/all_gather_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids) {
DenseTensor out;
int64_t world_size = process_ids.size();
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids);
dev_ctx->SetCommContext(comm_context);
if (phi::CPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const GPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
namespace phi {
class DenseTensor;
class DeviceContext;
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids);
} // namespace distributed
} // namespace phi
...@@ -36,7 +36,7 @@ class ReshardFunction { ...@@ -36,7 +36,7 @@ class ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0; const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval( virtual std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx, DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0; const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
}; };
......
...@@ -16,7 +16,12 @@ ...@@ -16,7 +16,12 @@
#include <cstdlib> #include <cstdlib>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -109,6 +114,21 @@ std::string GetMasterEndpoint() { ...@@ -109,6 +114,21 @@ std::string GetMasterEndpoint() {
return master_endpoint; return master_endpoint;
} }
std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}
} // namespace } // namespace
std::string GetMasterAddr() { std::string GetMasterAddr() {
...@@ -133,5 +153,41 @@ std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() { ...@@ -133,5 +153,41 @@ std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
return store; return store;
} }
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = GenUniqueCommKey(process_ids);
if (!CommContextManager::GetInstance().Has(unique_comm_key)) {
int64_t world_size = process_ids.size();
int64_t rank = GetLocalRankInParticipate(process_ids);
VLOG(3) << "local world size: " << world_size << " local rank: " << rank;
auto store = CreateOrGetGlobalTCPStore();
if (phi::CPUContext::classof(&dev_ctx)) {
#if defined(PADDLE_WITH_GLOO)
CommContextManager::CreateGlooCommContext(
store, unique_comm_key, rank, world_size);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (phi::GPUContext::classof(&dev_ctx)) {
CommContextManager::CreateNCCLCommContext(
store, unique_comm_key, rank, world_size);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"CommContext is only supported on CPU and GPU for now, other devices "
"will be supported later."));
#endif
}
}
auto* comm_context = CommContextManager::GetInstance().Get(unique_comm_key);
return comm_context;
}
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -23,7 +23,11 @@ ...@@ -23,7 +23,11 @@
#include "paddle/phi/core/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi { namespace phi {
class DeviceContext;
namespace distributed { namespace distributed {
class CommContext;
namespace auto_parallel { namespace auto_parallel {
class ProcessMesh; class ProcessMesh;
...@@ -48,6 +52,13 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh); ...@@ -48,6 +52,13 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping( std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping); const std::vector<int64_t>& dims_mapping);
// Create a comm context of the input process_ids. Once the newly comm context
// created, it will be cached in the global instance, and get from the global
// cache later. If the input dev_ctx is GPU, then nccl comm context will be
// created. If the input dev_ctx is CPU, then gloo comm context will be created.
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);
int64_t GetCurGlobalRank(); int64_t GetCurGlobalRank();
std::string GetMasterAddr(); std::string GetMasterAddr();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi {
namespace distributed {
bool SToRReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
flag &= IsDimsMappingShard(in_dims_mapping);
flag &= IsDimsMappingReplicated(out_dims_mapping);
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);
return flag;
}
std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
// Since the precondition ensure the out_process_ids is equal to the
// in_process_ids, so the participate process ids mush equal to either
// in_process_ids or out_process_ids.
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather), out_dist_attr);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
namespace phi {
namespace distributed {
class SToRReshardFunction final : public ReshardFunction {
public:
SToRReshardFunction() = default;
~SToRReshardFunction() = default;
bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};
} // namespace distributed
} // namespace phi
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi { namespace phi {
...@@ -24,4 +25,16 @@ void AllGatherKernel(const Context& dev_ctx, ...@@ -24,4 +25,16 @@ void AllGatherKernel(const Context& dev_ctx,
int nranks, int nranks,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void AllGather(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
MetaTensor out_meta(*out);
MetaTensor* out_meta_ptr = &out_meta;
AllGatherInferMeta(phi::MetaTensor(x), nranks, out_meta_ptr);
AllGatherKernel<T, Context>(dev_ctx, x, nranks, out);
}
} // namespace phi } // namespace phi
...@@ -61,5 +61,6 @@ PD_REGISTER_KERNEL(all_gather, ...@@ -61,5 +61,6 @@ PD_REGISTER_KERNEL(all_gather,
bool, bool,
int8_t, int8_t,
uint8_t, uint8_t,
int16_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -69,6 +69,7 @@ PD_REGISTER_KERNEL(all_gather, ...@@ -69,6 +69,7 @@ PD_REGISTER_KERNEL(all_gather,
int, int,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
int64_t, int64_t,
bool, bool,
phi::dtype::bfloat16, phi::dtype::bfloat16,
...@@ -83,6 +84,7 @@ PD_REGISTER_KERNEL(all_gather, ...@@ -83,6 +84,7 @@ PD_REGISTER_KERNEL(all_gather,
int, int,
uint8_t, uint8_t,
int8_t, int8_t,
int16_t,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16) {}
......
...@@ -77,6 +77,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -77,6 +77,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_quantization MODULES test_pass_quantization) py_test_modules(test_pass_quantization MODULES test_pass_quantization)
set_tests_properties(test_pass_quantization set_tests_properties(test_pass_quantization
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 60) PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 60)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout # End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.fluid import core
class TestReshardSToR:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._shard = eval(os.getenv("shard"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())
dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
in_shard_specs[self._shard] = "x"
out_shard_specs = [None for i in range(len(self._shape))]
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
reshard_func = core.SToRReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
out_shape = list(self._shape)
out_shape[self._shard] = out_shape[self._shard] * 2
assert np.equal(out.shape, out_shape).all()
if __name__ == '__main__':
TestReshardSToR().run_test_case()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestReshardSToR(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
"shard": "0",
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
}
def test_reshard_s_to_r(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_s_to_r.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
...@@ -34,6 +34,7 @@ class CommunicationTestDistBase(unittest.TestCase): ...@@ -34,6 +34,7 @@ class CommunicationTestDistBase(unittest.TestCase):
def run_test_case(self, script_file, user_defined_envs=None): def run_test_case(self, script_file, user_defined_envs=None):
runtime_envs = os.environ runtime_envs = os.environ
if user_defined_envs is not None:
runtime_envs.update(user_defined_envs) runtime_envs.update(user_defined_envs)
runtime_envs["CUDA_VISIBLE_DEVICES"] = self._devices runtime_envs["CUDA_VISIBLE_DEVICES"] = self._devices
start_command = f"{self._python_interp} -u -m paddle.distributed.launch --log_dir {self._log_dir.name} --devices {self._devices} {script_file}" start_command = f"{self._python_interp} -u -m paddle.distributed.launch --log_dir {self._log_dir.name} --devices {self._devices} {script_file}"
......
...@@ -114,7 +114,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) { ...@@ -114,7 +114,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) {
RToSReshardFunction r_to_s_func; RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output = std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(*context, *input, out_dist_attr); r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 12); CHECK_EQ(output->numel(), 12);
...@@ -136,7 +136,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) { ...@@ -136,7 +136,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
std::shared_ptr<TensorDistAttr> out_dist_attr = std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape); std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1}; std::vector<int64_t> out_dims_mapping = {0, -1, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping); out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh); out_dist_attr->set_process_mesh(mesh);
...@@ -145,7 +145,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) { ...@@ -145,7 +145,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
RToSReshardFunction r_to_s_func; RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output = std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(*context, *input, out_dist_attr); r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 32); CHECK_EQ(output->numel(), 32);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册