未验证 提交 52645667 编写于 作者: W Weilong Wu 提交者: GitHub

[New features] Support VarBase to expose func (#36965)

* Expose func for varbase

* Expose func for varbase and enhance varbase init func

* Change func name and add test case for _CopyGradientWith

* Rename func

* Add test cases to increase coverage

* Refine the logic of _to func

* Replace numel() with _numel(), Add test code
上级 1580eae2
......@@ -356,6 +356,37 @@ void VarBase::BumpInplaceVersion() {
MutableVar()->BumpInplaceVersion();
}
// NOTE(weilong wu):
// This function try to copy the data from target varbase,
// and fill into the grad_var_ of the current varbase.
void VarBase::_CopyGradientFrom(const VarBase& src) {
if (Var().IsInitialized()) {
PADDLE_ENFORCE_EQ(DataType(), src.DataType(),
platform::errors::PreconditionNotMet(
"Tensor %s has different data type with Tensor %s",
Name(), src.Name()));
PADDLE_ENFORCE_EQ(Type(), src.Type(),
platform::errors::PreconditionNotMet(
"Tensor %s has different type with Tensor %s, Tensor "
"ShareGradientDataWith cannot be performed!",
Name(), src.Name()));
}
VLOG(4) << " VarBase copy gradient with " << src.Name();
if (grad_var_) {
auto& src_tensor = src.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized", src.Name()));
auto* grad_t = grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(grad_t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor %s has not been initialized", Name()));
auto* var_ = MutableVar()->GetMutable<framework::LoDTensor>();
grad_t->ShareDataWith(src_tensor);
grad_t->Resize(var_->dims());
}
}
pten::KernelContext OpBase::pt_kernel_context_;
void OpBase::SetType(const std::string& type) {
......
......@@ -231,6 +231,8 @@ class VarBase {
void BumpInplaceVersion();
void _CopyGradientFrom(const imperative::VarBase& src);
/* Hook related method: now only used for GradVarBase */
bool HasVariableWrapperHook() const { return var_->HasVariableWrapperHook(); }
......
......@@ -282,6 +282,27 @@ static void InitVarBaseFromTensorWithArgDefault(
}
}
template <typename P>
static void InitVarBaseFromTensorWithArg(imperative::VarBase *self,
const framework::Tensor &tensor,
const P &place) {
VLOG(4) << "Init VarBase";
new (self) imperative::VarBase(
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
self->SetPersistable(false);
self->SetType(framework::proto::VarType::LOD_TENSOR);
self->SetDataType(tensor.type());
auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
// Same place,share data directly
if (platform::is_same_place(place, tensor.place())) {
new_tensor->ShareDataWith(tensor);
VLOG(4) << "Same place, do ShareDataWith";
} else {
framework::TensorCopy(tensor, place, new_tensor);
VLOG(4) << "Different place, do TensorCopy";
}
}
static std::string GetTypeName(const imperative::VarBase &var) {
if (var.Type() == framework::proto::VarType::RAW) {
return "RAW";
......@@ -899,6 +920,16 @@ void BindImperative(py::module *m_ptr) {
py::arg("stop_gradient") = -1)
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CPUPlace>,
py::arg("tensor"), py::arg("place"))
.def("__init__", &InitVarBaseFromTensorWithArg<platform::XPUPlace>,
py::arg("tensor"), py::arg("place"))
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPlace>,
py::arg("tensor"), py::arg("place"))
.def("__init__", &InitVarBaseFromTensorWithArg<platform::CUDAPinnedPlace>,
py::arg("tensor"), py::arg("place"))
.def("__init__", &InitVarBaseFromTensorWithArg<platform::NPUPlace>,
py::arg("tensor"), py::arg("place"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def(
"__setitem_varbase__",
......@@ -1865,6 +1896,70 @@ void BindImperative(py::module *m_ptr) {
py::return_value_policy::copy)
.def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
py::return_value_policy::reference)
.def("_clear",
[](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
t->clear();
})
.def("_offset",
[](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
return t->offset();
})
.def("_share_buffer_with",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &target_t) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *t_t =
target_t->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
PADDLE_ENFORCE_EQ(t_t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
t->ShareBufferWith(*t_t);
})
.def("_is_shared_buffer_with",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &target_t) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *t_t =
target_t->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
PADDLE_ENFORCE_EQ(t_t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
return t->IsSharedBufferWith(*t_t);
})
.def("_slice",
[](const std::shared_ptr<imperative::VarBase> &self,
int64_t begin_idx, int64_t end_idx) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
return t->Slice(begin_idx, end_idx);
})
.def("_copy_gradient_from",
[](std::shared_ptr<imperative::VarBase> &self,
const imperative::VarBase &src) { self->_CopyGradientFrom(src); })
.def("_numel",
[](std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
return t->numel();
})
.def_property("name", &imperative::VarBase::Name,
&imperative::VarBase::SetName)
.def_property("stop_gradient",
......
......@@ -357,6 +357,81 @@ def monkey_patch_varbase():
helper = TensorHookRemoveHelper(self, hook_id)
return helper
@framework.dygraph_only
def _to(self, device=None, dtype=None, blocking=None):
if device is None and dtype is None and blocking is None:
return self
if device is not None:
if isinstance(device, str):
device = paddle.device._convert_to_place(device)
elif isinstance(device, (core.CPUPlace, core.CUDAPlace,
core.CUDAPinnedPlace, core.XPUPlace)):
pass
else:
raise ValueError(
"device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
+ type(device).__name__)
if blocking is None:
blocking = True
else:
assert isinstance(
blocking,
bool), "blocking value error, must be the True, False or None"
def transform(t, device, dtype, blocking):
if device is None:
device = t.place
if dtype is None:
dtype = t.dtype
# 1. gpu place need to determine whether the memory is sufficient for allocation.
if t.place.is_gpu_place():
gpu_memory_available = core.gpu_memory_available()
# for gpu, minimum memory allocation unit is 256 bytes.
if type(dtype) is str:
size_dtype = core.size_of_dtype(
framework.convert_np_dtype_to_dtype_(dtype))
else:
size_dtype = core.size_of_dtype(dtype)
# Note(weilong wu): Paddle GPU minimum memory allocation unit is 256 bytes,
# waiting_alloc_memory will compute the memory space occupied by 't'.
# Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
waiting_alloc_memory = (
(t._numel() * size_dtype) / 256 + 1) * 256 * 1.2
if gpu_memory_available < waiting_alloc_memory:
# Copy Tensor to cpu
t_used = t._copy_to(paddle.CPUPlace(), blocking)
# Release memory of t
t._clear()
else:
# Tensor still in GPU
t_used = t
else:
t_used = t
# 2. cast Tensor to dtype
if dtype is not None and dtype != t_used.dtype:
t_casted = t_used.cast(dtype=dtype)
else:
t_casted = t_used
# 3. Copy casted Tensor(in CPU or GPU) to device
new_t = t_casted._copy_to(device, blocking)
# 4. Share Tensor to origin Tensor
dst_tensor = t.value().get_tensor()
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)
return t
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
return transform(self, device, dtype, blocking)
@property
def grad(self):
"""
......@@ -650,7 +725,7 @@ def monkey_patch_varbase():
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__), ("item", item),
("__setitem__", __setitem__)):
("__setitem__", __setitem__), ("_to", _to)):
setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
......
......@@ -1154,5 +1154,141 @@ class TestVarBaseInplaceVersion(unittest.TestCase):
self.assertEqual(var.inplace_version, 2)
class TestVarBaseSlice(unittest.TestCase):
def test_slice(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
actual_x = x._slice(0, 1)
actual_x = paddle.to_tensor(actual_x)
self.assertEqual(actual_x.numpy().all(), np_x[0:1].all())
class TestVarBaseClear(unittest.TestCase):
def test_clear(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
x._clear()
self.assertEqual(str(x), "Tensor(Not initialized)")
class TestVarBaseOffset(unittest.TestCase):
def test_offset(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
expected_offset = 0
actual_x = x._slice(expected_offset, 1)
actual_x = paddle.to_tensor(actual_x)
self.assertEqual(actual_x._offset(), expected_offset)
class TestVarBaseShareBufferWith(unittest.TestCase):
def test_share_buffer_with(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
np_y = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
y = paddle.to_tensor(np_y, dtype="float64")
x._share_buffer_with(y)
self.assertEqual(x._is_shared_buffer_with(y), True)
class TestVarBaseTo(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.np_x = np.random.random((3, 8, 8))
self.x = paddle.to_tensor(self.np_x, dtype="float32")
def test_to_api(self):
x_double = self.x._to(dtype='double')
self.assertEqual(x_double.dtype, paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(np.allclose(self.np_x, x_double))
x_ = self.x._to()
self.assertEqual(self.x.dtype, paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(np.allclose(self.np_x, x_))
if paddle.fluid.is_compiled_with_cuda():
x_gpu = self.x._to(device=paddle.CUDAPlace(0))
self.assertTrue(x_gpu.place.is_gpu_place())
self.assertEqual(x_gpu.place.gpu_device_id(), 0)
x_gpu0 = self.x._to(device='gpu:0')
self.assertTrue(x_gpu0.place.is_gpu_place())
self.assertEqual(x_gpu0.place.gpu_device_id(), 0)
x_gpu1 = self.x._to(device='gpu:0', dtype="float64")
self.assertTrue(x_gpu1.place.is_gpu_place())
self.assertEqual(x_gpu1.place.gpu_device_id(), 0)
self.assertEqual(x_gpu1.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
x_gpu2 = self.x._to(device='gpu:0', dtype="float16")
self.assertTrue(x_gpu2.place.is_gpu_place())
self.assertEqual(x_gpu2.place.gpu_device_id(), 0)
self.assertEqual(x_gpu2.dtype,
paddle.fluid.core.VarDesc.VarType.FP16)
x_cpu = self.x._to(device=paddle.CPUPlace())
self.assertTrue(x_cpu.place.is_cpu_place())
x_cpu0 = self.x._to(device='cpu')
self.assertTrue(x_cpu0.place.is_cpu_place())
x_cpu1 = self.x._to(device=paddle.CPUPlace(), dtype="float64")
self.assertTrue(x_cpu1.place.is_cpu_place())
self.assertEqual(x_cpu1.dtype, paddle.fluid.core.VarDesc.VarType.FP64)
x_cpu2 = self.x._to(device='cpu', dtype="float16")
self.assertTrue(x_cpu2.place.is_cpu_place())
self.assertEqual(x_cpu2.dtype, paddle.fluid.core.VarDesc.VarType.FP16)
self.assertRaises(ValueError, self.x._to, device=1)
self.assertRaises(AssertionError, self.x._to, blocking=1)
class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase):
def test_varbase_init(self):
paddle.disable_static()
t = fluid.Tensor()
np_x = np.random.random((3, 8, 8))
t.set(np_x, fluid.CPUPlace())
if paddle.fluid.is_compiled_with_cuda():
device = paddle.CUDAPlace(0)
tmp = fluid.core.VarBase(t, device)
self.assertTrue(tmp.place.is_gpu_place())
self.assertEqual(tmp.numpy().all(), np_x.all())
device = paddle.CPUPlace()
tmp = fluid.core.VarBase(t, device)
self.assertEqual(tmp.numpy().all(), np_x.all())
class TestVarBaseNumel(unittest.TestCase):
def test_numel(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
x_actual_numel = x._numel()
x_expected_numel = np.product((3, 8, 8))
self.assertEqual(x_actual_numel, x_expected_numel)
class TestVarBaseCopyGradientFrom(unittest.TestCase):
def test_copy_gradient_from(self):
paddle.disable_static()
np_x = np.random.random((2, 2))
np_y = np.random.random((2, 2))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64")
out = x + x
out.backward()
x._copy_gradient_from(y)
self.assertEqual(x.grad.numpy().all(), np_y.all())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册