未验证 提交 68b06ba6 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] set cuda device before run (#44985)

* set cuda device before run

* add header file

* fix compile
上级 9c98ee3e
......@@ -28,6 +28,7 @@
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
true,
......@@ -98,6 +99,11 @@ InterpreterCore::~InterpreterCore() {
interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
......@@ -122,6 +128,11 @@ interpreter::CostInfo InterpreterCore::DryRun(
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
......@@ -153,6 +164,11 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
......
......@@ -141,6 +141,7 @@ class CudaEvent {
#else
cudaEventCreateWithFlags(&event_, flags_);
#endif
VLOG(4) << "CudaEvent " << event_;
}
explicit CudaEvent(unsigned int flags) : flags_(flags) {
......@@ -149,6 +150,7 @@ class CudaEvent {
#else
cudaEventCreateWithFlags(&event_, flags_);
#endif
VLOG(4) << "CudaEvent " << event_;
}
~CudaEvent() {
......
......@@ -241,6 +241,7 @@ void SetDeviceId(int id) {
id,
GetGPUDeviceCount()));
PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
VLOG(4) << "SetDeviceId " << id;
}
void GpuMemcpyAsync(void *dst,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册