提交 1efd820d 编写于 作者: T Tianrun Li 提交者: TensorFlower Gardener

Add a basic test for ragged layout.

DTensor api.relayout and collectives need to recognize the new layout type.

PiperOrigin-RevId: 549477658
上级 b942a91f
......@@ -1107,6 +1107,7 @@ StatusOr<LayoutProto> Layout::ToProto() const {
bool Layout::IsEquivalent(const Layout& b) const {
if (this->rank() != b.rank()) return false;
if (this->mesh() != b.mesh()) return false;
if (this->type() != b.type()) return false;
for (int i = 0; i < this->rank(); ++i) {
if (this->sharding_specs_[i] != b.sharding_specs_[i]) {
if ((this->num_shards_for_dim(i) != 1) || (b.num_shards_for_dim(i) != 1))
......
......@@ -348,9 +348,9 @@ StatusOr<mlir::Value> EmitRelayout(
else
intermediate_specs_1[i] = src_layout.sharding_spec(i);
}
TF_ASSIGN_OR_RETURN(
Layout intermediate_layout_1,
Layout::GetLayout(intermediate_specs_1, src_layout.mesh()));
TF_ASSIGN_OR_RETURN(Layout intermediate_layout_1,
Layout::GetLayout(tgt_layout.type(), intermediate_specs_1,
src_layout.mesh()));
llvm::SmallPtrSet<mlir::Operation*, 4> local_newly_created_ops;
TF_ASSIGN_OR_RETURN(mlir::Value split_result,
......@@ -365,9 +365,9 @@ StatusOr<mlir::Value> EmitRelayout(
else
intermediate_specs_2[i] = intermediate_specs_1[i];
}
TF_ASSIGN_OR_RETURN(
Layout intermediate_layout_2,
Layout::GetLayout(intermediate_specs_2, src_layout.mesh()));
TF_ASSIGN_OR_RETURN(Layout intermediate_layout_2,
Layout::GetLayout(tgt_layout.type(), intermediate_specs_2,
src_layout.mesh()));
TF_ASSIGN_OR_RETURN(
mlir::Value concat_result,
......
......@@ -472,10 +472,9 @@ class RelayoutTest(test_util.DTensorBaseTest):
@combinations.generate(combinations.combine(is_graph=[False, True]))
def test_relayout_to_ragged(self, is_graph):
data = np.array([1, 2, 3, 4.0], dtype='f4')
inp = api.relayout(data, self.y_layout)
def do_relayout():
return api.relayout(inp, self.y_layout.to_ragged())
return api.relayout(data, self.y_layout.to_ragged())
if is_graph:
do_relayout = polymorphic_function.function(do_relayout)
......
......@@ -738,6 +738,29 @@ TEST_F(LayoutTest, RaggedLayoutToFromString) {
EXPECT_THAT(layout.ToProto(),
IsOkAndHolds(EqualsProto(layout_from_str_proto)));
}
TEST_F(LayoutTest, RaggedLayoutEqual) {
TF_ASSERT_OK_AND_ASSIGN(
Layout fully_sharded,
Layout::FromString("sharding_specs:x,y, mesh:|x=2,y=1|*TPU"));
TF_ASSERT_OK_AND_ASSIGN(
Layout x_sharded,
Layout::FromString("sharding_specs:x,unsharded, mesh:|x=2,y=1|*TPU"));
TF_ASSERT_OK_AND_ASSIGN(
Layout x_ragged,
Layout::FromString("ragged:x,unsharded, mesh:|x=2,y=1|*TPU"));
TF_ASSERT_OK_AND_ASSIGN(Layout x_y_ragged,
Layout::FromString("ragged:x,y, mesh:|x=2,y=1|*TPU"));
// Test that 'IsEquivalent' and '==' take layout type into account.
EXPECT_TRUE(x_ragged.IsEquivalent(x_y_ragged));
EXPECT_TRUE(x_y_ragged.IsEquivalent(x_ragged));
EXPECT_FALSE(x_sharded.IsEquivalent(x_ragged));
EXPECT_FALSE(fully_sharded.IsEquivalent(x_y_ragged));
EXPECT_FALSE(x_sharded == x_ragged);
EXPECT_FALSE(fully_sharded == x_y_ragged);
}
} // namespace
} // namespace dtensor
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册