未验证 提交 a6a49855 编写于 作者: M Meteor Liu 提交者: GitHub

[dy2static]implement tensor.cuda() in static graph (#56092)

* [dy2static]implement tensor.cuda() in static graph

* [dy2static]implement tensor.cuda() in static graph - change the patch place

* [dy2static]implement tensor.cuda() in static graph - fix code-block in comment

* [dy2static]implement tensor.cuda() in static graph - add ut for warning branch
上级 6eaed2da
......@@ -476,6 +476,56 @@ void ApplyDeviceGuard(const OperatorBase* op_base,
}
}
platform::DeviceContext* ConstructDeviceContext(const OperatorBase* op,
const platform::Place& place) {
auto& pool = platform::DeviceContextPool::Instance();
auto* default_dev_ctx = pool.Get(place);
// Replace the default_dev_ctx according to dst_place_type for memcpy op if
// needed.
// NOTE(liudongxue01):
// Please apply the following logic in other Executor/Interpreter modules
// likewise.
//
// NOTE(liudongxue01):
// The following code aims to fixup the memcpy kernel which does not handle
// some rare case. The case is:
// 1. The default place in the current execution context is not CUDAPlace,
// such as CPUPlace,
// 2. The dst_place_type is 1 which means CUDAPlace,
// 3. The expected result place is CUDAPlace but the actual result is
// CPUPlace.
// When the default place is CPUPlace, we call the tensor.cuda() would
// simply hit such case.
//
// Q: Why we do not add such logic in the memcpy kernel?
// A: (1) To fixup the memcpy kernel, we need to construct a CUDAPlace() and
// corresponding DeviceContext instance which used by the phi::Copy(...)
// api to perform the real memcpy action. (2) We should not access the
// singleton of the DeviceContextPool object in the PHI framework which
// is designed as a standalone module and all context data should passed
// into the kernel API through arguments. (3) So we have no way to
// construct a CUDAPlace() in the memcpy kernel and then pass it
// to the phi::Copy(...) api.
if (!platform::is_gpu_place(place)) {
const auto& op_type = op->Type();
if (op_type == "memcpy") {
int dst_place_type = op->Attr<int>("dst_place_type");
if (dst_place_type == 1) { // 1 : CUDAPlace
auto dev_ctx = pool.Get(paddle::DefaultGPUPlace());
VLOG(4) << "Change the device context for memcpy OP: ("
<< default_dev_ctx->type_info().name() << ") -> ("
<< dev_ctx->type_info().name() << ")";
return dev_ctx;
}
}
}
return default_dev_ctx;
}
void HandleOperatorBase(const platform::Place& place,
std::shared_ptr<OperatorBase> op,
OpFuncNode* op_func_node,
......@@ -537,6 +587,8 @@ void BuildOpFuncList(const platform::Place& place,
}
auto unused_var_map = GetUnusedVars(block, ops);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
bool flag_log_is_printed = false;
for (size_t i = 0; i < ops.size(); ++i) {
auto op = ops[i].get();
......@@ -655,8 +707,9 @@ void BuildOpFuncList(const platform::Place& place,
runtime_scope = local_scope;
}
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// construct the device context
auto* dev_ctx = ConstructDeviceContext(op, place);
SetDeviceCommContext(op, dev_ctx);
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
......
......@@ -146,6 +146,7 @@ def monkey_patch_variable():
In Static Graph Mode:
.. code-block:: python
import paddle
paddle.enable_static()
......@@ -162,7 +163,7 @@ def monkey_patch_variable():
persistable=False,
stop_gradient=True,
)
# 0 means cpu place, see paddle/fluid/operators/memcpy_op.h
# 0 means cpu place, see paddle/phi/kernels/memcpy_kernel.cc
attrs = {'dst_place_type': 0}
block.append_op(
type='memcpy',
......@@ -173,13 +174,57 @@ def monkey_patch_variable():
return output
@static_only
def cuda(self):
def cuda(self, device_id=None, blocking=True):
"""
Variable should not have cpu() and cuda() interface.
But this interface can greatly facilitate dy2static.
We do nothing here.
In dy2static, Variable also needs cpu() and cuda() interface.
But, the underneath operator has only forward op but not backward one.
Args:
self(Variable): The variable itself.
device_id(int, optional): The destination GPU device id. Default: None, means current device.
We add this argument for dy2static translation, please do not use it.
blocking(bool, optional): Whether blocking or not, Default: True.
We add this argument for dy2static translation, please do not use it.
Returns:
The tensor which has copied to cuda place.
Examples:
In Static Graph Mode:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.static.data(name="x", shape=[2,2], dtype='float32')
y = x.cpu()
z = y.cuda()
"""
return self
if device_id is not None:
warnings.warn("device_id is not supported, and it will be ignored.")
if blocking is not True:
warnings.warn("blocking is not supported, and it will be ignored.")
block = current_block(self)
tmp_name = unique_tmp_name()
output = block.create_var(
name=tmp_name,
dtype=self.dtype,
shape=self.shape,
type=self.type,
persistable=False,
stop_gradient=True,
)
# 1 means cuda place, see paddle/phi/kernels/memcpy_kernel.cc
attrs = {'dst_place_type': 1}
block.append_op(
type='memcpy',
inputs={'X': [self]},
outputs={'Out': [output]},
attrs=attrs,
)
return output
@static_only
def place(self):
......
......@@ -26,6 +26,20 @@ def tensor_copy_to_cpu(x):
return y
@paddle.jit.to_static
def tensor_copy_to_cuda(x):
x = paddle.to_tensor(x)
y = x.cuda()
return y
@paddle.jit.to_static
def tensor_copy_to_cuda_with_warning(x, device_id=None, blocking=True):
x = paddle.to_tensor(x)
y = x.cuda(device_id, blocking)
return y
class TestTensorCopyToCpuOnDefaultCPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
......@@ -46,5 +60,67 @@ class TestTensorCopyToCpuOnDefaultCPU(unittest.TestCase):
self.assertTrue(static_place.is_cpu_place())
class TestTensorCopyToCUDAOnDefaultCPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cuda(x1)
return x1.place, x2.place, x2.numpy()
def test_tensor_cuda_on_default_cpu(self):
if not paddle.fluid.is_compiled_with_cuda():
return
"""
Note(liudongxue01): If the following asserts fail to run,
please check the workaround logic for memcpy OP
whether is still taking effect or not.
See ConstructDeviceContext() in interpreter_util.cc.
"""
paddle.fluid.framework._set_expected_place(paddle.CPUPlace())
dygraph_x1_place, dygraph_place, dygraph_res = self._run(
to_static=False
)
static_x1_place, static_place, static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
self.assertTrue(dygraph_x1_place.is_cpu_place())
self.assertTrue(static_x1_place.is_cpu_place())
self.assertTrue(dygraph_place.is_gpu_place())
self.assertTrue(static_place.is_gpu_place())
class TestTensorCopyToCUDAWithWarningOnCPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cuda_with_warning(x1, device_id=1, blocking=False)
return x1.place, x2.place, x2.numpy()
def test_with_warning_on_cpu(self):
if not paddle.fluid.is_compiled_with_cuda():
return
paddle.fluid.framework._set_expected_place(paddle.CPUPlace())
x1 = paddle.ones([1, 2, 3])
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x1, device_id=1, blocking=True
)
self.assertIn('math_op_patch.py', cm.filename)
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x1, device_id=None, blocking=False
)
self.assertIn('math_op_patch.py', cm.filename)
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x1, device_id=2, blocking=False
)
self.assertIn('math_op_patch.py', cm.filename)
if __name__ == '__main__':
unittest.main()
......@@ -27,6 +27,20 @@ def tensor_copy_to_cpu(x):
return y
@paddle.jit.to_static
def tensor_copy_to_cuda(x):
x = paddle.to_tensor(x)
y = x.cuda()
return y
@paddle.jit.to_static
def tensor_copy_to_cuda_with_warning(x, device_id=None, blocking=True):
x = paddle.to_tensor(x)
y = x.cuda(device_id, blocking)
return y
class TestTensorCopyToCpuOnDefaultGPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
......@@ -53,5 +67,67 @@ class TestTensorCopyToCpuOnDefaultGPU(unittest.TestCase):
self.assertTrue(static_place.is_cpu_place())
class TestTensorCopyToCUDAOnDefaultGPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cuda(x1)
return x1.place, x2.place, x2.numpy()
def test_tensor_cuda_on_default_gpu(self):
if paddle.fluid.is_compiled_with_cuda():
place = paddle.CUDAPlace(
int(os.environ.get('FLAGS_selected_gpus', 0))
)
else:
return
paddle.fluid.framework._set_expected_place(place)
dygraph_x1_place, dygraph_place, dygraph_res = self._run(
to_static=False
)
static_x1_place, static_place, static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
self.assertTrue(dygraph_x1_place.is_gpu_place())
self.assertTrue(static_x1_place.is_gpu_place())
self.assertTrue(dygraph_place.is_gpu_place())
self.assertTrue(static_place.is_gpu_place())
class TestTensorCopyToCUDAWithWarningOnGPU(unittest.TestCase):
def _run(self, to_static):
paddle.jit.enable_to_static(to_static)
x1 = paddle.ones([1, 2, 3])
x2 = tensor_copy_to_cuda_with_warning(x1, device_id=1, blocking=False)
return x1.place, x2.place, x2.numpy()
def test_with_warning_on_gpu(self):
if paddle.fluid.is_compiled_with_cuda():
place = paddle.CUDAPlace(
int(os.environ.get('FLAGS_selected_gpus', 0))
)
else:
return
paddle.fluid.framework._set_expected_place(place)
x1 = paddle.ones([1, 2, 3])
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x1, device_id=1, blocking=True
)
self.assertIn('math_op_patch.py', cm.filename)
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x1, device_id=None, blocking=False
)
self.assertIn('math_op_patch.py', cm.filename)
with self.assertWarns(UserWarning, msg="ignored") as cm:
x2 = tensor_copy_to_cuda_with_warning(
x1, device_id=2, blocking=False
)
self.assertIn('math_op_patch.py', cm.filename)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册