未验证 提交 29de0d97 编写于 作者: L lilong12 提交者: GitHub

add the support to specify device index for device_guard (#24555)

* add the support of device index for device_guard.
上级 3016a4ac
...@@ -1050,7 +1050,16 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -1050,7 +1050,16 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
} else if (Attr<std::string>("op_device") == "gpu") { } else if (Attr<std::string>("op_device").find("gpu") !=
std::string::npos) {
auto device = Attr<std::string>("op_device");
size_t pos = device.find(':');
if (pos != std::string::npos) {
device = device.substr(0, pos);
LOG_FIRST_N(WARNING, 1)
<< "Device index is only supported under pipeline parallelism, "
<< "so it will be ignored.";
}
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time. // will be executed and a warning will be given at the same time.
if (SupportGPU()) { if (SupportGPU()) {
......
...@@ -5455,10 +5455,17 @@ def device_guard(device=None): ...@@ -5455,10 +5455,17 @@ def device_guard(device=None):
result = exe.run(fetch_list=[out]) result = exe.run(fetch_list=[out])
""" """
index = None
if device and ':' in device:
device, index = device.split(':')
if device == 'cpu':
raise ValueError("Should not set device id for cpu.")
if device not in ['cpu', 'gpu', '', None]: if device not in ['cpu', 'gpu', '', None]:
raise ValueError( raise ValueError(
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None " "The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device) "when there is no need to specify device. But received %s" % device)
if index:
device = ":".join([device, index])
pre_device = switch_device(device) pre_device = switch_device(device)
try: try:
yield yield
......
...@@ -59,6 +59,31 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -59,6 +59,31 @@ class TestDeviceGuard(unittest.TestCase):
execute(main_program, startup_program) execute(main_program, startup_program)
def test_device_guard_with_id(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
data1 = fluid.layers.fill_constant(
shape=[1, 3, 8, 8], value=0.5, dtype='float32')
data2 = fluid.layers.fill_constant(
shape=[1, 3, 5, 5], value=0.5, dtype='float32')
shape = fluid.layers.shape(data2)
with fluid.device_guard("cpu"):
shape = fluid.layers.slice(
shape, axes=[0], starts=[0], ends=[4])
with fluid.device_guard("gpu:1"):
out = fluid.layers.crop_tensor(data1, shape=shape)
# check if the device attr is set correctly
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
if op.type == 'slice':
self.assertEqual(op.desc.attr(device_attr_name), "cpu")
if op.type == 'crop_tensor':
self.assertEqual(op.desc.attr(device_attr_name), "gpu:1")
execute(main_program, startup_program)
def test_cpu_only_op(self): def test_cpu_only_op(self):
main_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
...@@ -123,7 +148,13 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -123,7 +148,13 @@ class TestDeviceGuard(unittest.TestCase):
out = fluid.layers.fill_constant( out = fluid.layers.fill_constant(
shape=[1], value=0.2, dtype='float32') shape=[1], value=0.2, dtype='float32')
def device_attr2():
with fluid.device_guard("cpu:1"):
out = fluid.layers.fill_constant(
shape=[1], value=0.2, dtype='float32')
self.assertRaises(ValueError, device_attr) self.assertRaises(ValueError, device_attr)
self.assertRaises(ValueError, device_attr2)
def test_warning(self): def test_warning(self):
main_program = fluid.Program() main_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册