未验证 提交 067f558c 编写于 作者: C chentianyu03 提交者: GitHub

add varbasecopy func to fix the ParamBase type bug in layers.to API (#32789)

* add varbasecopy func to fix the paraBase type bug in layers.to API

* overload _copy_to func for ParamBase

* add xpuplace

* add waiting varbsecopy completion when not blocking

* fix dst_device bug

* modify varbase to shared_ptr
上级 2611ed25
...@@ -469,6 +469,62 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -469,6 +469,62 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
if (!PyTuple_Check(_index)) Py_DecRef(index); if (!PyTuple_Check(_index)) Py_DecRef(index);
} }
template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src,
imperative::VarBase &dst, const P &dst_device,
const bool blocking) {
if (dst.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name();
dst.SetPersistable(src->Persistable());
dst.SetDataType(src->DataType());
dst.SetType(src->Type());
dst.SetOverridedStopGradient(src->OverridedStopGradient());
if (!src->SharedVar()->IsEmpty()) {
if (src->Var().IsType<framework::LoDTensor>()) {
auto &src_tensor = src->Var().Get<framework::LoDTensor>();
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, dst_device, dst_tensor);
if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
auto src_device = src_tensor.place();
if (!(src_device == dst_device)) {
platform::DeviceContextPool::Instance().Get(src_device)->Wait();
}
}
} else if (src->Var().IsType<framework::SelectedRows>()) {
auto &src_selected_rows = src->Var().Get<framework::SelectedRows>();
auto *dst_selected_rows =
dst.MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows());
framework::TensorCopy(src_selected_rows.value(), dst_device,
dst_selected_rows->mutable_value());
if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
auto src_device = src_selected_rows.value().place();
if (!(src_device == dst_device)) {
platform::DeviceContextPool::Instance().Get(src_device)->Wait();
}
}
}
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(src, dst_device);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The source Tensor(%s) can not copy when it is empty.", src->Name()));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The destion Tensor(%s) can not copy when it is not empty.",
dst.Name()));
}
}
// Bind Methods // Bind Methods
void BindImperative(py::module *m_ptr) { void BindImperative(py::module *m_ptr) {
auto &m = *m_ptr; auto &m = *m_ptr;
...@@ -1639,6 +1695,11 @@ void BindImperative(py::module *m_ptr) { ...@@ -1639,6 +1695,11 @@ void BindImperative(py::module *m_ptr) {
self.nrings_ = nrings; self.nrings_ = nrings;
}); });
m.def("varbase_copy", &VarBaseCopy<platform::Place>);
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>);
m.def( m.def(
"dygraph_partial_grad", "dygraph_partial_grad",
[](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets, [](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
......
...@@ -34,7 +34,7 @@ from .base import program_desc_tracing_guard, param_guard ...@@ -34,7 +34,7 @@ from .base import program_desc_tracing_guard, param_guard
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph import no_grad from paddle.fluid.dygraph import no_grad
import paddle.utils.deprecated as deprecated import paddle.utils.deprecated as deprecated
...@@ -1427,8 +1427,19 @@ class Layer(core.Layer): ...@@ -1427,8 +1427,19 @@ class Layer(core.Layer):
dtype = t.dtype dtype = t.dtype
new_t = t._copy_to(device, blocking) new_t = t._copy_to(device, blocking)
if dtype is not None and dtype != t.dtype: if isinstance(t, framework.ParamBase):
new_t = new_t.cast(dtype=dtype) if dtype is not None and dtype != t.dtype:
framework._dygraph_tracer().trace_op(
type='cast',
inputs={'X': new_t},
outputs={'Out': new_t},
attrs={
'in_dtype': t.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
else:
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)
return new_t return new_t
......
...@@ -5855,6 +5855,13 @@ class ParamBase(core.VarBase): ...@@ -5855,6 +5855,13 @@ class ParamBase(core.VarBase):
new_param.copy_(self, True) new_param.copy_(self, True)
return new_param return new_param
def _copy_to(self, device, blocking):
print("in ParamBase copy_to func")
state = copy.deepcopy(self.__dict__)
new_param = ParamBase(self.shape, self.dtype, **state)
core.varbase_copy(self, new_param, device, blocking)
return new_param
__repr__ = __str__ __repr__ = __str__
......
...@@ -341,7 +341,7 @@ class TestLayerTo(unittest.TestCase): ...@@ -341,7 +341,7 @@ class TestLayerTo(unittest.TestCase):
self.linear.register_buffer("buf_name", buffer, persistable=True) self.linear.register_buffer("buf_name", buffer, persistable=True)
sublayer = paddle.nn.Conv1D(3, 2, 3) sublayer = paddle.nn.Conv1D(3, 2, 3)
self.linear.add_sublayer(1, sublayer) self.linear.add_sublayer("1", sublayer)
def test_to_api(self): def test_to_api(self):
self.linear.to(dtype='double') self.linear.to(dtype='double')
...@@ -351,8 +351,8 @@ class TestLayerTo(unittest.TestCase): ...@@ -351,8 +351,8 @@ class TestLayerTo(unittest.TestCase):
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue( self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad)) np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype, self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
self.linear.to() self.linear.to()
self.assertEqual(self.linear.weight.dtype, self.assertEqual(self.linear.weight.dtype,
...@@ -361,8 +361,10 @@ class TestLayerTo(unittest.TestCase): ...@@ -361,8 +361,10 @@ class TestLayerTo(unittest.TestCase):
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue( self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad)) np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype, self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
if paddle.fluid.is_compiled_with_cuda(): if paddle.fluid.is_compiled_with_cuda():
self.linear.to(device=paddle.CUDAPlace(0)) self.linear.to(device=paddle.CUDAPlace(0))
...@@ -384,6 +386,8 @@ class TestLayerTo(unittest.TestCase): ...@@ -384,6 +386,8 @@ class TestLayerTo(unittest.TestCase):
)) ))
self.assertEqual( self.assertEqual(
self.linear.weight._grad_ivar().place.gpu_device_id(), 0) self.linear.weight._grad_ivar().place.gpu_device_id(), 0)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
self.linear.to(device=paddle.CPUPlace()) self.linear.to(device=paddle.CPUPlace())
self.assertTrue(self.linear.weight.place.is_cpu_place()) self.assertTrue(self.linear.weight.place.is_cpu_place())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册