提交 2ee09b87 编写于 作者: M Mark Heffernan 提交者: TensorFlower Gardener

[XLA] Various improvements to ShapeTree.

Add support for holding non-copyable types, operator==, and a
CopySubtreeFrom method for copying a subtree from one ShapeTree to
another.

PiperOrigin-RevId: 157777636
上级 4f3ae769
......@@ -44,6 +44,7 @@ struct ShapeTreeNode {
// Children of this node.
std::vector<std::unique_ptr<ShapeTreeNode>> children;
ShapeTreeNode() = default;
explicit ShapeTreeNode(const T& data) : data(data) {}
ShapeTreeNode(const ShapeTreeNode& other)
......@@ -85,8 +86,9 @@ class ShapeTree {
public:
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
// Create ShapeTree with the given shape, and default T values for all nodes.
explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {}
// Create ShapeTree with the given shape, and default-constructed T values for
// all nodes.
explicit ShapeTree(const Shape& shape);
// Create ShapeTree with the given shape, and init_value for all nodes.
ShapeTree(const Shape& shape, const T& init_value);
......@@ -127,6 +129,19 @@ class ShapeTree {
const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
Status ForEachMutableElement(const MutableVisitorFunction& func);
// Copy the subtree of values from 'other' rooted at ShapeIndex
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
// 'target_base_index'.
//
// Precondition: The subshape of other.shape() at index source_base_index must
// be compatible with the subshape of shape() at index target_base_index.
void CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& source_base_index,
const ShapeIndex& target_base_index);
bool operator==(const ShapeTree<T>& other) const;
bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
private:
using Node = internal::ShapeTreeNode<T>;
......@@ -134,6 +149,10 @@ class ShapeTree {
// the given 'init_value'.
void InitChildren(const Shape& shape, const T& init_value, Node* node);
// Initialize node->children based on 'shape'. All children have
// default-constructed data values.
void InitChildren(const Shape& shape, Node* node);
// Helpers for traversing the shape via ForEachElement. The helpers
// recursively traverse the subtree rooted at "index" (defined as in
// ShapeUtil::GetSubshape).
......@@ -165,6 +184,24 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
}
}
template <typename T>
void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
if (ShapeUtil::IsTuple(shape)) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
node->children.emplace_back(new Node());
InitChildren(shape.tuple_shapes(i), node->children.back().get());
}
}
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape& shape) : root_(), shape_(shape) {
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(&shape_);
InitChildren(shape_, &root_);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
: root_(init_value), shape_(shape) {
......@@ -240,6 +277,48 @@ Status ShapeTree<T>::ForEachMutableElement(const MutableVisitorFunction& func) {
return ForEachMutableHelper(func, &root_, &index);
}
template <typename T>
void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& source_base_index,
const ShapeIndex& target_base_index) {
CHECK(ShapeUtil::Compatible(
ShapeUtil::GetSubshape(shape(), target_base_index),
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
ForEachMutableElement(
[this, &other, &source_base_index, &target_base_index](
const ShapeIndex& index, bool /*is_leaf*/, T* data) {
// Copy the data element only if index is in the
// subtree rooted at target_base_index.
for (int i = 0; i < target_base_index.size(); ++i) {
if (i >= index.size() || index[i] != target_base_index[i]) {
return Status::OK();
}
}
// Construct source element index to copy from.
ShapeIndex source_index = source_base_index;
for (int i = target_base_index.size(); i < index.size(); ++i) {
source_index.push_back(index[i]);
}
*data = other.element(source_index);
return Status::OK();
})
.IgnoreError();
}
template <typename T>
bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
bool equal = true;
ForEachElement([this, &other, &equal](const ShapeIndex& index,
bool /*is_leaf*/, const T& data) {
if (data != other.element(index)) {
equal = false;
}
return Status::OK();
})
.IgnoreError();
return equal;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
......@@ -245,5 +245,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
EXPECT_DEATH(shape_tree.element({0, 0}), "");
}
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
*shape_tree.mutable_element({2}) = MakeUnique<int>(42);
EXPECT_EQ(*shape_tree.element({2}), 42);
}
TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) {
// Test CopySubtreeFrom method for a single value copied between array-shaped
// ShapeTrees.
ShapeTree<int> source(array_shape_);
*source.mutable_element(/*index=*/{}) = 42;
ShapeTree<int> destination(array_shape_, 123);
EXPECT_EQ(destination.element(/*index=*/{}), 123);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 42);
}
TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) {
// Test CopySubtreeFrom method for a copy of all elements from one
// tuple-shaped ShapeTree to another.
ShapeTree<int> source(tuple_shape_);
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
*source.mutable_element(/*index=*/{2}) = 13;
ShapeTree<int> destination(tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 10);
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
EXPECT_EQ(destination.element(/*index=*/{2}), 13);
}
TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) {
// Test CopySubtreeFrom method for a copy of a single element from one
// tuple-shaped ShapeTree to another.
ShapeTree<int> source(tuple_shape_);
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
*source.mutable_element(/*index=*/{2}) = 13;
ShapeTree<int> destination(tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{0},
/*target_base_index=*/{1});
EXPECT_EQ(destination.element(/*index=*/{}), 0);
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1}), 11);
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
}
TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) {
// Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a
// nested-tuple-shaped ShapeTree.
ShapeTree<int> source(
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}));
*source.mutable_element(/*index=*/{}) = 10;
*source.mutable_element(/*index=*/{0}) = 11;
*source.mutable_element(/*index=*/{1}) = 12;
ShapeTree<int> destination(nested_tuple_shape_, 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{},
/*target_base_index=*/{2, 0});
EXPECT_EQ(destination.element(/*index=*/{}), 0);
EXPECT_EQ(destination.element(/*index=*/{0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1}), 0);
EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0);
EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0);
EXPECT_EQ(destination.element(/*index=*/{2}), 0);
EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10);
EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11);
EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12);
EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0);
}
TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) {
// Test CopySubtreeFrom method for a copy from a nested-tuple-shape.
ShapeTree<int> source(nested_tuple_shape_, 42);
*source.mutable_element(/*index=*/{1}) = 10;
*source.mutable_element(/*index=*/{1, 0}) = 11;
*source.mutable_element(/*index=*/{1, 1}) = 12;
ShapeTree<int> destination(
ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0);
destination.CopySubtreeFrom(source, /*source_base_index=*/{1},
/*target_base_index=*/{});
EXPECT_EQ(destination.element(/*index=*/{}), 10);
EXPECT_EQ(destination.element(/*index=*/{0}), 11);
EXPECT_EQ(destination.element(/*index=*/{1}), 12);
}
TEST_F(ShapeTreeTest, OperatorEquals) {
{
ShapeTree<int> a(array_shape_, 123);
ShapeTree<int> b(array_shape_, 42);
ShapeTree<int> c(array_shape_, 42);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a != b);
EXPECT_TRUE(b == c);
}
{
ShapeTree<int> a(tuple_shape_);
*a.mutable_element(/*index=*/{}) = 10;
*a.mutable_element(/*index=*/{0}) = 11;
*a.mutable_element(/*index=*/{1}) = 12;
ShapeTree<int> b(tuple_shape_);
*b.mutable_element(/*index=*/{}) = 10;
*b.mutable_element(/*index=*/{0}) = 42;
*b.mutable_element(/*index=*/{1}) = 11;
ShapeTree<int> c(tuple_shape_);
*c.mutable_element(/*index=*/{}) = 10;
*c.mutable_element(/*index=*/{0}) = 42;
*c.mutable_element(/*index=*/{1}) = 11;
EXPECT_FALSE(a == b);
EXPECT_TRUE(a != b);
EXPECT_TRUE(b == c);
EXPECT_FALSE(b != c);
}
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册