未验证 提交 24873f4f 编写于 作者: T taixiurong 提交者: GitHub

dyngraph (#30892)

上级 71acde9a
......@@ -29,11 +29,11 @@ class XPURangeKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
framework::TensorCopySync(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
framework::TensorCopySync(*end_t, platform::CPUPlace(), &n);
T end = n.data<T>()[0];
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
framework::TensorCopySync(*step_t, platform::CPUPlace(), &n);
T step = n.data<T>()[0];
int64_t size = 0;
......
......@@ -14,7 +14,7 @@
# TODO: define the functions to manipulate devices
import re
import os
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv
......@@ -137,7 +137,9 @@ def set_device(device):
raise ValueError(
"The device should not be 'xpu', " \
"since PaddlePaddle is not compiled with XPU")
place = core.XPUPlace(ParallelEnv().dev_id)
selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
device_id = int(selected_xpus[0])
place = core.XPUPlace(device_id)
else:
avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册