提交 ebdbaa3f 编写于 作者: T Tongfei Guo 提交者: TensorFlower Gardener

[SPMD] Create a placeholder sharding of type UNKNOWN with the lowest...

[SPMD] Create a placeholder sharding of type UNKNOWN with the lowest precedence that can be overwriten by any other shardings. (2/N)

UNKNOWN type is not supported in subgroup sharding.

PiperOrigin-RevId: 564456004
上级 73e39990
......@@ -278,10 +278,10 @@ HloSharding HloSharding::Subgroup(
static constexpr std::array<OpSharding::Type, OpSharding::Type_ARRAYSIZE>
kOrderedTypes = {OpSharding::MAXIMAL, OpSharding::TUPLE,
OpSharding::OTHER, OpSharding::MANUAL,
OpSharding::REPLICATED};
OpSharding::REPLICATED, OpSharding::UNKNOWN};
static_assert(kOrderedTypes[0] == 1 && kOrderedTypes[1] == 2 &&
kOrderedTypes[2] == 3 && kOrderedTypes[3] == 4 &&
kOrderedTypes[4] == 0);
kOrderedTypes[4] == 0 && kOrderedTypes[5] == 5);
for (OpSharding::Type type : kOrderedTypes) {
auto& dims = type_to_dims[type];
if (dims.empty()) continue;
......@@ -426,6 +426,15 @@ void HloSharding::Print(Printer* printer, bool include_metadata) const {
printer->Append("}");
return;
}
if (unknown_) {
printer->Append("{unknown");
print_shard_group();
print_metadata();
printer->Append("}");
return;
}
if (maximal_) {
AppendCat(printer, "{maximal device=",
static_cast<int64_t>(*tile_assignment_.array().begin()));
......@@ -510,6 +519,7 @@ std::map<int64_t, int64_t> HloSharding::UsedDevices(int64_t* count) const {
std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {
CHECK(!maximal_);
CHECK(!IsManual());
CHECK(!IsUnknown());
CHECK(!IsTuple());
std::vector<int64_t> ret_index;
tile_assignment_.Each([&](absl::Span<const int64_t> index, int64_t d) {
......@@ -525,6 +535,7 @@ std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {
int64_t HloSharding::DeviceForTileIndex(absl::Span<const int64_t> index) const {
CHECK(!replicated_);
CHECK(!IsManual());
CHECK(!IsUnknown());
CHECK(!IsTuple());
if (maximal_) {
return *tile_assignment_.array().begin();
......@@ -545,6 +556,7 @@ std::vector<int64_t> HloSharding::TileOffsetForDevice(const Shape& shape,
int64_t device) const {
CHECK(!IsTuple());
CHECK(!IsManual());
CHECK(!IsUnknown());
if (maximal_) {
return std::vector<int64_t>(shape.dimensions_size(), 0);
......@@ -563,6 +575,7 @@ std::vector<int64_t> HloSharding::TileLimitForDevice(const Shape& shape,
int64_t device) const {
CHECK(!IsTuple());
CHECK(!IsManual());
CHECK(!IsUnknown());
if (maximal_) {
return std::vector<int64_t>(shape.dimensions().begin(),
......@@ -744,7 +757,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
tile_assignment_.iota_->num_elements() == *num_devices;
}
if (IsTileMaximal() || IsManual()) {
if (IsTileMaximal() || IsManual() || IsUnknown()) {
return OkStatus();
}
......@@ -798,6 +811,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return Replicate(metadata).SetShardGroupFromProto(proto);
} else if (proto.type() == OpSharding::MANUAL) {
return Manual(metadata).SetShardGroupFromProto(proto);
} else if (proto.type() == OpSharding::UNKNOWN) {
return Unknown(metadata).SetShardGroupFromProto(proto);
} else if (proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0), metadata)
.SetShardGroupFromProto(proto);
......@@ -921,6 +936,9 @@ OpSharding HloSharding::ToProto() const {
} else if (IsManual()) {
result.set_type(OpSharding::MANUAL);
result.clear_tile_assignment_dimensions();
} else if (IsUnknown()) {
result.set_type(OpSharding::UNKNOWN);
result.clear_tile_assignment_dimensions();
} else {
result.set_type(OpSharding::OTHER);
result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
......@@ -942,7 +960,7 @@ OpSharding HloSharding::ToProto() const {
}
Shape HloSharding::TileShape(const Shape& shape) const {
if (IsTileMaximal() || IsManual()) {
if (IsTileMaximal() || IsManual() || IsUnknown()) {
return shape;
}
Shape result_shape = shape;
......@@ -954,7 +972,7 @@ Shape HloSharding::TileShape(const Shape& shape) const {
}
Shape HloSharding::TileShape(const Shape& shape, int64_t device) const {
if (IsTileMaximal() || IsManual()) {
if (IsTileMaximal() || IsManual() || IsUnknown()) {
return shape;
}
......@@ -977,6 +995,7 @@ int64_t HloSharding::TotalNumTiles() const {
return 1;
}
CHECK(!IsManual());
CHECK(!IsUnknown());
return Product(absl::Span<const int64_t>(tile_assignment_.dimensions()));
}
......@@ -985,6 +1004,7 @@ int64_t HloSharding::NumTiles() const {
return 1;
}
CHECK(!IsManual());
CHECK(!IsUnknown());
return Product(absl::Span<const int64_t>(tile_assignment_.dimensions())
.subspan(0, TiledDataRank()));
}
......
......@@ -45,12 +45,20 @@ class HloSharding {
// Creates a trivial sharding that replicates a maximal tile across all
// devices.
static HloSharding Replicate(absl::Span<const OpMetadata> metadata = {}) {
return HloSharding(/*manual=*/false, /*replicated=*/true, metadata);
return HloSharding(/*manual=*/false, /*replicated=*/true, /*unknown=*/false,
metadata);
}
// Creates a sharding that represents the op is manually partitioned.
static HloSharding Manual(absl::Span<const OpMetadata> metadata = {}) {
return HloSharding(/*manual=*/true, /*replicated=*/false, metadata);
return HloSharding(/*manual=*/true, /*replicated=*/false, /*unknown=*/false,
metadata);
}
// Creates a sharding that represents the op has a placeholder sharding.
static HloSharding Unknown(absl::Span<const OpMetadata> metadata = {}) {
return HloSharding(/*manual=*/false, /*replicated=*/false, /*unknown=*/true,
metadata);
}
// Creates a sharding that emulates device placement; a tile shape equal to
......@@ -189,6 +197,15 @@ class HloSharding {
[](const HloSharding& s) { return s.IsManual(); });
}
// Returns whether the sharding represents a placeholder sharding.
bool IsUnknown() const {
if (!IsTuple()) {
return unknown_;
}
return absl::c_all_of(tuple_elements_,
[](const HloSharding& s) { return s.IsUnknown(); });
}
bool IsShardGroup() const {
if (!IsTuple()) {
return shard_group_.shard_group_id != -1 &&
......@@ -226,7 +243,9 @@ class HloSharding {
// Returns weather the sharding represents a tiled sharding where the mapping
// between devices and tiles is represented through 'tile_assignment()'.
bool IsTiled() const { return !IsTileMaximal() && !IsManual(); }
bool IsTiled() const {
return !IsTileMaximal() && !IsManual() && !IsUnknown();
}
// Returns if the sharding has partial replication and partial sharding. If
// true, data is sharded according to other dimensions of tile_assignment(),
......@@ -334,7 +353,7 @@ class HloSharding {
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
manual_ == other.manual_ &&
manual_ == other.manual_ && unknown_ == other.unknown_ &&
tile_assignment_ == other.tile_assignment_ &&
tuple_elements_ == other.tuple_elements_ &&
replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ &&
......@@ -349,7 +368,7 @@ class HloSharding {
return H::combine(std::move(h), sharding.tuple_elements_);
}
return H::combine(std::move(h), sharding.replicated_, sharding.manual_,
sharding.tile_assignment_.array(),
sharding.unknown_, sharding.tile_assignment_.array(),
sharding.replicate_on_last_tile_dim_,
sharding.shard_group_.ToString());
}
......@@ -393,7 +412,7 @@ class HloSharding {
std::vector<OpMetadata>& metadata() { return metadata_; }
const std::vector<OpMetadata>& metadata() const { return metadata_; }
// Returns the replication subgroiup dim, or -1 if it doesn't exist.
// Returns the replication subgroup dim, or -1 if it doesn't exist.
int64_t SubgroupReplicationDim() const {
auto it = absl::c_find(subgroup_types_, OpSharding::REPLICATED);
if (it != subgroup_types_.end()) {
......@@ -498,13 +517,14 @@ class HloSharding {
const ShardGroup& GetShardGroup() const { return shard_group_; }
private:
explicit HloSharding(bool manual, bool replicated,
explicit HloSharding(bool manual, bool replicated, bool unknown,
absl::Span<const OpMetadata> metadata)
: metadata_(metadata.begin(), metadata.end()),
replicated_(replicated),
maximal_(replicated),
tuple_(false),
manual_(manual),
unknown_(unknown),
replicate_on_last_tile_dim_(false) {}
// device_id values:
// -2: magic number to mean unassigned device, used by spatial partitioning
......@@ -519,6 +539,7 @@ class HloSharding {
maximal_(true),
tuple_(false),
manual_(false),
unknown_(false),
replicate_on_last_tile_dim_(false) {}
explicit HloSharding(TileAssignment tile_assignment,
bool replicate_on_last_tile_dim,
......@@ -529,6 +550,7 @@ class HloSharding {
maximal_(false),
tuple_(false),
manual_(false),
unknown_(false),
replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {}
explicit HloSharding(TileAssignment tile_assignment,
absl::Span<const OpSharding::Type> subgroup_types,
......@@ -540,6 +562,7 @@ class HloSharding {
maximal_(false),
tuple_(false),
manual_(false),
unknown_(false),
replicate_on_last_tile_dim_(false) {}
explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
: tuple_elements_(tuple_shardings),
......@@ -547,6 +570,7 @@ class HloSharding {
maximal_(false),
tuple_(true),
manual_(false),
unknown_(false),
replicate_on_last_tile_dim_(false) {}
// Test-only constructor for sharding format code coverage. Copies the
......@@ -560,6 +584,7 @@ class HloSharding {
maximal_(other.maximal_),
tuple_(other.tuple_),
manual_(other.manual_),
unknown_(other.unknown_),
replicate_on_last_tile_dim_(other.replicate_on_last_tile_dim_) {
CHECK(tile_assignment_ == other.tile_assignment_)
<< tile_assignment_.ToString() << " v.s. "
......@@ -614,6 +639,7 @@ class HloSharding {
bool maximal_;
bool tuple_;
bool manual_;
bool unknown_;
// This flag is to support partial replication and partial sharding. If it is
// true, tile_assignment_ will have an extra dimension in addition to the data
// shape rank, and the added last dimension represents the subgroups of
......
......@@ -987,7 +987,8 @@ void BuildXlaCompilerSubmodule(py::module& m) {
.value("MAXIMAL", OpSharding::MAXIMAL)
.value("MANUAL", OpSharding::MANUAL)
.value("TUPLE", OpSharding::TUPLE)
.value("OTHER", OpSharding::OTHER);
.value("OTHER", OpSharding::OTHER)
.value("UNKNOWN", OpSharding::UNKNOWN);
py::enum_<OpSharding::ShardGroupType> op_sharding_shard_group_type(
m, "OpSharding_ShardGroupType");
......@@ -1068,12 +1069,14 @@ void BuildXlaCompilerSubmodule(py::module& m) {
py::arg("subgroup_types") = absl::Span<const xla::OpSharding::Type>())
.def_static("manual", [] { return HloSharding::Manual(); })
.def_static("replicate", [] { return HloSharding::Replicate(); })
.def_static("unknown", [] { return HloSharding::Unknown(); })
.def("__eq__", [](const xla::HloSharding& a,
const xla::HloSharding& b) { return a == b; })
.def("__hash__",
[](const xla::HloSharding& self) { return absl::HashOf(self); })
.def("is_replicated", &xla::HloSharding::IsReplicated)
.def("is_manual", &xla::HloSharding::IsManual)
.def("is_unknown", &xla::HloSharding::IsUnknown)
.def("is_tiled", &xla::HloSharding::IsTiled)
.def("tile", [](const xla::HloSharding& self,
xla::Shape shape) { return self.TileShape(shape); })
......
......@@ -308,6 +308,7 @@ class OpSharding_Type(enum.IntEnum):
TUPLE: int
OTHER: int
MANUAL: int
UNKNOWN: int
class OpSharding_ShardGroupType(enum.IntEnum):
AS: int
......@@ -346,12 +347,15 @@ class HloSharding:
def replicate() -> HloSharding: ...
@staticmethod
def manual() -> HloSharding: ...
@staticmethod
def unknown() -> HloSharding: ...
def __eq__(self, other: HloSharding) -> bool: ...
def __hash__(self) -> int: ...
def __repr__(self) -> str: ...
def tile(self, shape: Shape) -> Shape: ...
def is_replicated(self) -> bool: ...
def is_manual(self) -> bool: ...
def is_unknown(self) -> bool: ...
def is_tiled(self) -> bool: ...
def tuple_elements(self) -> List[HloSharding]: ...
def num_devices(self) -> int: ...
......
......@@ -40,7 +40,7 @@ while_cond {
arg_cond = f32[2,3,2] parameter(0)
token0 = token[] after-all()
infeed = (pred[], token[]) infeed(token0)
ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
ROOT cond = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
ENTRY main {
......@@ -89,7 +89,7 @@ while_cond {
arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
token0 = token[] after-all()
infeed = (pred[], token[]) infeed(token0)
ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
ROOT cond = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
ENTRY main {
......
......@@ -328,6 +328,7 @@ TokKind HloLexer::LexIdentifier() {
KEYWORD(last_tile_dim_replicate);
KEYWORD(shard_as);
KEYWORD(shard_like);
KEYWORD(unknown);
#undef KEYWORD
......@@ -612,6 +613,8 @@ std::string TokKindToString(TokKind kind) {
return "kw_shard_as";
case TokKind::kw_shard_like:
return "kw_shard_like";
case TokKind::kw_unknown:
return "kw_unknown";
case TokKind::kw_inf:
return "kw_inf";
case TokKind::kNegInf:
......
......@@ -69,6 +69,7 @@ enum class TokKind {
kw_last_tile_dim_replicate,
kw_shard_as,
kw_shard_like,
kw_unknown,
kw_inf,
kNegInf, // -inf
......
......@@ -3223,7 +3223,7 @@ bool HloParserImpl::ParseStatisticsViz(StatisticsViz* statistics_viz) {
return ParseToken(TokKind::kRbrace, "expects '}' at the end of statistics");
}
// ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
// ::= '{' 'replicated'? 'manual'? 'maximal'? 'unknown'? ('device=' int)? shape?
// ('devices=' ('[' dims ']')* device_list)?
// (('shard_like' | 'shard_as') int)* '}'
// ('metadata=' metadata)*
......@@ -3245,6 +3245,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
bool maximal = false;
bool replicated = false;
bool manual = false;
bool unknown = false;
bool last_tile_dim_replicate = false;
bool last_tile_dims = false;
bool shard_like = false;
......@@ -3269,6 +3270,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
manual = true;
lexer_.Lex();
break;
case TokKind::kw_unknown:
unknown = true;
lexer_.Lex();
break;
case TokKind::kAttributeName: {
if (lexer_.GetStrVal() == "device") {
if (lexer_.Lex() != TokKind::kInt) {
......@@ -3421,6 +3426,12 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
"manual shardings should not have any devices assigned");
}
sharding->set_type(OpSharding::MANUAL);
} else if (unknown) {
if (!devices.empty()) {
return Error(loc,
"unknown shardings should not have any devices assigned");
}
sharding->set_type(OpSharding::UNKNOWN);
} else {
if (tile_assignment_dimensions.empty()) {
return Error(
......
......@@ -3625,6 +3625,13 @@ TEST_F(HloParserTest, ParseShardLike) {
original);
}
TEST_F(HloParserTest, ParseUnknownSharding) {
const std::string original = "{unknown}";
TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
EXPECT_EQ(sharding.ToString(), original);
EXPECT_EQ(HloSharding::Unknown().ToString(), original);
}
TEST_F(HloParserTest, ParseFrontendAttributes) {
const std::string original =
R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})";
......
......@@ -795,6 +795,9 @@ message OpSharding {
// This op is manually sharded: the shapes are already partitioned and the
// partitioner should not change this op.
MANUAL = 4;
// This sharding is a placeholder sharding with lowest precedence, it can be
// overwriten by any other shardings.
UNKNOWN = 5;
}
Type type = 1;
// The shape of the sharded tile.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册