提交 4a7d4b38 编写于 作者: Y Yujing Zhang 提交者: TensorFlower Gardener

[XLA:GPU] Introduce CollectiveOpGroupMode and related helpers.

- Introduce a CollectiveOpGroupMode enum to describe various modes of interpreting
  replica groups attached to collective communication operations and
  GetCollectiveOpGroupMode() function to get the group formation mode
  implied by an HLO collective op based on whether it has channel_id and
  use_global_device_ids.
- Fix GetParticipatingDevices() to use this mode to correctly find participants in all
  group formatio...

PiperOrigin-RevId: 358313513
Change-Id: I9103d065c0a89149bcad3fa6b275696a257c97ae
上级 a8d2bce4
......@@ -5340,7 +5340,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:lib_internal", # fixdeps: keep
"@com_google_absl//absl/types:optional",
],
)
......@@ -5354,8 +5353,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",
],
)
......
......@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
namespace xla {
......@@ -51,144 +50,59 @@ absl::optional<ReductionKind> MatchReductionComputation(
}
}
StatusOr<std::vector<int>> GetParticipatingIDs(
int current_id, absl::optional<int> total_participant_count,
absl::Span<const ReplicaGroup> groups) {
StatusOr<std::vector<int>> GetParticipatingReplicas(
int replica_id, int total_replica_count,
absl::Span<const ReplicaGroup> replica_groups) {
// Empty replica_groups() means that all replicas participate.
if (groups.empty()) {
TF_RET_CHECK(total_participant_count.has_value());
std::vector<int> all_participants(*total_participant_count);
absl::c_iota(all_participants, 0);
return all_participants;
if (replica_groups.empty()) {
std::vector<int> all_replicas(total_replica_count);
absl::c_iota(all_replicas, 0);
return all_replicas;
}
// Figure out the other replicas that go together with this one.
absl::optional<ReplicaGroup> group;
for (const ReplicaGroup& g : groups) {
if (absl::c_linear_search(g.replica_ids(), current_id)) {
TF_RET_CHECK(!group.has_value())
<< "ID " << current_id << " appears twice in replica groups";
group = g;
}
}
TF_RET_CHECK(group.has_value())
<< "ID " << current_id << " doesn't appear in replica groups";
return std::vector<int>(group->replica_ids().begin(),
group->replica_ids().end());
}
// Returns the group formation mode implied by (a) whether the operation has
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
bool has_channel_id, absl::optional<bool> use_global_device_ids) {
if (!has_channel_id) {
if (!use_global_device_ids.has_value() || !*use_global_device_ids) {
return CollectiveOpGroupMode::kCrossReplica;
} else {
return InvalidArgument(
"Invalid combination of has_channel_id and use_global_device_ids");
}
} else {
if (!use_global_device_ids.has_value()) {
return CollectiveOpGroupMode::kCrossPartition;
} else if (!*use_global_device_ids) {
return CollectiveOpGroupMode::kCrossReplicaAndPartition;
} else {
return CollectiveOpGroupMode::kFlattenedID;
absl::optional<ReplicaGroup> replica_group;
for (const ReplicaGroup& g : replica_groups) {
if (absl::c_linear_search(g.replica_ids(), replica_id)) {
TF_RET_CHECK(!replica_group.has_value())
<< "Replica " << replica_id << " appears twice in replica groups";
replica_group = g;
}
}
TF_RET_CHECK(replica_group.has_value())
<< "Replica " << replica_id << " doesn't appear in replica groups";
return std::vector<int>(replica_group->replica_ids().begin(),
replica_group->replica_ids().end());
}
StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
absl::Span<const ReplicaGroup> replica_groups,
CollectiveOpGroupMode group_mode) {
int replica_count = device_assignment.replica_count();
int partition_count = device_assignment.computation_count();
int total_replica_count, absl::Span<const ReplicaGroup> replica_groups) {
std::vector<GlobalDeviceId> participants;
// Fast path for common case, avoiding logical IDs lookup.
if (replica_groups.empty() && device_assignment.computation_count() == 1) {
participants.reserve(total_replica_count);
for (int replica_id = 0; replica_id < total_replica_count; ++replica_id) {
participants.emplace_back(
device_assignment(replica_id, /*computation_id=*/0));
}
return participants;
}
std::pair<int, int> logical_ids;
TF_ASSIGN_OR_RETURN(logical_ids,
device_assignment.LogicalIdsForDevice(device_id));
int current_replica_id = logical_ids.first;
int current_partition_id = logical_ids.second;
std::vector<GlobalDeviceId> participants;
switch (group_mode) {
case CollectiveOpGroupMode::kCrossReplica: {
// This is a cross replica operation. replica group contains replica id.
// use current replica id to find the set of participating replicas. If
// replica groups are empty, assume a group with all replicas.
TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
GetParticipatingIDs(current_replica_id, replica_count,
replica_groups));
// The set of participating devices is the replicas from the current
// partition.
participants.reserve(participating_replicas.size());
for (int replica_id : participating_replicas) {
participants.emplace_back(
device_assignment(replica_id, current_partition_id));
}
return participants;
}
case CollectiveOpGroupMode::kCrossPartition: {
// replica_groups contain partition_id, group contains all partitions for
// the current replica.
TF_ASSIGN_OR_RETURN(std::vector<int> participating_partitions,
GetParticipatingIDs(current_partition_id,
partition_count, replica_groups));
participants.reserve(participating_partitions.size());
for (int partition_id : participating_partitions) {
participants.emplace_back(
device_assignment(current_replica_id, partition_id));
}
return participants;
}
case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
// replica_groups contain replica_ids. Group contains replicas for all
// partitions.
TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
GetParticipatingIDs(current_replica_id, replica_count,
replica_groups));
participants.reserve(participating_replicas.size() * partition_count);
for (int replica_id : participating_replicas) {
for (int partition_id = 0; partition_id < partition_count;
++partition_id) {
participants.emplace_back(
device_assignment(replica_id, partition_id));
}
}
return participants;
}
case CollectiveOpGroupMode::kFlattenedID: {
// replica groups contain flattened-ids and cannot be empty.
TF_RET_CHECK(!replica_groups.empty())
<< "replica groups cannot be empty for kFlattenedID mode";
int current_flattened_id =
current_replica_id * partition_count + current_partition_id;
// Find participants based on flattened id. replica_groups cannot be empty
// so no need to pass in total_participant_count.
TF_ASSIGN_OR_RETURN(
std::vector<int> participating_flattened_ids,
GetParticipatingIDs(current_flattened_id,
/*total_participant_count=*/absl::nullopt,
replica_groups));
participants.reserve(participating_flattened_ids.size());
for (int flattened_id : participating_flattened_ids) {
// Map from flattened id back to replica_id, partition_id.
int replica_id = flattened_id / partition_count;
int partition_id = flattened_id % partition_count;
participants.emplace_back(device_assignment(replica_id, partition_id));
}
return participants;
}
int replica_id = logical_ids.first;
int computation_id = logical_ids.second;
TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
GetParticipatingReplicas(replica_id, total_replica_count,
replica_groups));
participants.reserve(participating_replicas.size());
for (int replica_id : participating_replicas) {
participants.emplace_back(device_assignment(replica_id, computation_id));
}
return participants;
}
} // end namespace xla
......@@ -37,64 +37,17 @@ enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
absl::optional<ReductionKind> MatchReductionComputation(
const HloComputation* computation);
// Figures out which IDs are participating in the collective subgroup.
// An empty `groups` indicates that all [0, total_participant_count) IDs
// are participating. Note that for CollectiveOpGroupMode::kFlattenedID,
// groups cannot be empty, so `total_participant_count` is an optional.
StatusOr<std::vector<int>> GetParticipatingIDs(
int current_id, absl::optional<int> total_participant_count,
absl::Span<const ReplicaGroup> groups);
// There are broadly 4 modes that collective communication ops use to describe
// which sets of devices are participating with a given device in the operation.
// These modes are determined by the values of channel_id (optional) and
// use_global_device_ids (optional). The modes are as follows:
//
// kCrossReplica:
// implied by: no channel id, use_global_device_ids = false, or
// no channel_id, no use_global_device_ids:
// replica_groups contain replica_id, group contains all replicas for the
// current partition
//
// kCrossPartition:
// implied by: channel_id is set, no use_global_device_ids:
// replica_groups contain partition_id, group contains all partitions for the
// current replica.
//
// kCrossReplicaAndPartition:
// implied by: channel_id is set, use_global_device_ids = false:
// replica_groups contain replica_id, group contains all replicas for all
// partitions (as opposed to just current partition).
//
// kFlattenedID:
// implied by: channel_id is set, use_global_device_ids = true:
// replica_groups contain flattened-ids, group contains devices that are
// listed in the flattened-id list.
//
// Rest of the combinations are invalid.
//
// Since the actual value of channel_id does not matter, we use a bool argument
// `has_channel_id`, and optional<bool> for use_global_device_ids.
// Note that use_global_device_ids true requires channel_id to be set as well.
// Additionally, if use_global_device_ids = true, replica groups cannot be
// empty (verified in the HLO verifier).
enum class CollectiveOpGroupMode {
kCrossReplica,
kCrossPartition,
kCrossReplicaAndPartition,
kFlattenedID,
};
// Returns the group formation mode implied by (a) whether the operation has
// channel_id and (b) if it has use_global_device_ids and if yes, its value.
StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
bool has_channel_id, absl::optional<bool> use_global_device_ids);
// Figures out which replicas are participating in the collective subgroup.
// An empty `replica_groups` indicates that all replicas are participating.
StatusOr<std::vector<int>> GetParticipatingReplicas(
int replica_id, int total_replica_count,
absl::Span<const ReplicaGroup> replica_groups);
// Figures out which devices are participating in the collective subgroup.
// An empty `replica_groups` indicates that all replicas are participating.
StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
absl::Span<const ReplicaGroup> replica_groups,
CollectiveOpGroupMode group_mode);
int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
// Key that identifies a particular Rendezvous object in our global hashtable.
// This determines which calls to ExecuteOnStream communicate with each other.
......
......@@ -15,31 +15,24 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include <iterator>
#include <sstream>
#include <string>
#include "absl/algorithm/container.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
TEST(CollectiveOpsUtilsTest, GetParticipatingIDs_NoReplicaGroups) {
std::vector<int> actual = GetParticipatingIDs(
/*current_id=*/0, /*total_participant_count=*/3,
/*groups=*/{})
.ConsumeValueOrDie();
TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_NoReplicaGroups) {
std::vector<int> actual =
GetParticipatingReplicas(
/*replica_id=*/0, /*total_replica_count=*/3, /*replica_groups=*/{})
.ConsumeValueOrDie();
std::vector<int> expected = {0, 1, 2};
EXPECT_EQ(actual, expected);
}
TEST(CollectiveOpsUtilsTest, GetParticipatingIDs_ReplicaGroups) {
TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_ReplicaGroups) {
std::vector<ReplicaGroup> replica_groups(3);
replica_groups[0].add_replica_ids(0);
replica_groups[0].add_replica_ids(4);
......@@ -49,388 +42,68 @@ TEST(CollectiveOpsUtilsTest, GetParticipatingIDs_ReplicaGroups) {
replica_groups[2].add_replica_ids(3);
std::vector<int> actual =
GetParticipatingIDs(
/*current_id=*/1, /*total_participant_count=*/absl::nullopt,
replica_groups)
GetParticipatingReplicas(
/*replica_id=*/1, /*total_replica_count=*/6, replica_groups)
.ConsumeValueOrDie();
std::vector<int> expected = {1, 5};
EXPECT_EQ(actual, expected);
}
} // namespace
// Tests for GetCollectOpGroupMode
namespace GetCollectiveOpGroupModeTest {
struct TestCase {
bool has_channel_id;
absl::optional<bool> use_global_device_ids;
absl::optional<xla::CollectiveOpGroupMode> expected;
std::string ToString() const {
std::ostringstream s;
s << (has_channel_id ? "chnl" : "nochnl");
s << "_"
<< (use_global_device_ids
? (*use_global_device_ids ? "ugdi_true" : "ugdi_false")
: "nougdi");
return s.str();
}
};
std::vector<TestCase> GetTestCases() {
const std::vector<TestCase> test_cases = {
{.has_channel_id = false,
.use_global_device_ids = absl::nullopt,
.expected = CollectiveOpGroupMode::kCrossReplica},
{.has_channel_id = false,
.use_global_device_ids = false,
.expected = CollectiveOpGroupMode::kCrossReplica},
{.has_channel_id = false,
.use_global_device_ids = true,
.expected = absl::nullopt},
{
.has_channel_id = true,
.use_global_device_ids = absl::nullopt,
.expected = CollectiveOpGroupMode::kCrossPartition,
},
{.has_channel_id = true,
.use_global_device_ids = false,
.expected = CollectiveOpGroupMode::kCrossReplicaAndPartition},
{.has_channel_id = true,
.use_global_device_ids = true,
.expected = CollectiveOpGroupMode::kFlattenedID},
};
return test_cases;
}
class GetCollectOpGroupModeTest : public testing::TestWithParam<TestCase> {};
TEST_P(GetCollectOpGroupModeTest, Test) {
const TestCase &tc = GetParam();
StatusOr<CollectiveOpGroupMode> actual =
GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids);
if (tc.expected) {
TF_ASSERT_OK(actual.status());
EXPECT_EQ(*actual, *tc.expected);
} else {
EXPECT_FALSE(actual.ok());
}
}
INSTANTIATE_TEST_SUITE_P(GetCollectOpGroupMode, GetCollectOpGroupModeTest,
testing::ValuesIn(GetTestCases()));
} // namespace GetCollectiveOpGroupModeTest
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_NoReplicaGroups) {
DeviceAssignment device_assignment(/*replica_count=*/3,
/*computation_count=*/1);
device_assignment(0, 0) = 42;
device_assignment(1, 0) = 43;
device_assignment(2, 0) = 44;
// Tests for GetParticipatingDevices
namespace GetParticipatingDevicesTest {
// Test case for GetParticipatingDevices. Describes all the inputs to the
// function and for a given "setup", multiple "current_id" values and the
// expected output corresponding to those values.
struct TestCase {
xla::Array2D<int> device_assignment;
std::vector<std::vector<int>> replica_groups;
bool has_channel_id;
absl::optional<bool> use_global_device_ids;
// For a given test case, its useful to test multiple 'current_id' inputs.
struct CurrentIdAndOutput {
int current_id;
std::vector<int> expected_output;
};
std::vector<CurrentIdAndOutput> subtests;
bool expected_failure;
std::string ToString() const;
};
// Please see the comment for GetParticipatingDevices() for a description of
// modes and their behavior.
std::string TestCase::ToString() const {
std::ostringstream s;
StatusOr<CollectiveOpGroupMode> group_mode =
GetCollectiveOpGroupMode(has_channel_id, use_global_device_ids);
if (group_mode.ok()) {
switch (*group_mode) {
case CollectiveOpGroupMode::kCrossReplica:
s << "kCrossReplica";
break;
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
s << "kCrossReplicaAndPartition";
break;
case CollectiveOpGroupMode::kFlattenedID:
s << "kFlattenedID";
break;
case CollectiveOpGroupMode::kCrossPartition:
s << "kCrossPartition";
break;
}
} else {
s << "Invalid";
}
s << "_" << device_assignment.n1() << "x" << device_assignment.n2();
s << "_" << (replica_groups.empty() ? "NoRG" : "RG");
s << "_" << subtests.size() << "SubTests";
return s.str();
}
std::ostream &operator<<(std::ostream &os, const TestCase &tc) {
os << tc.ToString();
return os;
std::vector<GlobalDeviceId> actual =
GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
/*total_replica_count=*/3, /*replica_groups=*/{})
.ConsumeValueOrDie();
std::vector<GlobalDeviceId> expected = {
GlobalDeviceId(42), GlobalDeviceId(43), GlobalDeviceId(44)};
EXPECT_EQ(actual, expected);
}
std::vector<TestCase> GetTestCases() {
std::vector<TestCase> test_cases;
// clang-format off
const std::vector<TestCase> cross_replica_test_cases = {
// with empty replica groups, 1 partition.
{
.device_assignment = {{33}, {44}, {55}}, // 3 replicas, 1 partition.
.replica_groups = {},
.has_channel_id = false,
.use_global_device_ids = false,
.subtests = {
// for empty replica group, any id should return all ids.
{.current_id = 33, .expected_output = {33, 44, 55}},
{.current_id = 44, .expected_output = {33, 44, 55}},
},
.expected_failure = false
},
// empty replica groups, > 1 partition
{
.device_assignment = {{33, 34}, {44, 45}, {55, 56}}, // 3r, 2p
.replica_groups = {},
.has_channel_id = false,
.use_global_device_ids = false,
// for empty replica group, any id should return all replicas within that
// partition.
.subtests = {
{.current_id = 33, .expected_output = {33, 44, 55}},
{.current_id = 34, .expected_output = {34, 45, 56}},
{.current_id = 45, .expected_output = {34, 45, 56}},
},
.expected_failure = false
},
// non-empty replica groups, 1 partition.
{
.device_assignment = {{33}, {44}, {55}}, // 3r, 1p.
.replica_groups = {{0}, {1, 2}},
.has_channel_id = false,
.use_global_device_ids = false,
.subtests = {
// 33 is r0, so it's a singleton group.
{.current_id = 33, .expected_output = {33}},
// 44 is r1, so it should give {r1, r2}.
{.current_id = 44, .expected_output = {44, 55}},
},
.expected_failure = false
},
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_ReplicaGroups) {
DeviceAssignment device_assignment(/*replica_count=*/4,
/*computation_count=*/1);
device_assignment(0, 0) = 42;
device_assignment(1, 0) = 43;
device_assignment(2, 0) = 44;
device_assignment(3, 0) = 45;
// non-empty, > 1 partition
{
.device_assignment = {{33, 34}, {44, 45}, {55, 56}}, // 3r, 2p
.replica_groups = {{0}, {1, 2}},
.has_channel_id = false,
.use_global_device_ids = false,
.subtests = {
// 33 is r0p0, so should be singleton.
{.current_id = 33, .expected_output = {33}},
// 34 is r0p1, so should be singleton.
{.current_id = 34, .expected_output = {34}},
// 45 is r1p1, so should get r1p1 and r2p1.
{.current_id = 45, .expected_output = {45, 56}},
},
.expected_failure = false
},
};
// replica groups contain partition ids.
const std::vector<TestCase> cross_partition_test_cases = {
{
// 3x4 device assignment
.device_assignment = {
{33, 34, 35, 36}, {44, 45, 46, 47}, {55, 56, 57, 58}
},
.replica_groups = {{0, 1}, {2, 3}},
.has_channel_id = true,
.use_global_device_ids = absl::nullopt,
.subtests = {
// 33 is r0p0, p0 group has p0, p1 so we get r0p0 and r0p1.
{.current_id = 33, .expected_output = {33, 34}},
// 35 is r0p2, so we get r0p2 and r0p3
{.current_id = 35, .expected_output = {35, 36}},
{.current_id = 45, .expected_output = {44, 45}},
{.current_id = 47, .expected_output = {46, 47}},
{.current_id = 58, .expected_output = {57, 58}},
},
.expected_failure = false
}
};
const std::vector<TestCase> cross_replica_and_partition_test_cases = {
{
.device_assignment = {{33, 34}, {44, 45}, {55, 56}}, // 3r, 2p
.replica_groups = {{0}, {1, 2}},
.has_channel_id = true,
.use_global_device_ids = false,
.subtests = {
// 33 is r0p0, so should get r0 from all partitions.
{.current_id = 33, .expected_output = {33, 34}},
// 34 is r0p1, so should get r0 from all partitions.
{.current_id = 34, .expected_output = {33, 34}},
// 45 is r1p1, so should get r1, r2
{.current_id = 45, .expected_output = {44, 45, 55, 56}},
// from all partitons.
},
.expected_failure = false
},
// empty replica group = all replicas, so we should get all devices.
{
.device_assignment = {{33, 34}, {44, 45}, {55, 56}}, // 3r, 2p
.replica_groups = {},
.has_channel_id = true,
.use_global_device_ids = false,
.subtests = {
{.current_id = 33, .expected_output = {33, 34, 44, 45, 55, 56}},
{.current_id = 34, .expected_output = {33, 34, 44, 45, 55, 56}},
{.current_id = 56, .expected_output = {33, 34, 44, 45, 55, 56}},
},
.expected_failure = false
},
};
// Replica groups are flattened ids. For a 3x2 device assignment
// used in these tests, the flattened ID and deviceId correspondence is as
// follows:
// r0p0 = f#0 = d#33
// r0p1 = f#1 = d#34
// r1p0 = f#2 = d#44
// r1p1 = f#3 = d#45
// r2p0 = f#4 = d#55
// r2p1 = f#5 = d#56
const std::vector<TestCase> flattened_id_test_cases = {
{
.device_assignment = {{33, 34}, {44, 45}, {55, 56}}, // 3r, 2p
.replica_groups = {{0}, {1, 2}, {3, 4, 5}},
.has_channel_id = true,
.use_global_device_ids = true,
.subtests = {
{.current_id = 33, .expected_output = {33}},
{.current_id = 34, .expected_output = {34, 44}},
{.current_id = 44, .expected_output = {34, 44}},
{.current_id = 45, .expected_output = {45, 55, 56}},
{.current_id = 55, .expected_output = {45, 55, 56}},
{.current_id = 56, .expected_output = {45, 55, 56}},
},
.expected_failure = false
},
{
.device_assignment = {{33}},
.replica_groups = {}, // empty replica groups not allowed.
.has_channel_id = true,
.use_global_device_ids = true,
.subtests = {
{.current_id = 33, .expected_output = {33}},
},
.expected_failure = true
},
};
const std::vector<TestCase> failure_test_cases = {
// No channel id, use_global_device_ids = true;
{
.device_assignment = {{33}, {44}, {55}}, // 3r, 1p
.replica_groups = {},
.has_channel_id = false,
.use_global_device_ids = true,
.subtests = {
{.current_id = 33, .expected_output = {}},
},
.expected_failure = true
},
};
// clang-format on
test_cases.insert(test_cases.end(), cross_replica_test_cases.begin(),
cross_replica_test_cases.end());
// When use_global_device_ifs is not present and channel_id is not present,
// that implies cross replica mode as well.
for (TestCase tc : cross_replica_test_cases) {
tc.use_global_device_ids = absl::nullopt;
test_cases.push_back(tc);
}
test_cases.insert(test_cases.end(), cross_partition_test_cases.begin(),
cross_partition_test_cases.end());
test_cases.insert(test_cases.end(),
cross_replica_and_partition_test_cases.begin(),
cross_replica_and_partition_test_cases.end());
test_cases.insert(test_cases.end(), flattened_id_test_cases.begin(),
flattened_id_test_cases.end());
test_cases.insert(test_cases.end(), failure_test_cases.begin(),
failure_test_cases.end());
std::vector<ReplicaGroup> replica_groups(2);
replica_groups[0].add_replica_ids(0);
replica_groups[0].add_replica_ids(3);
replica_groups[1].add_replica_ids(1);
replica_groups[1].add_replica_ids(2);
return test_cases;
std::vector<GlobalDeviceId> actual =
GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
/*total_replica_count=*/4, replica_groups)
.ConsumeValueOrDie();
std::vector<GlobalDeviceId> expected = {GlobalDeviceId(42),
GlobalDeviceId(45)};
EXPECT_EQ(actual, expected);
}
class GetParticipatingDevicesTest : public testing::TestWithParam<TestCase> {};
TEST_P(GetParticipatingDevicesTest, Test) {
const TestCase &tc = GetParam();
int64_t num_replicas = tc.device_assignment.n1();
int64_t num_partitions = tc.device_assignment.n2();
DeviceAssignment device_assignment(num_replicas, num_partitions);
for (int64_t replica_id = 0; replica_id < num_replicas; ++replica_id) {
for (int64_t partition_id = 0; partition_id < num_partitions;
++partition_id) {
device_assignment(replica_id, partition_id) =
tc.device_assignment(replica_id, partition_id);
}
}
std::vector<ReplicaGroup> replica_groups;
absl::c_transform(tc.replica_groups, std::back_inserter(replica_groups),
[](const std::vector<int> &ids) {
ReplicaGroup group;
for (int id : ids) {
group.add_replica_ids(id);
}
return group;
});
// Execute each sub-test.
for (const TestCase::CurrentIdAndOutput &subtest : tc.subtests) {
StatusOr<CollectiveOpGroupMode> group_mode =
GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids);
if (!group_mode.ok()) {
EXPECT_TRUE(tc.expected_failure);
continue;
}
StatusOr<std::vector<GlobalDeviceId>> actual =
GetParticipatingDevices(GlobalDeviceId(subtest.current_id),
device_assignment, replica_groups, *group_mode);
if (!actual.ok()) {
EXPECT_TRUE(tc.expected_failure);
continue;
}
std::vector<GlobalDeviceId> expected;
expected.reserve(subtest.expected_output.size());
absl::c_transform(subtest.expected_output, std::back_inserter(expected),
[](int id) { return GlobalDeviceId(id); });
EXPECT_EQ(*actual, expected);
}
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_MultipleComputations) {
DeviceAssignment device_assignment(/*replica_count=*/2,
/*computation_count=*/2);
device_assignment(0, 0) = 42;
device_assignment(1, 0) = 43;
device_assignment(0, 1) = 44;
device_assignment(1, 1) = 45;
std::vector<GlobalDeviceId> actual =
GetParticipatingDevices(GlobalDeviceId(44), device_assignment,
/*total_replica_count=*/2, /*replica_groups=*/{})
.ConsumeValueOrDie();
std::vector<GlobalDeviceId> expected = {GlobalDeviceId(44),
GlobalDeviceId(45)};
EXPECT_EQ(actual, expected);
}
INSTANTIATE_TEST_SUITE_P(GetParticipatingDevices, GetParticipatingDevicesTest,
testing::ValuesIn(GetTestCases()));
} // namespace GetParticipatingDevicesTest
} // namespace
} // namespace xla
......@@ -617,7 +617,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
......
......@@ -24,7 +24,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
......@@ -605,8 +604,8 @@ xla::RendezvousKey GetRendezvousKey(
: xla::RendezvousKey::kCrossReplica;
std::vector<xla::GlobalDeviceId> participating_devices =
xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
device_assignment, group,
xla::CollectiveOpGroupMode::kCrossReplica)
device_assignment,
device_assignment.replica_count(), group)
.ValueOrDie();
int num_local_participants = participating_devices.size();
return xla::RendezvousKey{run_options->run_id(),
......@@ -637,7 +636,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
run_options->stream());
participant.replica_id = replica_id;
participant.replica_ids_to_copy_to =
xla::GetParticipatingIDs(
xla::GetParticipatingReplicas(
replica_id, run_options->device_assignment()->replica_count(), group)
.ValueOrDie();
for (int i = 0; i < num_buffers; i++) {
......
......@@ -166,7 +166,7 @@ extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
// Perform all reduce on a CPU.
//
// participating_replicas: array of replica IDs participating in the reduction,
// cf. GetParticipatingIDs.
// cf. GetParticipatingReplicas.
// channel_id_present, op_id: whether op_id is a channel ID or a module ID.
// reduction_kind: operator used for a reduction, cf. ReductionKind.
// shape_ptr: shape of all input/output buffers.
......
......@@ -73,8 +73,7 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(
std::vector<GlobalDeviceId> participants,
GetParticipatingDevices(global_device_id, *params.device_assn,
config().replica_groups,
CollectiveOpGroupMode::kCrossReplica));
config().replica_count, config().replica_groups));
if (IsGlobalNcclConfig() && (participants.size() != config().replica_count)) {
return InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册