未验证 提交 067f558c 编写于 作者: C chentianyu03 提交者: GitHub

add varbasecopy func to fix the ParamBase type bug in layers.to API (#32789)

* add varbasecopy func to fix the paraBase type bug in layers.to API

* overload _copy_to func for ParamBase

* add xpuplace

* add waiting varbsecopy completion when not blocking

* fix dst_device bug

* modify varbase to shared_ptr
上级 2611ed25
......@@ -469,6 +469,62 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
if (!PyTuple_Check(_index)) Py_DecRef(index);
}
template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src,
imperative::VarBase &dst, const P &dst_device,
const bool blocking) {
if (dst.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name();
dst.SetPersistable(src->Persistable());
dst.SetDataType(src->DataType());
dst.SetType(src->Type());
dst.SetOverridedStopGradient(src->OverridedStopGradient());
if (!src->SharedVar()->IsEmpty()) {
if (src->Var().IsType<framework::LoDTensor>()) {
auto &src_tensor = src->Var().Get<framework::LoDTensor>();
auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, dst_device, dst_tensor);
if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
auto src_device = src_tensor.place();
if (!(src_device == dst_device)) {
platform::DeviceContextPool::Instance().Get(src_device)->Wait();
}
}
} else if (src->Var().IsType<framework::SelectedRows>()) {
auto &src_selected_rows = src->Var().Get<framework::SelectedRows>();
auto *dst_selected_rows =
dst.MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows());
framework::TensorCopy(src_selected_rows.value(), dst_device,
dst_selected_rows->mutable_value());
if (blocking) {
platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
auto src_device = src_selected_rows.value().place();
if (!(src_device == dst_device)) {
platform::DeviceContextPool::Instance().Get(src_device)->Wait();
}
}
}
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(src, dst_device);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The source Tensor(%s) can not copy when it is empty.", src->Name()));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The destion Tensor(%s) can not copy when it is not empty.",
dst.Name()));
}
}
// Bind Methods
void BindImperative(py::module *m_ptr) {
auto &m = *m_ptr;
......@@ -1639,6 +1695,11 @@ void BindImperative(py::module *m_ptr) {
self.nrings_ = nrings;
});
m.def("varbase_copy", &VarBaseCopy<platform::Place>);
m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>);
m.def(
"dygraph_partial_grad",
[](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
......
......@@ -34,7 +34,7 @@ from .base import program_desc_tracing_guard, param_guard
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph import no_grad
import paddle.utils.deprecated as deprecated
......@@ -1427,8 +1427,19 @@ class Layer(core.Layer):
dtype = t.dtype
new_t = t._copy_to(device, blocking)
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)
if isinstance(t, framework.ParamBase):
if dtype is not None and dtype != t.dtype:
framework._dygraph_tracer().trace_op(
type='cast',
inputs={'X': new_t},
outputs={'Out': new_t},
attrs={
'in_dtype': t.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype)
})
else:
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)
return new_t
......
......@@ -5855,6 +5855,13 @@ class ParamBase(core.VarBase):
new_param.copy_(self, True)
return new_param
def _copy_to(self, device, blocking):
print("in ParamBase copy_to func")
state = copy.deepcopy(self.__dict__)
new_param = ParamBase(self.shape, self.dtype, **state)
core.varbase_copy(self, new_param, device, blocking)
return new_param
__repr__ = __str__
......
......@@ -341,7 +341,7 @@ class TestLayerTo(unittest.TestCase):
self.linear.register_buffer("buf_name", buffer, persistable=True)
sublayer = paddle.nn.Conv1D(3, 2, 3)
self.linear.add_sublayer(1, sublayer)
self.linear.add_sublayer("1", sublayer)
def test_to_api(self):
self.linear.to(dtype='double')
......@@ -351,8 +351,8 @@ class TestLayerTo(unittest.TestCase):
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.linear.to()
self.assertEqual(self.linear.weight.dtype,
......@@ -361,8 +361,10 @@ class TestLayerTo(unittest.TestCase):
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(
np.allclose(self.linear.weight.grad.numpy(), self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
if paddle.fluid.is_compiled_with_cuda():
self.linear.to(device=paddle.CUDAPlace(0))
......@@ -384,6 +386,8 @@ class TestLayerTo(unittest.TestCase):
))
self.assertEqual(
self.linear.weight._grad_ivar().place.gpu_device_id(), 0)
for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
self.linear.to(device=paddle.CPUPlace())
self.assertTrue(self.linear.weight.place.is_cpu_place())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册