提交 840633cb 编写于 作者: S Skye Wanderman-Milne 提交者: TensorFlower Gardener

[PJRT:C] Remove xla::Shape getter and related logic from PJRT C API

This changes removes `PJRT_Buffer_OnDeviceTrimmedShape` and related
functionality, since passing around xla::Shapes is expensive and often
includes more information than is necessary or even meaningful, and we
now have more specific getters that should be used instead.

This means that `PjRtBuffer::on_device_shape` and
`PjRtBuffer:logical_on_device_shape` no longer work with the
PjRtCApiClient. Instead, C++ callers should use the new more-specific
getters:

* `PjRtBuffer::element_type()`
* `PjRtBuffer::dimensions()`
* `PjRtBuffer::layout()`
* `PjRtBuffer::has_dynamic_dimensions()`
* `PjRtBuffer::is_dynamic_dimension()`
* `PjRtBuffer::logical_dimensions()`

These all have corresponding C APIs.

PiperOrigin-RevId: 561148407
上级 29bb1223
......@@ -53,7 +53,7 @@ extern "C" {
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 24
#define PJRT_API_MINOR 25
// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
......@@ -1441,80 +1441,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_GetMemoryLayout_Args, layout);
typedef PJRT_Error* PJRT_Buffer_GetMemoryLayout(
PJRT_Buffer_GetMemoryLayout_Args* args);
// Maximum number of array elements to inline into structs for performance.
#define PJRT_C_API_MAX_INLINED 6
typedef struct PJRT_IntList {
union {
int* heap; // owned
int inlined[PJRT_C_API_MAX_INLINED];
};
int64_t size;
} PJRT_IntList;
typedef struct PJRT_Int64List {
union {
int64_t* heap; // owned
int64_t inlined[PJRT_C_API_MAX_INLINED];
};
int64_t size;
} PJRT_Int64List;
typedef struct PJRT_BoolList {
union {
bool* heap; // owned
bool inlined[PJRT_C_API_MAX_INLINED];
};
int64_t size;
} PJRT_BoolList;
typedef struct PJRT_XLA_Tile {
PJRT_Int64List dimensions;
} PJRT_XLA_Tile;
typedef struct PJRT_XLA_TileList {
union {
PJRT_XLA_Tile* heap; // owned
PJRT_XLA_Tile inlined[PJRT_C_API_MAX_INLINED];
};
int64_t size;
} PJRT_XLA_TileList;
typedef struct PJRT_XLA_Layout {
PJRT_Int64List minor_to_major;
PJRT_IntList dim_level_types;
PJRT_IntList dim_unique;
PJRT_IntList dim_ordered;
PJRT_XLA_TileList tiles;
int index_primitive_type;
int pointer_primitive_type;
int64_t element_size_in_bits;
int64_t memory_space;
int64_t dynamic_shape_metadata_prefix_bytes;
} PJRT_XLA_Layout;
// This trimmed shape doesn't have any Tuple information. In case of Tuple,
// assert is triggered from the C API Client.
// TODO(b/238999986): This is a temporary solution. Remove this later.
struct PJRT_Buffer_OnDeviceTrimmedShape_Args {
size_t struct_size;
void* priv;
PJRT_Buffer* buffer;
int element_type; // out
PJRT_Int64List dimensions; // out
PJRT_BoolList dynamic_dimensions; // out
bool has_layout;
// Whether it calls logical_on_device_shape.
bool is_logical_on_device_shape;
PJRT_XLA_Layout layout; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_OnDeviceTrimmedShape_Args, layout);
// Return the trimmed shape from PjRtBuffer.
// TODO(b/238999986): Replace this with decomposed shape methods.
typedef PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
PJRT_Buffer_OnDeviceTrimmedShape_Args* args);
struct PJRT_Buffer_ToHostBuffer_Args {
size_t struct_size;
void* priv;
......@@ -1991,7 +1917,6 @@ typedef struct {
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_UnpaddedDimensions);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_DynamicDimensionIndices);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_GetMemoryLayout);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceTrimmedShape);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceSizeInBytes);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_Device);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete);
......
......@@ -748,4 +748,33 @@ xla::StatusOr<xla::Layout> ConvertToLayout(
return layout;
}
PJRT_Buffer_Type GetElementType(const PJRT_Api* api, PJRT_Buffer* buffer) {
PJRT_Buffer_ElementType_Args args;
args.struct_size = PJRT_Buffer_ElementType_Args_STRUCT_SIZE;
args.priv = nullptr;
args.buffer = buffer;
LogFatalIfPjrtError(api->PJRT_Buffer_ElementType(&args), api);
return args.type;
}
absl::Span<const int64_t> GetDimensions(const PJRT_Api* api,
PJRT_Buffer* buffer) {
PJRT_Buffer_Dimensions_Args args;
args.struct_size = PJRT_Buffer_Dimensions_Args_STRUCT_SIZE;
args.priv = nullptr;
args.buffer = buffer;
LogFatalIfPjrtError(api->PJRT_Buffer_Dimensions(&args), api);
return {args.dims, args.num_dims};
}
PJRT_Buffer_MemoryLayout GetMemoryLayout(const PJRT_Api* api,
PJRT_Buffer* buffer) {
PJRT_Buffer_GetMemoryLayout_Args args;
args.struct_size = PJRT_Buffer_GetMemoryLayout_Args_STRUCT_SIZE;
args.priv = nullptr;
args.buffer = buffer;
LogFatalIfPjrtError(api->PJRT_Buffer_GetMemoryLayout(&args), api);
return args.layout;
}
} // namespace pjrt
......@@ -221,6 +221,12 @@ xla::StatusOr<BufferMemoryLayoutData> ConvertToBufferMemoryLayoutData(
xla::StatusOr<xla::Layout> ConvertToLayout(
const PJRT_Buffer_MemoryLayout_Tiled& c_tiled);
PJRT_Buffer_Type GetElementType(const PJRT_Api* api, PJRT_Buffer* buffer);
absl::Span<const int64_t> GetDimensions(const PJRT_Api* api,
PJRT_Buffer* buffer);
PJRT_Buffer_MemoryLayout GetMemoryLayout(const PJRT_Api* api,
PJRT_Buffer* buffer);
} // namespace pjrt
#endif // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_
......@@ -293,22 +293,8 @@ class PjrtCApiTest : public PjrtCApiTestBase {
}
CHECK_EQ(args.dst_size, sizeof(float));
PJRT_Buffer_OnDeviceTrimmedShape_Args shape_args{
.struct_size = PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE,
.priv = nullptr,
.buffer = src_buffer,
.element_type = -1,
.dimensions = {},
.dynamic_dimensions = {},
.has_layout = false,
.layout = {},
};
error = api_->PJRT_Buffer_OnDeviceTrimmedShape(&shape_args);
if (error != nullptr) {
return ::pjrt::PjrtErrorToStatus(error, api_);
}
CHECK_EQ(shape_args.dimensions.size, 0);
CHECK_EQ(shape_args.element_type, xla::PrimitiveType::F32);
CHECK_EQ(::pjrt::GetDimensions(api_, src_buffer).size(), 0);
CHECK_EQ(::pjrt::GetElementType(api_, src_buffer), PJRT_Buffer_Type_F32);
float value;
args.dst = &value;
......
......@@ -1336,101 +1336,7 @@ PJRT_Error* PJRT_LoadedExecutable_GetExecutable(
return nullptr;
}
namespace {
// Helper functions for copying data to possibly-inlined C arrays.
// 'Src' and 'Dst' are allowed to be different types to make this usable with
// memory-identical types, e.g. int64_t and int64_t. This should not be used
// with types that require a static_cast.
template <typename Src, typename Dst, typename DstList>
static void CreateVectorBase(const absl::Span<Src> src, DstList* dst) {
dst->size = src.size();
if (dst->size > PJRT_C_API_MAX_INLINED) {
dst->heap = new Dst[dst->size];
std::copy(src.begin(), src.end(), dst->heap);
} else {
std::copy(src.begin(), src.end(), dst->inlined);
}
}
void CreateVector(const absl::Span<const int64_t> src, PJRT_Int64List* dst) {
return CreateVectorBase<const int64_t, int64_t, PJRT_Int64List>(src, dst);
}
void CreateVector(const absl::Span<const bool> src, PJRT_BoolList* dst) {
return CreateVectorBase<const bool, bool, PJRT_BoolList>(src, dst);
}
static void CreateVector(const absl::Span<const xla::DimLevelType> src,
PJRT_IntList* dst) {
CreateVectorBase<const xla::DimLevelType, int, PJRT_IntList>(src, dst);
}
void CreateVector(const absl::Span<const bool> src, PJRT_IntList* dst) {
CreateVectorBase<const bool, int, PJRT_IntList>(src, dst);
}
void ToC(const xla::Tile& tile, PJRT_XLA_Tile* c_tile) {
CreateVector(tile.dimensions(), &c_tile->dimensions);
}
void CreateVector(const absl::Span<const xla::Tile> src,
PJRT_XLA_TileList* dst) {
dst->size = src.size();
PJRT_XLA_Tile* c_tiles;
if (dst->size > PJRT_C_API_MAX_INLINED) {
dst->heap = new PJRT_XLA_Tile[dst->size];
c_tiles = dst->heap;
} else {
c_tiles = dst->inlined;
}
for (int i = 0; i < dst->size; ++i) {
ToC(src[i], &c_tiles[i]);
}
}
void ToC(const xla::Layout& layout, PJRT_XLA_Layout* c_layout) {
CreateVector(layout.minor_to_major(), &c_layout->minor_to_major);
CreateVector(layout.dim_level_types(), &c_layout->dim_level_types);
CreateVector(layout.dim_unique(), &c_layout->dim_unique);
CreateVector(layout.dim_ordered(), &c_layout->dim_ordered);
c_layout->index_primitive_type = layout.index_primitive_type();
c_layout->pointer_primitive_type = layout.pointer_primitive_type();
c_layout->element_size_in_bits = layout.element_size_in_bits();
c_layout->memory_space = layout.memory_space();
c_layout->dynamic_shape_metadata_prefix_bytes =
layout.dynamic_shape_metadata_prefix_bytes();
CreateVector(layout.tiles(), &c_layout->tiles);
}
} // namespace
// ---------------------------------- Buffers ----------------------------------
// TODO(b/238999986): Replace this with decomposed shape methods.
PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
PJRT_Buffer_OnDeviceTrimmedShape_Args* args) {
PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
"PJRT_Buffer_OnDeviceTrimmedShape_Args",
PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE, args->struct_size));
xla::Shape shape;
if (args->is_logical_on_device_shape) {
PJRT_ASSIGN_OR_RETURN(shape,
args->buffer->buffer->logical_on_device_shape());
} else {
shape = args->buffer->buffer->on_device_shape();
}
args->element_type = shape.element_type();
CreateVector(shape.dimensions(), &args->dimensions);
CreateVector(shape.dynamic_dimensions(), &args->dynamic_dimensions);
if (shape.has_layout()) {
args->has_layout = true;
ToC(shape.layout(), &args->layout);
} else {
args->has_layout = false;
}
return nullptr;
}
PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args) {
PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
......
......@@ -279,8 +279,6 @@ PJRT_Error* PJRT_Buffer_UnpaddedDimensions(
PJRT_Error* PJRT_Buffer_DynamicDimensionIndices(
PJRT_Buffer_DynamicDimensionIndices_Args* args);
PJRT_Error* PJRT_Buffer_GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args* args);
PJRT_Error* PJRT_Buffer_OnDeviceTrimmedShape(
PJRT_Buffer_OnDeviceTrimmedShape_Args* args);
PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes(
PJRT_Buffer_OnDeviceSizeInBytes_Args* args);
PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args);
......@@ -486,8 +484,6 @@ constexpr PJRT_Api CreatePjrtApi(
pjrt::PJRT_Buffer_DynamicDimensionIndices,
/*PJRT_Buffer_GetMemoryLayout=*/
pjrt::PJRT_Buffer_GetMemoryLayout,
/*PJRT_Buffer_OnDeviceTrimmedShape=*/
pjrt::PJRT_Buffer_OnDeviceTrimmedShape,
/*PJRT_Buffer_OnDeviceSizeInBytes=*/
pjrt::PJRT_Buffer_OnDeviceSizeInBytes,
/*PJRT_Buffer_Device=*/pjrt::PJRT_Buffer_Device,
......
......@@ -1522,7 +1522,6 @@ PjRtCApiBuffer::PjRtCApiBuffer(PjRtCApiClient* client, PJRT_Buffer* buffer)
buffer_(buffer, ::pjrt::MakeBufferDeleter(client->pjrt_c_api())),
readiness_event_(nullptr,
::pjrt::MakeEventDeleter(client->pjrt_c_api())) {
set_shape();
}
PrimitiveType PjRtCApiBuffer::element_type() const {
......@@ -1595,138 +1594,6 @@ StatusOr<std::vector<int64_t>> PjRtCApiBuffer::logical_dimensions() {
args.unpadded_dims + args.num_dims);
}
const Shape& PjRtCApiBuffer::on_device_shape() const {
CHECK(shape_.has_value())
<< "Shape should be initialized in PjRtCApiBuffer constructor.";
return shape_.value();
}
namespace {
// TODO(b/238999986): these utilities exist only to serialize an XLA shape, and
// will likely be removed, in favor of a more targeted representation of shapes.
// Helper functions for creating a view of possibly-inlined C arrays.
// 'Src' and 'Dst' are allowed to be different types to make this usable with
// memory-identical types, e.g. int64_t and int64_t. This should not be used
// with types that require a static_cast.
template <typename Dst, typename Src, typename SrcList>
static absl::Span<const Dst> MakeSpanBase(const SrcList& src_list) {
static_assert(sizeof(Src) == sizeof(Dst), "Mismatched types");
const Src* src = src_list.size > PJRT_C_API_MAX_INLINED
? src_list.heap
: &src_list.inlined[0];
return absl::Span<const Dst>(reinterpret_cast<const Dst*>(src),
src_list.size);
}
absl::Span<const int> MakeSpan(const PJRT_IntList& src_list) {
return MakeSpanBase<int, int, PJRT_IntList>(src_list);
}
absl::Span<const int64_t> MakeSpan(const PJRT_Int64List& src_list) {
return MakeSpanBase<int64_t, int64_t, PJRT_Int64List>(src_list);
}
absl::Span<const bool> MakeSpan(const PJRT_BoolList& src_list) {
return MakeSpanBase<bool, bool, PJRT_BoolList>(src_list);
}
xla::Tile FromC(const PJRT_XLA_Tile* c_tile) {
absl::Span<const int64_t> dims = MakeSpan(c_tile->dimensions);
return xla::Tile(dims);
}
xla::Layout FromC(const PJRT_XLA_Layout* c_layout) {
absl::Span<const int64_t> minor_to_major = MakeSpan(c_layout->minor_to_major);
absl::Span<const int> dim_level_type_ints =
MakeSpan(c_layout->dim_level_types);
xla::DimLevelTypeVector dim_level_types;
dim_level_types.reserve(dim_level_type_ints.size());
for (int dim_level_type : dim_level_type_ints) {
dim_level_types.push_back(static_cast<xla::DimLevelType>(dim_level_type));
}
absl::Span<const int> dim_unique_ints = MakeSpan(c_layout->dim_unique);
absl::InlinedVector<bool, xla::InlineRank()> dim_unique(
dim_unique_ints.begin(), dim_unique_ints.end());
absl::Span<const int> dim_ordered_ints = MakeSpan(c_layout->dim_unique);
absl::InlinedVector<bool, xla::InlineRank()> dim_ordered(
dim_ordered_ints.begin(), dim_ordered_ints.end());
absl::InlinedVector<xla::Tile, 1> tiles;
const PJRT_XLA_Tile* c_tiles = c_layout->tiles.size > PJRT_C_API_MAX_INLINED
? c_layout->tiles.heap
: c_layout->tiles.inlined;
tiles.reserve(c_layout->tiles.size);
for (int i = 0; i < c_layout->tiles.size; ++i) {
tiles.push_back(FromC(&c_tiles[i]));
}
return xla::Layout(
minor_to_major, dim_level_types, dim_unique, dim_ordered, tiles,
static_cast<xla::PrimitiveType>(c_layout->index_primitive_type),
static_cast<xla::PrimitiveType>(c_layout->pointer_primitive_type),
c_layout->element_size_in_bits, c_layout->memory_space,
/*physical_shape=*/nullptr,
c_layout->dynamic_shape_metadata_prefix_bytes);
}
} // namespace
static Shape GetDeviceShape(PJRT_Buffer* c_buffer, const PJRT_Api* api,
bool is_logical_on_device_shape) {
PJRT_Buffer_OnDeviceTrimmedShape_Args args;
args.struct_size = PJRT_Buffer_OnDeviceTrimmedShape_Args_STRUCT_SIZE;
args.priv = nullptr;
args.buffer = c_buffer;
args.is_logical_on_device_shape = is_logical_on_device_shape;
pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_OnDeviceTrimmedShape(&args), api);
xla::PrimitiveType element_type =
static_cast<xla::PrimitiveType>(args.element_type);
CHECK_NE(element_type, xla::PrimitiveType::TUPLE);
absl::Span<const int64_t> dims = MakeSpan(args.dimensions);
absl::Span<const bool> dynamic_dims = MakeSpan(args.dynamic_dimensions);
Shape trimmed_shape = Shape(element_type, dims, dynamic_dims, {});
if (args.has_layout) {
*(trimmed_shape.mutable_layout()) = FromC(&args.layout);
}
// TODO(amangu): Refactor the deletion.
if (args.dimensions.size > PJRT_C_API_MAX_INLINED) {
delete[] args.dimensions.heap;
}
if (args.dynamic_dimensions.size > PJRT_C_API_MAX_INLINED) {
delete[] args.dynamic_dimensions.heap;
}
if (args.has_layout) {
if (args.layout.minor_to_major.size > PJRT_C_API_MAX_INLINED) {
delete[] args.layout.minor_to_major.heap;
}
if (args.layout.tiles.size > PJRT_C_API_MAX_INLINED) {
delete[] args.layout.tiles.heap;
}
}
return trimmed_shape;
}
void PjRtCApiBuffer::set_shape() {
shape_ = GetDeviceShape(buffer_.get(), client_->pjrt_c_api(),
/*is_logical_on_device_shape=*/false);
}
StatusOr<Shape> PjRtCApiBuffer::logical_on_device_shape() {
return GetDeviceShape(buffer_.get(), client_->pjrt_c_api(),
/*is_logical_on_device_shape=*/true);
}
PjRtFuture<Status> PjRtCApiBuffer::ToLiteral(MutableLiteralBase* literal) {
PJRT_Buffer_ToHostBuffer_Args args;
args.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE;
......
......@@ -366,7 +366,9 @@ class PjRtCApiBuffer : public PjRtBuffer {
// PJRT C API doesn't support tuple buffers.
bool IsTuple() const override { return false; }
const Shape& on_device_shape() const override;
const Shape& on_device_shape() const override {
LOG(FATAL) << "PjRtBuffer::on_device_shape() not implemented in PJRT C API";
}
bool has_dynamic_dimensions() const override;
......@@ -378,7 +380,10 @@ class PjRtCApiBuffer : public PjRtBuffer {
StatusOr<std::vector<int64_t>> logical_dimensions() override;
StatusOr<Shape> logical_on_device_shape() override;
StatusOr<Shape> logical_on_device_shape() override {
LOG(FATAL) << "PjRtBuffer::on_logical_device_shape() not implemented in "
"PJRT C API";
}
PjRtMemorySpace* memory_space() const override;
......@@ -434,9 +439,6 @@ class PjRtCApiBuffer : public PjRtBuffer {
const PJRT_Api* pjrt_c_api() const { return client_->pjrt_c_api(); }
private:
// TODO(b/238999986): Refactor or Remove.
void set_shape();
// Gets the raw pointer to `readiness_event_`. If `readiness_event_` has not
// yet been initialized, this function does so before returning the pointer.
PJRT_Event* GetReadyEvent();
......@@ -448,7 +450,6 @@ class PjRtCApiBuffer : public PjRtBuffer {
PjRtCApiClient* client_;
std::unique_ptr<PJRT_Buffer, ::pjrt::PJRT_BufferDeleter> buffer_;
std::optional<xla::Shape> shape_;
std::unique_ptr<PJRT_Event, ::pjrt::PJRT_EventDeleter> readiness_event_;
// This is a shared_ptr to keep the underlying future alive even if
// `readiness_promise` is destroyed before `readiness_event`, and the callback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册