diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 2beedaa2f4490610f602a39356dccddc232cb4dd..8640639e80d81976c17c6c162af81b174306d7a2 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -59,6 +59,9 @@ PassStrategy *AnalysisConfig::pass_builder() const { } else if (use_ipu_) { LOG(INFO) << "Create IPU IR passes"; pass_builder_.reset(new IpuPassStrategy); + } else if (use_custom_device_) { + LOG(INFO) << "Create CUSTOM DEVICE IR passes"; + pass_builder_.reset(new CustomDevicePassStrategy); } else { LOG(INFO) << "Create CPU IR passes"; pass_builder_.reset(new CpuPassStrategy); @@ -555,6 +558,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { } else if (use_xpu_) { pass_builder_.reset(new XpuPassStrategy( *static_cast(other.pass_builder()))); + } else if (use_custom_device_) { + pass_builder_.reset(new CustomDevicePassStrategy( + *static_cast(other.pass_builder()))); } else if (use_npu_) { pass_builder_.reset(new NpuPassStrategy( *static_cast(other.pass_builder()))); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 41d5726c598863255a505910791cca9c725bff55..dd7d1bac697f0377eb20747ffb8ae599325e2762 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -239,10 +239,28 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, #else PADDLE_THROW(paddle::platform::errors::Fatal( "Not compile with XPU, should not reach here.")); +#endif + } else if (platform::is_custom_place(place)) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + paddle::platform::DeviceContextPool &pool = + paddle::platform::DeviceContextPool::Instance(); + auto custom_place = place; + auto *dev_ctx = static_cast( + pool.Get(custom_place)); + memory::Copy(custom_place, + static_cast(input_ptr), + platform::CPUPlace(), + pt.data.data(), + pt.data.length(), + dev_ctx->stream()); +#else + PADDLE_THROW(paddle::platform::errors::Fatal( + "Not compile with CUSTOM_DEVICE, should not reach here.")); #endif } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "The analysis predictor supports CPU, GPU and XPU now.")); + "The analysis predictor supports CPU, GPU, XPU and CUSTOM_DEVICE " + "now.")); } // TODO(Superjomn) Low performance, need optimization for heavy LoD copy. framework::LoD lod;