未验证 提交 24873f4f 编写于 作者: T taixiurong 提交者: GitHub

dyngraph (#30892)

上级 71acde9a
...@@ -29,11 +29,11 @@ class XPURangeKernel : public framework::OpKernel<T> { ...@@ -29,11 +29,11 @@ class XPURangeKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor n; framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n); framework::TensorCopySync(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0]; T start = n.data<T>()[0];
framework::TensorCopy(*end_t, platform::CPUPlace(), &n); framework::TensorCopySync(*end_t, platform::CPUPlace(), &n);
T end = n.data<T>()[0]; T end = n.data<T>()[0];
framework::TensorCopy(*step_t, platform::CPUPlace(), &n); framework::TensorCopySync(*step_t, platform::CPUPlace(), &n);
T step = n.data<T>()[0]; T step = n.data<T>()[0];
int64_t size = 0; int64_t size = 0;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# TODO: define the functions to manipulate devices # TODO: define the functions to manipulate devices
import re import re
import os
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
...@@ -137,7 +137,9 @@ def set_device(device): ...@@ -137,7 +137,9 @@ def set_device(device):
raise ValueError( raise ValueError(
"The device should not be 'xpu', " \ "The device should not be 'xpu', " \
"since PaddlePaddle is not compiled with XPU") "since PaddlePaddle is not compiled with XPU")
place = core.XPUPlace(ParallelEnv().dev_id) selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
device_id = int(selected_xpus[0])
place = core.XPUPlace(device_id)
else: else:
avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册