未验证 提交 cb1d6b50 编写于 作者: E engineer1109 提交者: GitHub

fix mutable of custom place (#51710)

remove namespace

codestyle

move setPlace to Public

fix devicetype
上级 4638a62e
...@@ -1740,7 +1740,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor( ...@@ -1740,7 +1740,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
static_cast<size_t>(PaddlePlace::kCUSTOM) + static_cast<size_t>(PaddlePlace::kCUSTOM) +
phi::CustomRegisteredDeviceMap::Instance() phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType())); .GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
res->SetPlace(paddleplace, custom_place.GetDeviceId()); res->SetPlace(
paddleplace, custom_place.GetDeviceId(), place_.GetDeviceType());
} else { } else {
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
...@@ -1796,7 +1797,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor( ...@@ -1796,7 +1797,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
static_cast<size_t>(PaddlePlace::kCUSTOM) + static_cast<size_t>(PaddlePlace::kCUSTOM) +
phi::CustomRegisteredDeviceMap::Instance() phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType())); .GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
res->SetPlace(paddleplace, custom_place.GetDeviceId()); res->SetPlace(
paddleplace, custom_place.GetDeviceId(), place_.GetDeviceType());
} else { } else {
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
......
...@@ -127,6 +127,10 @@ T *Tensor::mutable_data(PlaceType place) { ...@@ -127,6 +127,10 @@ T *Tensor::mutable_data(PlaceType place) {
case static_cast<int>(PlaceType::kNPU): { case static_cast<int>(PlaceType::kNPU): {
return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_)); return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
} }
case static_cast<int>(PlaceType::kCUSTOM): {
return tensor->mutable_data<T>(
paddle::platform::CustomPlace(device_type_, device_));
}
default: default:
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
"Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is " "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
...@@ -150,6 +154,8 @@ T *Tensor::data(PlaceType *place, int *size) const { ...@@ -150,6 +154,8 @@ T *Tensor::data(PlaceType *place, int *size) const {
*place = PlaceType::kXPU; *place = PlaceType::kXPU;
} else if (paddle::platform::is_npu_place(tensor->place())) { } else if (paddle::platform::is_npu_place(tensor->place())) {
*place = PlaceType::kNPU; *place = PlaceType::kNPU;
} else if (paddle::platform::is_custom_place(tensor->place())) {
*place = PlaceType::kCUSTOM;
} else { } else {
*place = PlaceType::kUNK; *place = PlaceType::kUNK;
} }
...@@ -741,9 +747,12 @@ void Tensor::SetName(const std::string &name) { name_ = name; } ...@@ -741,9 +747,12 @@ void Tensor::SetName(const std::string &name) { name_ = name; }
const std::string &Tensor::name() const { return name_; } const std::string &Tensor::name() const { return name_; }
void Tensor::SetPlace(PlaceType place, int device) { void Tensor::SetPlace(PlaceType place,
int device,
const std::string device_type) {
place_ = place; place_ = place;
device_ = device; device_ = device;
device_type_ = device_type;
} }
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
......
...@@ -176,7 +176,10 @@ class PD_INFER_DECL Tensor { ...@@ -176,7 +176,10 @@ class PD_INFER_DECL Tensor {
template <typename T> template <typename T>
void* FindTensor() const; void* FindTensor() const;
void SetPlace(PlaceType place, int device = -1); void SetPlace(PlaceType place,
int device = -1,
const std::string device_type = "");
void SetName(const std::string& name); void SetName(const std::string& name);
template <typename T> template <typename T>
...@@ -195,6 +198,7 @@ class PD_INFER_DECL Tensor { ...@@ -195,6 +198,7 @@ class PD_INFER_DECL Tensor {
const void* device_contexs_{nullptr}; const void* device_contexs_{nullptr};
PlaceType place_; PlaceType place_;
int device_; int device_;
std::string device_type_;
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false}; bool is_ort_tensor_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册