diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index 2583332f7101273912de412612cf956aa6592a89..3b3d1f516b4c53e28963da012d104bc53e8be026 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -276,12 +276,12 @@ HloSharding HloSharding::Subgroup( absl::InlinedVector transposed_shape = merged_shape; std::vector merged_types; static constexpr std::array - kOrderedTypes = {OpSharding::MAXIMAL, OpSharding::TUPLE, - OpSharding::OTHER, OpSharding::MANUAL, - OpSharding::REPLICATED}; + kOrderedTypes = {OpSharding::MAXIMAL, OpSharding::TUPLE, + OpSharding::OTHER, OpSharding::MANUAL, + OpSharding::REPLICATED, OpSharding::UNKNOWN}; static_assert(kOrderedTypes[0] == 1 && kOrderedTypes[1] == 2 && kOrderedTypes[2] == 3 && kOrderedTypes[3] == 4 && - kOrderedTypes[4] == 0); + kOrderedTypes[4] == 0 && kOrderedTypes[5] == 5); for (OpSharding::Type type : kOrderedTypes) { auto& dims = type_to_dims[type]; if (dims.empty()) continue; @@ -426,6 +426,15 @@ void HloSharding::Print(Printer* printer, bool include_metadata) const { printer->Append("}"); return; } + + if (unknown_) { + printer->Append("{unknown"); + print_shard_group(); + print_metadata(); + printer->Append("}"); + return; + } + if (maximal_) { AppendCat(printer, "{maximal device=", static_cast(*tile_assignment_.array().begin())); @@ -510,6 +519,7 @@ std::map HloSharding::UsedDevices(int64_t* count) const { std::vector HloSharding::TileIndexForDevice(int64_t device) const { CHECK(!maximal_); CHECK(!IsManual()); + CHECK(!IsUnknown()); CHECK(!IsTuple()); std::vector ret_index; tile_assignment_.Each([&](absl::Span index, int64_t d) { @@ -525,6 +535,7 @@ std::vector HloSharding::TileIndexForDevice(int64_t device) const { int64_t HloSharding::DeviceForTileIndex(absl::Span index) const { CHECK(!replicated_); CHECK(!IsManual()); + CHECK(!IsUnknown()); CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.array().begin(); @@ -545,6 +556,7 @@ std::vector HloSharding::TileOffsetForDevice(const Shape& shape, int64_t device) const { CHECK(!IsTuple()); CHECK(!IsManual()); + CHECK(!IsUnknown()); if (maximal_) { return std::vector(shape.dimensions_size(), 0); @@ -563,6 +575,7 @@ std::vector HloSharding::TileLimitForDevice(const Shape& shape, int64_t device) const { CHECK(!IsTuple()); CHECK(!IsManual()); + CHECK(!IsUnknown()); if (maximal_) { return std::vector(shape.dimensions().begin(), @@ -744,7 +757,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, tile_assignment_.iota_->num_elements() == *num_devices; } - if (IsTileMaximal() || IsManual()) { + if (IsTileMaximal() || IsManual() || IsUnknown()) { return OkStatus(); } @@ -798,6 +811,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Replicate(metadata).SetShardGroupFromProto(proto); } else if (proto.type() == OpSharding::MANUAL) { return Manual(metadata).SetShardGroupFromProto(proto); + } else if (proto.type() == OpSharding::UNKNOWN) { + return Unknown(metadata).SetShardGroupFromProto(proto); } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0), metadata) .SetShardGroupFromProto(proto); @@ -921,6 +936,9 @@ OpSharding HloSharding::ToProto() const { } else if (IsManual()) { result.set_type(OpSharding::MANUAL); result.clear_tile_assignment_dimensions(); + } else if (IsUnknown()) { + result.set_type(OpSharding::UNKNOWN); + result.clear_tile_assignment_dimensions(); } else { result.set_type(OpSharding::OTHER); result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim()); @@ -942,7 +960,7 @@ OpSharding HloSharding::ToProto() const { } Shape HloSharding::TileShape(const Shape& shape) const { - if (IsTileMaximal() || IsManual()) { + if (IsTileMaximal() || IsManual() || IsUnknown()) { return shape; } Shape result_shape = shape; @@ -954,7 +972,7 @@ Shape HloSharding::TileShape(const Shape& shape) const { } Shape HloSharding::TileShape(const Shape& shape, int64_t device) const { - if (IsTileMaximal() || IsManual()) { + if (IsTileMaximal() || IsManual() || IsUnknown()) { return shape; } @@ -977,6 +995,7 @@ int64_t HloSharding::TotalNumTiles() const { return 1; } CHECK(!IsManual()); + CHECK(!IsUnknown()); return Product(absl::Span(tile_assignment_.dimensions())); } @@ -985,6 +1004,7 @@ int64_t HloSharding::NumTiles() const { return 1; } CHECK(!IsManual()); + CHECK(!IsUnknown()); return Product(absl::Span(tile_assignment_.dimensions()) .subspan(0, TiledDataRank())); } diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.h b/third_party/xla/xla/hlo/ir/hlo_sharding.h index 53fcd4343a9fd42ab55f6e63e88bb7bdf5498a50..e30d2acc1ba3a2c34f1dd8616ce6d24965d39e0b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.h @@ -45,12 +45,20 @@ class HloSharding { // Creates a trivial sharding that replicates a maximal tile across all // devices. static HloSharding Replicate(absl::Span metadata = {}) { - return HloSharding(/*manual=*/false, /*replicated=*/true, metadata); + return HloSharding(/*manual=*/false, /*replicated=*/true, /*unknown=*/false, + metadata); } // Creates a sharding that represents the op is manually partitioned. static HloSharding Manual(absl::Span metadata = {}) { - return HloSharding(/*manual=*/true, /*replicated=*/false, metadata); + return HloSharding(/*manual=*/true, /*replicated=*/false, /*unknown=*/false, + metadata); + } + + // Creates a sharding that represents the op has a placeholder sharding. + static HloSharding Unknown(absl::Span metadata = {}) { + return HloSharding(/*manual=*/false, /*replicated=*/false, /*unknown=*/true, + metadata); } // Creates a sharding that emulates device placement; a tile shape equal to @@ -189,6 +197,15 @@ class HloSharding { [](const HloSharding& s) { return s.IsManual(); }); } + // Returns whether the sharding represents a placeholder sharding. + bool IsUnknown() const { + if (!IsTuple()) { + return unknown_; + } + return absl::c_all_of(tuple_elements_, + [](const HloSharding& s) { return s.IsUnknown(); }); + } + bool IsShardGroup() const { if (!IsTuple()) { return shard_group_.shard_group_id != -1 && @@ -226,7 +243,9 @@ class HloSharding { // Returns weather the sharding represents a tiled sharding where the mapping // between devices and tiles is represented through 'tile_assignment()'. - bool IsTiled() const { return !IsTileMaximal() && !IsManual(); } + bool IsTiled() const { + return !IsTileMaximal() && !IsManual() && !IsUnknown(); + } // Returns if the sharding has partial replication and partial sharding. If // true, data is sharded according to other dimensions of tile_assignment(), @@ -334,7 +353,7 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - manual_ == other.manual_ && + manual_ == other.manual_ && unknown_ == other.unknown_ && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_ && replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_ && @@ -349,7 +368,7 @@ class HloSharding { return H::combine(std::move(h), sharding.tuple_elements_); } return H::combine(std::move(h), sharding.replicated_, sharding.manual_, - sharding.tile_assignment_.array(), + sharding.unknown_, sharding.tile_assignment_.array(), sharding.replicate_on_last_tile_dim_, sharding.shard_group_.ToString()); } @@ -393,7 +412,7 @@ class HloSharding { std::vector& metadata() { return metadata_; } const std::vector& metadata() const { return metadata_; } - // Returns the replication subgroiup dim, or -1 if it doesn't exist. + // Returns the replication subgroup dim, or -1 if it doesn't exist. int64_t SubgroupReplicationDim() const { auto it = absl::c_find(subgroup_types_, OpSharding::REPLICATED); if (it != subgroup_types_.end()) { @@ -498,13 +517,14 @@ class HloSharding { const ShardGroup& GetShardGroup() const { return shard_group_; } private: - explicit HloSharding(bool manual, bool replicated, + explicit HloSharding(bool manual, bool replicated, bool unknown, absl::Span metadata) : metadata_(metadata.begin(), metadata.end()), replicated_(replicated), maximal_(replicated), tuple_(false), manual_(manual), + unknown_(unknown), replicate_on_last_tile_dim_(false) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning @@ -519,6 +539,7 @@ class HloSharding { maximal_(true), tuple_(false), manual_(false), + unknown_(false), replicate_on_last_tile_dim_(false) {} explicit HloSharding(TileAssignment tile_assignment, bool replicate_on_last_tile_dim, @@ -529,6 +550,7 @@ class HloSharding { maximal_(false), tuple_(false), manual_(false), + unknown_(false), replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {} explicit HloSharding(TileAssignment tile_assignment, absl::Span subgroup_types, @@ -540,6 +562,7 @@ class HloSharding { maximal_(false), tuple_(false), manual_(false), + unknown_(false), replicate_on_last_tile_dim_(false) {} explicit HloSharding(const std::vector& tuple_shardings) : tuple_elements_(tuple_shardings), @@ -547,6 +570,7 @@ class HloSharding { maximal_(false), tuple_(true), manual_(false), + unknown_(false), replicate_on_last_tile_dim_(false) {} // Test-only constructor for sharding format code coverage. Copies the @@ -560,6 +584,7 @@ class HloSharding { maximal_(other.maximal_), tuple_(other.tuple_), manual_(other.manual_), + unknown_(other.unknown_), replicate_on_last_tile_dim_(other.replicate_on_last_tile_dim_) { CHECK(tile_assignment_ == other.tile_assignment_) << tile_assignment_.ToString() << " v.s. " @@ -614,6 +639,7 @@ class HloSharding { bool maximal_; bool tuple_; bool manual_; + bool unknown_; // This flag is to support partial replication and partial sharding. If it is // true, tile_assignment_ will have an extra dimension in addition to the data // shape rank, and the added last dimension represents the subgroups of diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index 7b30a85451995b9874b8f2e0e455dbda29f0fc53..4c18c7ba72b4a5688616d2d113921b0b9e714c70 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -987,7 +987,8 @@ void BuildXlaCompilerSubmodule(py::module& m) { .value("MAXIMAL", OpSharding::MAXIMAL) .value("MANUAL", OpSharding::MANUAL) .value("TUPLE", OpSharding::TUPLE) - .value("OTHER", OpSharding::OTHER); + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); py::enum_ op_sharding_shard_group_type( m, "OpSharding_ShardGroupType"); @@ -1068,12 +1069,14 @@ void BuildXlaCompilerSubmodule(py::module& m) { py::arg("subgroup_types") = absl::Span()) .def_static("manual", [] { return HloSharding::Manual(); }) .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) .def("__eq__", [](const xla::HloSharding& a, const xla::HloSharding& b) { return a == b; }) .def("__hash__", [](const xla::HloSharding& self) { return absl::HashOf(self); }) .def("is_replicated", &xla::HloSharding::IsReplicated) .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unknown", &xla::HloSharding::IsUnknown) .def("is_tiled", &xla::HloSharding::IsTiled) .def("tile", [](const xla::HloSharding& self, xla::Shape shape) { return self.TileShape(shape); }) diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index d5ec9dadd4849518c2b4b444dd223611e3f7a4ea..6ea7e598f43681fbe043643954a60f68a4ac9a1d 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -308,6 +308,7 @@ class OpSharding_Type(enum.IntEnum): TUPLE: int OTHER: int MANUAL: int + UNKNOWN: int class OpSharding_ShardGroupType(enum.IntEnum): AS: int @@ -346,12 +347,15 @@ class HloSharding: def replicate() -> HloSharding: ... @staticmethod def manual() -> HloSharding: ... + @staticmethod + def unknown() -> HloSharding: ... def __eq__(self, other: HloSharding) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... def tile(self, shape: Shape) -> Shape: ... def is_replicated(self) -> bool: ... def is_manual(self) -> bool: ... + def is_unknown(self) -> bool: ... def is_tiled(self) -> bool: ... def tuple_elements(self) -> List[HloSharding]: ... def num_devices(self) -> int: ... diff --git a/third_party/xla/xla/service/cpu/tests/cpu_literal_caching_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_literal_caching_test.cc index cec5492a13cffc073eb0051f2f3cfffedc6f0c66..c010cc7cf79ab8b690cac0dba2740aa1a13e10b4 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -40,7 +40,7 @@ while_cond { arg_cond = f32[2,3,2] parameter(0) token0 = token[] after-all() infeed = (pred[], token[]) infeed(token0) - ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 + ROOT cond = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -89,7 +89,7 @@ while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) token0 = token[] after-all() infeed = (pred[], token[]) infeed(token0) - ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 + ROOT cond = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { diff --git a/third_party/xla/xla/service/hlo_lexer.cc b/third_party/xla/xla/service/hlo_lexer.cc index 3fdf7888599e0597bfc705ae5d958818bbe29c40..0e53e6f844e99f676e9d46bbb14a0cd963c67dea 100644 --- a/third_party/xla/xla/service/hlo_lexer.cc +++ b/third_party/xla/xla/service/hlo_lexer.cc @@ -328,6 +328,7 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(last_tile_dim_replicate); KEYWORD(shard_as); KEYWORD(shard_like); + KEYWORD(unknown); #undef KEYWORD @@ -612,6 +613,8 @@ std::string TokKindToString(TokKind kind) { return "kw_shard_as"; case TokKind::kw_shard_like: return "kw_shard_like"; + case TokKind::kw_unknown: + return "kw_unknown"; case TokKind::kw_inf: return "kw_inf"; case TokKind::kNegInf: diff --git a/third_party/xla/xla/service/hlo_lexer.h b/third_party/xla/xla/service/hlo_lexer.h index ac498a90868825a62eb17255dbcdaa3821255a1e..031ec1ae295330c7c6f25c6348e7a68ac31885d7 100644 --- a/third_party/xla/xla/service/hlo_lexer.h +++ b/third_party/xla/xla/service/hlo_lexer.h @@ -69,6 +69,7 @@ enum class TokKind { kw_last_tile_dim_replicate, kw_shard_as, kw_shard_like, + kw_unknown, kw_inf, kNegInf, // -inf diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index 6c44438a6a93eeb456f0bca9675748e481658dc2..faae4f7aa0000c2e2b6bb76b5668cc693724abc6 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -3223,7 +3223,7 @@ bool HloParserImpl::ParseStatisticsViz(StatisticsViz* statistics_viz) { return ParseToken(TokKind::kRbrace, "expects '}' at the end of statistics"); } -// ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape? +// ::= '{' 'replicated'? 'manual'? 'maximal'? 'unknown'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? // (('shard_like' | 'shard_as') int)* '}' // ('metadata=' metadata)* @@ -3245,6 +3245,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, bool maximal = false; bool replicated = false; bool manual = false; + bool unknown = false; bool last_tile_dim_replicate = false; bool last_tile_dims = false; bool shard_like = false; @@ -3269,6 +3270,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, manual = true; lexer_.Lex(); break; + case TokKind::kw_unknown: + unknown = true; + lexer_.Lex(); + break; case TokKind::kAttributeName: { if (lexer_.GetStrVal() == "device") { if (lexer_.Lex() != TokKind::kInt) { @@ -3421,6 +3426,12 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, "manual shardings should not have any devices assigned"); } sharding->set_type(OpSharding::MANUAL); + } else if (unknown) { + if (!devices.empty()) { + return Error(loc, + "unknown shardings should not have any devices assigned"); + } + sharding->set_type(OpSharding::UNKNOWN); } else { if (tile_assignment_dimensions.empty()) { return Error( diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 9adf66000c51a9a2487189037fa10dc64c9bc57d..20220bd38034b2398481952a88762acf1737e96e 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -3625,6 +3625,13 @@ TEST_F(HloParserTest, ParseShardLike) { original); } +TEST_F(HloParserTest, ParseUnknownSharding) { + const std::string original = "{unknown}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); + EXPECT_EQ(HloSharding::Unknown().ToString(), original); +} + TEST_F(HloParserTest, ParseFrontendAttributes) { const std::string original = R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})"; diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 3cd6b87322d4c8633c51d863a22877c23bb5aba0..46f4dba1279444a0b65a105f64228878df225a76 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -795,6 +795,9 @@ message OpSharding { // This op is manually sharded: the shapes are already partitioned and the // partitioner should not change this op. MANUAL = 4; + // This sharding is a placeholder sharding with lowest precedence, it can be + // overwriten by any other shardings. + UNKNOWN = 5; } Type type = 1; // The shape of the sharded tile.