提交 c2f38110 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add CopyToMemorySpace to the PjRtBuffer API. This CL does not implement any...

Add CopyToMemorySpace to the PjRtBuffer API. This CL does not implement any instance of the method, but adds the ability to do so in followup CLs.

PiperOrigin-RevId: 564807735
上级 1d7dcd42
......@@ -135,6 +135,11 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer {
return PjRtFuture<Status>(Unimplemented("CopyRawToHost not implemented"));
}
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) override {
return Unimplemented("CopyToMemorySpace not implemented");
}
void Delete() override;
bool IsDeleted() override;
......
......@@ -424,6 +424,11 @@ class PjRtCApiBuffer : public PjRtBuffer {
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) override {
return Unimplemented("PJRT C API does not support CopyToMemorySpace");
}
void CopyToRemoteDevice(
PjRtFuture<StatusOr<std::string>> serialized_descriptor,
RemoteSendCallback on_done) override {
......
......@@ -1111,6 +1111,18 @@ class PjRtBuffer {
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) = 0;
// Copies the buffer to memory space `dst_memory_space`.
//
// The destination memory space may be attached to any client, but optimized
// implementations may apply when the copy is within the same client.
//
// Returns an error if the buffer is already in dst_memory_space.
//
// See note on semantics of cross-device copies in the class definition
// comment for PjRtClient.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) = 0;
// Prepares to send a copy of the buffer to a remote device. The destination
// device is encoded in `serialized_descriptor`, which must be fulfilled by
// the result of call to MakeCrossHostReceiveBuffers on the remote host's
......
......@@ -672,6 +672,11 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) override {
return Unimplemented("Implement CopyToMemorySpace");
}
void CopyToRemoteDevice(
PjRtFuture<StatusOr<std::string>> serialized_descriptor,
RemoteSendCallback on_done) override;
......
......@@ -73,6 +73,10 @@ class TfPjRtBuffer : public PjRtBuffer {
bool IsDeleted() override { return wrapped_->IsDeleted(); }
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToMemorySpace(
PjRtMemorySpace* dst_memory_space) override {
return Unimplemented("CopyToMemorySpace not implemented");
}
void CopyToRemoteDevice(
PjRtFuture<StatusOr<std::string>> serialized_descriptor,
RemoteSendCallback on_done) override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册