提交 643e43b1 编写于 作者: Y Yash Katariya 提交者: TensorFlower Gardener

Fix the checks that check if memory kinds between pjrt_buffer and user sharding is equal or not.

PiperOrigin-RevId: 564578779
上级 97a1fa35
......@@ -694,6 +694,7 @@ StatusOr<PyArray> PyArray::CopyToDeviceWithSharding(
CreateIfRtMemoryKindFromSharding(dst_sharding);
if (ifrt_array_ptr->sharding().devices().devices() == devices.devices() &&
(!dst_memory_kind.memory_kind().has_value() ||
!ifrt_array_ptr->sharding().memory_kind().memory_kind().has_value() ||
ifrt_array_ptr->sharding().memory_kind() == dst_memory_kind)) {
return *this;
}
......
......@@ -290,7 +290,8 @@ StatusOr<DevicePutResult> HandlePyArray(py::handle obj, ifrt::Client* client,
if (ifrt_array->sharding().devices().front() == to_device &&
(!to_memory_kind.memory_kind().has_value() ||
(ifrt_array->sharding().memory_kind() == to_memory_kind))) {
!ifrt_array->sharding().memory_kind().memory_kind().has_value() ||
ifrt_array->sharding().memory_kind() == to_memory_kind)) {
return DevicePutResult(
tsl::FormRef(ifrt_array), py_array.weak_type(),
/*owning_pybuffer=*/py::reinterpret_borrow<py::object>(obj));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册