未验证 提交 cb1d6b50 编写于 作者: E engineer1109 提交者: GitHub

fix mutable of custom place (#51710)

remove namespace

codestyle

move setPlace to Public

fix devicetype
上级 4638a62e
......@@ -1740,7 +1740,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
static_cast<size_t>(PaddlePlace::kCUSTOM) +
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
res->SetPlace(paddleplace, custom_place.GetDeviceId());
res->SetPlace(
paddleplace, custom_place.GetDeviceId(), place_.GetDeviceType());
} else {
auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
......@@ -1796,7 +1797,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
static_cast<size_t>(PaddlePlace::kCUSTOM) +
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
res->SetPlace(paddleplace, custom_place.GetDeviceId());
res->SetPlace(
paddleplace, custom_place.GetDeviceId(), place_.GetDeviceType());
} else {
auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
......
......@@ -127,6 +127,10 @@ T *Tensor::mutable_data(PlaceType place) {
case static_cast<int>(PlaceType::kNPU): {
return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
}
case static_cast<int>(PlaceType::kCUSTOM): {
return tensor->mutable_data<T>(
paddle::platform::CustomPlace(device_type_, device_));
}
default:
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
......@@ -150,6 +154,8 @@ T *Tensor::data(PlaceType *place, int *size) const {
*place = PlaceType::kXPU;
} else if (paddle::platform::is_npu_place(tensor->place())) {
*place = PlaceType::kNPU;
} else if (paddle::platform::is_custom_place(tensor->place())) {
*place = PlaceType::kCUSTOM;
} else {
*place = PlaceType::kUNK;
}
......@@ -741,9 +747,12 @@ void Tensor::SetName(const std::string &name) { name_ = name; }
const std::string &Tensor::name() const { return name_; }
void Tensor::SetPlace(PlaceType place, int device) {
void Tensor::SetPlace(PlaceType place,
int device,
const std::string device_type) {
place_ = place;
device_ = device;
device_type_ = device_type;
}
#ifdef PADDLE_WITH_ONNXRUNTIME
......
......@@ -176,7 +176,10 @@ class PD_INFER_DECL Tensor {
template <typename T>
void* FindTensor() const;
void SetPlace(PlaceType place, int device = -1);
void SetPlace(PlaceType place,
int device = -1,
const std::string device_type = "");
void SetName(const std::string& name);
template <typename T>
......@@ -195,6 +198,7 @@ class PD_INFER_DECL Tensor {
const void* device_contexs_{nullptr};
PlaceType place_;
int device_;
std::string device_type_;
#ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册