未验证 提交 706a7897 编写于 作者: Z zhangbo9674 提交者: GitHub

Fix Layer.to() of device bug (#37156)

上级 34a44d59
......@@ -1556,19 +1556,18 @@ class Layer(core.Layer):
if dtype is None:
dtype = t.dtype
if type(dtype) is str:
dtype = convert_np_dtype_to_dtype_(dtype)
# 1. gpu place need to determine whether the memory is sufficient for allocation:
if t.place.is_gpu_place():
gpu_memory_available = core.gpu_memory_available()
# for gpu, minimum memory allocation unit is 256 bytes.
if type(dtype) is str:
size_dtype = core.size_of_dtype(
convert_np_dtype_to_dtype_(dtype))
else:
size_dtype = core.size_of_dtype(dtype)
# Note(zhangbo): Paddle GPU minimum memory allocation unit is 256 bytes, waiting_alloc_memory will comput ‘t’ occupied memory space.
# Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
waiting_alloc_memory = (
(t.numel().numpy()[0] * size_dtype) / 256 + 1) * 256 * 1.2
(np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2
gpu_memory_available = core.gpu_memory_available()
if gpu_memory_available < waiting_alloc_memory:
# Copy param / Tensor to cpu
t_used = t._copy_to(paddle.CPUPlace(),
......@@ -1582,26 +1581,17 @@ class Layer(core.Layer):
# 2. cast param / Tensor to dtype
if dtype is not None and dtype != t_used.dtype:
if isinstance(t_used, framework.ParamBase):
from paddle.fluid.layer_helper import LayerHelper
helper = LayerHelper("cast", **locals())
t_casted = helper.create_variable_for_type_inference(
dtype=dtype)
framework._dygraph_tracer().trace_op(
type='cast',
inputs={'X': t_used},
outputs={'Out': t_casted},
attrs={
'in_dtype': t_used.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
else:
with paddle.fluid.framework._dygraph_place_guard(
place=t_used.place):
t_casted = t_used.cast(dtype=dtype)
else:
t_casted = t_used
# 3. Copy casted cpu param / Tensor to device
if device is not None and not t_casted.place._equals(device):
new_t = t_casted._copy_to(device, blocking)
else:
new_t = t_casted
# 4. share Tensor to origin param / Tensor
dst_tensor = t.value().get_tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册