未验证 提交 f839e821 编写于 作者: L LiYuRio 提交者: GitHub

reshard p to r (#56975)

上级 2d9de72f
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
...@@ -162,6 +163,10 @@ void BindAutoParallel(py::module *m) { ...@@ -162,6 +163,10 @@ void BindAutoParallel(py::module *m) {
*m, "RToPReshardFunction", ReshardFunction) *m, "RToPReshardFunction", ReshardFunction)
.def(py::init<>()); .def(py::init<>());
py::class_<phi::distributed::PToRReshardFunction>(
*m, "PToRReshardFunction", ReshardFunction)
.def(py::init<>());
py::class_<ProcessMesh>(*m, "ProcessMesh") py::class_<ProcessMesh>(*m, "ProcessMesh")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::vector<int64_t> &, .def(py::init<const std::vector<int64_t> &,
...@@ -340,7 +345,8 @@ void BindAutoParallel(py::module *m) { ...@@ -340,7 +345,8 @@ void BindAutoParallel(py::module *m) {
}, },
py::arg("memo")) py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string) .def("__str__", &TensorDistAttr::to_string)
.def("_is_partial", &TensorDistAttr::is_partial) .def(
"_is_partial", &TensorDistAttr::is_partial, py::arg("mesh_axis") = -1)
.def("_partial_dims", &TensorDistAttr::partial_dims) .def("_partial_dims", &TensorDistAttr::partial_dims)
.def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims) .def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims)
.def("_set_partial_dims", .def("_set_partial_dims",
......
...@@ -105,8 +105,7 @@ namespace { ...@@ -105,8 +105,7 @@ namespace {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
phi::DenseTensor ReshardXToReplicated( phi::DenseTensor ReshardXToReplicated(
phi::distributed::DistTensor* dist_tensor) { phi::distributed::DistTensor* dist_tensor) {
if (!phi::distributed::IsDimsMappingReplicated( if (!dist_tensor->dist_attr().is_replicated()) {
dist_tensor->dist_attr().dims_mapping())) {
phi::distributed::TensorDistAttr dist_attr(dist_tensor->dist_attr()); phi::distributed::TensorDistAttr dist_attr(dist_tensor->dist_attr());
std::vector<int64_t> dims_mapping(dist_tensor->dims().size(), -1); std::vector<int64_t> dims_mapping(dist_tensor->dims().size(), -1);
dist_attr.set_dims_mapping(dims_mapping); dist_attr.set_dims_mapping(dims_mapping);
......
...@@ -14,4 +14,5 @@ collect_srcs( ...@@ -14,4 +14,5 @@ collect_srcs(
reshard_function.cc reshard_function.cc
r_to_s_reshard_function.cc r_to_s_reshard_function.cc
s_to_r_reshard_function.cc s_to_r_reshard_function.cc
r_to_p_reshard_function.cc) r_to_p_reshard_function.cc
p_to_r_reshard_function.cc)
...@@ -348,5 +348,69 @@ bool TensorDistAttr::empty() const { ...@@ -348,5 +348,69 @@ bool TensorDistAttr::empty() const {
return process_mesh_.empty() || dims_mapping_.empty(); return process_mesh_.empty() || dims_mapping_.empty();
} }
std::vector<std::shared_ptr<PlacementStatus>> TensorDistAttr::to_placement()
const {
auto ndim = process_mesh_.ndim();
std::vector<std::shared_ptr<PlacementStatus>> placement(
ndim, std::make_shared<ReplicatedStatus>());
for (size_t i = 0; i < dims_mapping_.size(); ++i) {
if (dims_mapping_[i] != -1) {
PADDLE_ENFORCE_LT(
dims_mapping_[i],
ndim,
errors::InvalidArgument(
"Split axis %ld can not exceed the ndim of process_mesh %ld",
dims_mapping_[i],
ndim));
placement[dims_mapping_[i]] = std::make_shared<ShardStatus>(i);
}
}
for (auto& itr : partial_status_) {
PADDLE_ENFORCE_LT(
itr.first,
ndim,
errors::InvalidArgument(
"Partial axis %ld can not exceed the ndim of process_mesh %ld",
itr.first,
ndim));
placement[itr.first] = std::make_shared<PartialStatus>(itr.second);
}
return placement;
}
bool TensorDistAttr::is_replicated(int64_t mesh_axis) const {
auto placement = to_placement();
if (mesh_axis == -1) {
return std::all_of(placement.begin(),
placement.end(),
[](std::shared_ptr<PlacementStatus> status) {
return status->is_replicated();
});
} else {
return placement[mesh_axis]->is_replicated();
}
}
bool TensorDistAttr::is_shard(int64_t mesh_axis, int64_t tensor_axis) const {
auto placement = to_placement();
if (mesh_axis == -1) {
return std::all_of(placement.begin(),
placement.end(),
[tensor_axis](std::shared_ptr<PlacementStatus> status) {
return status->is_shard(tensor_axis);
});
} else {
return placement[mesh_axis]->is_shard(tensor_axis);
}
}
bool TensorDistAttr::is_partial(int64_t mesh_axis) const {
if (mesh_axis == -1) {
return !partial_status_.empty();
} else {
return partial_status_.count(mesh_axis) > 0;
}
}
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -31,6 +31,46 @@ limitations under the License. */ ...@@ -31,6 +31,46 @@ limitations under the License. */
namespace phi { namespace phi {
namespace distributed { namespace distributed {
class PlacementStatus {
public:
virtual ~PlacementStatus() = default;
virtual bool is_shard(int64_t axis = -1) const { return false; }
virtual bool is_partial() const { return false; }
virtual bool is_replicated() const { return false; }
};
class ReplicatedStatus final : public PlacementStatus {
public:
bool is_replicated() const override { return true; }
};
class PartialStatus final : public PlacementStatus {
public:
PartialStatus(ReduceType type) : type_(type) {}
bool is_partial() const override { return true; }
ReduceType get_reduce_type() const { return type_; }
private:
ReduceType type_{ReduceType::kRedSum};
};
class ShardStatus final : public PlacementStatus {
public:
ShardStatus(int64_t axis) : axis_(axis) {}
bool is_shard(int64_t axis = -1) const override {
if (axis == -1) {
return true;
} else {
return axis == axis_;
}
}
int64_t get_axis() const { return axis_; }
private:
int64_t axis_{-1};
};
class TensorDistAttr { class TensorDistAttr {
public: public:
TensorDistAttr() = default; TensorDistAttr() = default;
...@@ -51,9 +91,6 @@ class TensorDistAttr { ...@@ -51,9 +91,6 @@ class TensorDistAttr {
void set_dims_mapping(const std::vector<int64_t>& dims_mapping); void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
// true if tensor is partial on any mesh dim.
bool is_partial() const { return !partial_status_.empty(); }
// return vector of mesh dims on which the this tensor is partial on // return vector of mesh dims on which the this tensor is partial on
const std::set<int64_t> partial_dims() const; const std::set<int64_t> partial_dims() const;
...@@ -132,6 +169,22 @@ class TensorDistAttr { ...@@ -132,6 +169,22 @@ class TensorDistAttr {
bool empty() const; bool empty() const;
std::vector<std::shared_ptr<PlacementStatus>> to_placement() const;
// if mesh_axis is -1, check if tensor is replicated on whole process_mesh
// if mesh_axis is not -1, check only on specific axis.
bool is_replicated(int64_t mesh_axis = -1) const;
// if mesh_axis is -1, check if tensor is shard on whole process_mesh
// if mesh_axis is not -1, check only on specific axis
// if tensor_axis is not -1, return true only if the shard axis equal to
// tensor_axis.
bool is_shard(int64_t mesh_axis = -1, int64_t tensor_axis = -1) const;
// if mesh_axis is -1, check if tensor is partial on whole process_mesh
// if mesh_axis is not -1, check only on specific axis.
bool is_partial(int64_t mesh_axis = -1) const;
private: private:
static std::vector<std::string> fields_; static std::vector<std::string> fields_;
ProcessMesh process_mesh_; ProcessMesh process_mesh_;
......
...@@ -34,11 +34,15 @@ inline void check_defined(const DistTensor& dist_tensor, ...@@ -34,11 +34,15 @@ inline void check_defined(const DistTensor& dist_tensor,
DistTensor::DistTensor(const phi::DenseTensor& global_value, DistTensor::DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr) const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) { : dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {
if (!IsDimsMappingReplicated(dist_attr_.dims_mapping())) { if (!dist_attr.is_replicated()) {
// 1. create replicated global tensor // 1. create replicated global tensor
int64_t dims_size = global_value.dims().size(); int64_t dims_size = global_value.dims().size();
std::vector<int64_t> dims_mapping(dims_size, -1); std::vector<int64_t> dims_mapping(dims_size, -1);
dist_attr_.set_dims_mapping(dims_mapping); dist_attr_.set_dims_mapping(dims_mapping);
if (dist_attr_.is_partial()) {
dist_attr_.clean_partial_status();
}
dist_attr_.set_dims_mapping(dims_mapping);
// 2. reshard from replicated to other state // 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr); auto* func = ChooseProperReshardFunction(*this, dist_attr);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"
namespace phi {
namespace distributed {
bool PToRReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
bool flag = true;
flag &= in.dist_attr().is_partial();
flag &= out_dist_attr.is_replicated();
const auto& in_process_mesh = in.dist_attr().process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);
return flag;
}
void PToRReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
const auto& in_partial_status = in_dist_attr.partial_status();
auto dtype = in.dtype();
int64_t reduce_type = static_cast<int64_t>(in_partial_status.at(0));
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
AllReduce,
dtype,
in_process_ids,
in.value(),
reduce_type,
GetMutableTensor(out));
SetDistProps(out, in.dims(), out_dist_attr);
}
REGISTER_RESHARD_FUNC(PToRReshardFunction);
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
namespace phi {
namespace distributed {
class PToRReshardFunction final : public ReshardFunction {
public:
PToRReshardFunction() = default;
~PToRReshardFunction() = default;
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};
} // namespace distributed
} // namespace phi
...@@ -28,9 +28,7 @@ bool RToPReshardFunction::IsSuitable(const DistTensor& in, ...@@ -28,9 +28,7 @@ bool RToPReshardFunction::IsSuitable(const DistTensor& in,
bool flag = true; bool flag = true;
const auto& in_dist_attr = in.dist_attr(); const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping(); flag &= in_dist_attr.is_replicated();
flag &= IsDimsMappingReplicated(in_dims_mapping);
flag &= out_dist_attr.is_partial(); flag &= out_dist_attr.is_partial();
const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_mesh = in_dist_attr.process_mesh();
......
...@@ -28,11 +28,8 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in, ...@@ -28,11 +28,8 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in,
bool flag = true; bool flag = true;
const auto& in_dist_attr = in.dist_attr(); const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping(); flag &= in_dist_attr.is_replicated();
const auto& out_dims_mapping = out_dist_attr.dims_mapping(); flag &= out_dist_attr.is_shard();
flag &= IsDimsMappingReplicated(in_dims_mapping);
flag &= IsDimsMappingShard(out_dims_mapping);
const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh();
......
...@@ -40,18 +40,6 @@ std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) { ...@@ -40,18 +40,6 @@ std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
} }
} // namespace } // namespace
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(),
dims_mapping.end(),
[](int64_t value) { return value != -1; });
}
bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) {
return std::all_of(dims_mapping.begin(),
dims_mapping.end(),
[](int64_t value) { return value == -1; });
}
std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape(); const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids(); const auto& process_ids = process_mesh.process_ids();
......
...@@ -30,10 +30,6 @@ class DeviceContext; ...@@ -30,10 +30,6 @@ class DeviceContext;
namespace distributed { namespace distributed {
class ProcessMesh; class ProcessMesh;
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);
bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping);
// Get the coordinate of cur rank in process mesh. For example, the process mesh // Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will // is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1]. // return [2, 0]; if the current rank is 3, then will return [1, 1].
......
...@@ -30,12 +30,10 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in, ...@@ -30,12 +30,10 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) { const TensorDistAttr& out_dist_attr) {
bool flag = true; bool flag = true;
const auto& in_dist_attr = in.dist_attr(); const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping(); const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& out_dims_mapping = out_dist_attr.dims_mapping();
flag &= IsDimsMappingShard(in_dims_mapping); flag &= in_dist_attr.is_shard();
flag &= IsDimsMappingReplicated(out_dims_mapping); flag &= out_dist_attr.is_replicated();
const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh();
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/phi/common/reduce_type.h" #include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi { namespace phi {
...@@ -25,4 +26,16 @@ void AllReduceKernel(const Context& dev_ctx, ...@@ -25,4 +26,16 @@ void AllReduceKernel(const Context& dev_ctx,
int reduce_type, int reduce_type,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void AllReduce(const Context& dev_ctx,
const DenseTensor& x,
int reduce_type,
DenseTensor* out) {
MetaTensor out_meta(*out);
MetaTensor* out_meta_ptr = &out_meta;
AllReduceInferMeta(phi::MetaTensor(x), out_meta_ptr);
AllReduceKernel<T, Context>(dev_ctx, x, reduce_type, out);
}
} // namespace phi } // namespace phi
...@@ -83,6 +83,7 @@ PD_REGISTER_KERNEL(all_reduce, ...@@ -83,6 +83,7 @@ PD_REGISTER_KERNEL(all_reduce,
bool, bool,
int8_t, int8_t,
uint8_t, uint8_t,
int16_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
......
...@@ -78,6 +78,7 @@ PD_REGISTER_KERNEL(all_reduce, ...@@ -78,6 +78,7 @@ PD_REGISTER_KERNEL(all_reduce,
bool, bool,
int8_t, int8_t,
uint8_t, uint8_t,
int16_t,
int64_t, int64_t,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -92,6 +93,7 @@ PD_REGISTER_KERNEL(all_reduce, ...@@ -92,6 +93,7 @@ PD_REGISTER_KERNEL(all_reduce,
bool, bool,
int8_t, int8_t,
uint8_t, uint8_t,
int16_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif #endif
...@@ -88,6 +88,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -88,6 +88,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p) py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
set_tests_properties(test_reshard_r_to_p set_tests_properties(test_reshard_r_to_p
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_p_to_r MODULES test_reshard_p_to_r)
set_tests_properties(test_reshard_p_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_semi_auto_parallel_basic MODULES py_test_modules(test_semi_auto_parallel_basic MODULES
test_semi_auto_parallel_basic) test_semi_auto_parallel_basic)
set_tests_properties(test_semi_auto_parallel_basic set_tests_properties(test_semi_auto_parallel_basic
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.framework import core
class TestReshardSToR:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())
dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
dist_attr._set_partial_dims([0])
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
reshard_func = core.PToRReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
assert np.equal(out.shape, input_tensor.shape).all()
np.testing.assert_equal(out._local_value().numpy(), a.numpy())
if __name__ == '__main__':
TestReshardSToR().run_test_case()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestReshardSToR(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
}
def test_reshard_s_to_r(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_p_to_r.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册