提交 85d3662a 编写于 作者: S Skye Wanderman-Milne 提交者: TensorFlower Gardener

[PJRT] Rewrite PjRtBuffer::ToLiteralSync to use new shape getters when possible.

It's not possible for tuple buffers (yet?).

The eventual goal is for ML frameworks to only call individual getters
instead of using PjRtBuffer::{logical_}on_device_shape, since passing
around xla::Shapes is expensive and often includes more information
than is necessary or even meaningful. We'd like to eventually remove
PJRT_Buffer_OnDeviceTrimmedShape from the PJRT C API altogether
({logical_}on_device_shape will likely stay for non-ML framework
usage).

PiperOrigin-RevId: 549438520
上级 ea6b1678
......@@ -973,9 +973,28 @@ class PjRtBuffer {
// Convenience synchronous overload that allocates a literal with a default
// layout.
StatusOr<std::shared_ptr<Literal>> ToLiteralSync() {
Shape device_shape = on_device_shape();
if (device_shape.is_dynamic()) {
TF_ASSIGN_OR_RETURN(device_shape, logical_on_device_shape());
Shape device_shape;
if (!IsTuple()) {
absl::Span<const int64_t> literal_dims;
std::optional<std::vector<int64_t>> logical_dims_storage;
if (has_dynamic_dimensions()) {
TF_ASSIGN_OR_RETURN(std::vector<int64_t> logical_dims,
logical_dimensions());
logical_dims_storage.emplace(std::move(logical_dims));
literal_dims = *logical_dims_storage;
} else {
literal_dims = dimensions();
}
device_shape = ShapeUtil::MakeShape(element_type(), literal_dims);
*device_shape.mutable_layout() = layout();
} else {
// TODO(skyewm): does anything need to create tuple literals? The PJRT C
// API doesn't support tuples or {logical_}on_device_shape(), so we prefer
// to use the above non-tuple code path where possible.
device_shape = on_device_shape();
if (device_shape.is_dynamic()) {
TF_ASSIGN_OR_RETURN(device_shape, logical_on_device_shape());
}
}
auto literal = std::make_shared<Literal>(
ShapeUtil::DeviceShapeToHostShape(device_shape));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册