未验证 提交 4c576870 编写于 作者: H HongyuJia 提交者: GitHub

SetDevice when parse TensorBase (#49860)

上级 2f24b2d8
...@@ -22,6 +22,9 @@ limitations under the License. */ ...@@ -22,6 +22,9 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/string_tensor_utils.h" #include "paddle/phi/core/string_tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -54,6 +57,11 @@ bool HasAllocation(const phi::TensorBase& t) { ...@@ -54,6 +57,11 @@ bool HasAllocation(const phi::TensorBase& t) {
BackendSet GetTensorBackendSet(const phi::TensorBase& t) { BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) { if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (t.place().GetType() == AllocationType::CUSTOM) {
phi::DeviceManager::SetDevice(t.place());
}
#endif
phi::Backend backend_key = phi::TransToPhiBackend(t.place()); phi::Backend backend_key = phi::TransToPhiBackend(t.place());
BackendSet backend_set(backend_key); BackendSet backend_set(backend_key);
if (backend_key == Backend::GPU && phi::DenseTensor::classof(&t) && if (backend_key == Backend::GPU && phi::DenseTensor::classof(&t) &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册