未验证 提交 5d6d14bc 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] fix test_var_base (#41397)

* eager test var base

* refine, test=develop
上级 afb56e8c
...@@ -78,6 +78,10 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name, ...@@ -78,6 +78,10 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name,
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
ddims)); ddims));
self->tensor.set_impl(dense_tensor); self->tensor.set_impl(dense_tensor);
} else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
std::shared_ptr<phi::SelectedRows> tensor =
std::make_shared<phi::SelectedRows>();
self->tensor.set_impl(tensor);
} }
if (!autograd_meta->GetMutableGradNode()) { if (!autograd_meta->GetMutableGradNode()) {
......
...@@ -465,6 +465,9 @@ static PyObject* tensor__share_buffer_to(TensorObject* self, PyObject* args, ...@@ -465,6 +465,9 @@ static PyObject* tensor__share_buffer_to(TensorObject* self, PyObject* args,
self->tensor.name())); self->tensor.name()));
auto* src_tensor = auto* src_tensor =
static_cast<paddle::framework::Tensor*>(self->tensor.impl().get()); static_cast<paddle::framework::Tensor*>(self->tensor.impl().get());
if (!dst_ptr->defined()) {
dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
}
auto dst_tensor = auto dst_tensor =
static_cast<paddle::framework::Tensor*>(dst_ptr->impl().get()); static_cast<paddle::framework::Tensor*>(dst_ptr->impl().get());
dst_tensor->ShareDataWith(*src_tensor); dst_tensor->ShareDataWith(*src_tensor);
...@@ -565,6 +568,10 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self, ...@@ -565,6 +568,10 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
Py_IncRef(Py_None);
return Py_None;
}
if (self->tensor.is_dense_tensor()) { if (self->tensor.is_dense_tensor()) {
auto* tensor = auto* tensor =
static_cast<paddle::framework::LoDTensor*>(self->tensor.impl().get()); static_cast<paddle::framework::LoDTensor*>(self->tensor.impl().get());
...@@ -577,6 +584,25 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self, ...@@ -577,6 +584,25 @@ static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (!self->tensor.defined()) {
Py_IncRef(Py_None);
return Py_None;
}
if (self->tensor.is_selected_rows()) {
auto* selected_rows =
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
return ToPyObject(selected_rows);
} else {
Py_IncRef(Py_None);
return Py_None;
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1214,6 +1240,9 @@ static PyObject* tensor_method_get_non_zero_cols(TensorObject* self, ...@@ -1214,6 +1240,9 @@ static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args, static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
return ToPyObject(false);
}
return ToPyObject(self->tensor.is_sparse_coo_tensor() || return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
self->tensor.is_sparse_csr_tensor()); self->tensor.is_sparse_csr_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
...@@ -1222,6 +1251,9 @@ static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args, ...@@ -1222,6 +1251,9 @@ static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args,
static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args, static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
return ToPyObject(false);
}
return ToPyObject(self->tensor.is_sparse_coo_tensor()); return ToPyObject(self->tensor.is_sparse_coo_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
...@@ -1229,6 +1261,9 @@ static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args, ...@@ -1229,6 +1261,9 @@ static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args,
static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args, static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
return ToPyObject(false);
}
return ToPyObject(self->tensor.is_sparse_csr_tensor()); return ToPyObject(self->tensor.is_sparse_csr_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
...@@ -1307,6 +1342,9 @@ static PyObject* tensor_method_is_selected_rows(TensorObject* self, ...@@ -1307,6 +1342,9 @@ static PyObject* tensor_method_is_selected_rows(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
return ToPyObject(false);
}
return ToPyObject(self->tensor.is_selected_rows()); return ToPyObject(self->tensor.is_selected_rows());
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
...@@ -1323,6 +1361,13 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args, ...@@ -1323,6 +1361,13 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor_methon_element_size(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
return ToPyObject(paddle::experimental::SizeOf(self->tensor.dtype()));
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1420,6 +1465,9 @@ PyMethodDef variable_methods[] = { ...@@ -1420,6 +1465,9 @@ PyMethodDef variable_methods[] = {
{"get_tensor", {"get_tensor",
(PyCFunction)(void (*)(void))tensor_method_get_underline_tensor, (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"get_selected_rows",
(PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_getitem_index_not_tensor", {"_getitem_index_not_tensor",
(PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor, (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
...@@ -1482,6 +1530,8 @@ PyMethodDef variable_methods[] = { ...@@ -1482,6 +1530,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows, {"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)(void (*)(void))tensor_methon_element_size,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_reset_grad_inplace_version", {"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version, (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
......
...@@ -43,8 +43,14 @@ PyObject* tensor_properties_get_name(TensorObject* self, void* closure) { ...@@ -43,8 +43,14 @@ PyObject* tensor_properties_get_name(TensorObject* self, void* closure) {
PyObject* tensor_properties_get_type(TensorObject* self, void* closure) { PyObject* tensor_properties_get_type(TensorObject* self, void* closure) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
// be same to old dygraph
return ToPyObject(paddle::framework::proto::VarType::LOD_TENSOR);
}
if (self->tensor.is_dense_tensor()) { if (self->tensor.is_dense_tensor()) {
return ToPyObject(paddle::framework::proto::VarType::LOD_TENSOR); return ToPyObject(paddle::framework::proto::VarType::LOD_TENSOR);
} else if (self->tensor.is_selected_rows()) {
return ToPyObject(paddle::framework::proto::VarType::SELECTED_ROWS);
} else { } else {
Py_INCREF(Py_None); Py_INCREF(Py_None);
return Py_None; return Py_None;
...@@ -137,8 +143,11 @@ int tensor_properties_set_persistable(TensorObject* self, PyObject* value, ...@@ -137,8 +143,11 @@ int tensor_properties_set_persistable(TensorObject* self, PyObject* value,
PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) { PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) {
EAGER_TRY EAGER_TRY
auto ddim = self->tensor.shape();
std::vector<int64_t> value; std::vector<int64_t> value;
if (!self->tensor.defined()) {
return ToPyObject(value);
}
auto ddim = self->tensor.shape();
size_t rank = static_cast<size_t>(ddim.size()); size_t rank = static_cast<size_t>(ddim.size());
value.resize(rank); value.resize(rank);
for (size_t i = 0; i < rank; i++) { for (size_t i = 0; i < rank; i++) {
...@@ -165,6 +174,10 @@ PyObject* tensor_properties_get_place_str(TensorObject* self, void* closure) { ...@@ -165,6 +174,10 @@ PyObject* tensor_properties_get_place_str(TensorObject* self, void* closure) {
PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) { PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) {
EAGER_TRY EAGER_TRY
if (!self->tensor.defined()) {
// be same to old dygraph
return ToPyObject(framework::proto::VarType::FP32);
}
return ToPyObject( return ToPyObject(
paddle::framework::TransToProtoVarType(self->tensor.type())); paddle::framework::TransToProtoVarType(self->tensor.type()));
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
......
...@@ -577,6 +577,12 @@ PyObject* ToPyObject(const paddle::framework::LoDTensor* value) { ...@@ -577,6 +577,12 @@ PyObject* ToPyObject(const paddle::framework::LoDTensor* value) {
return obj.ptr(); return obj.ptr();
} }
PyObject* ToPyObject(const phi::SelectedRows* value) {
auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
obj.inc_ref();
return obj.ptr();
}
PyObject* ToPyObject(const void* value) { PyObject* ToPyObject(const void* value) {
if (value == nullptr) { if (value == nullptr) {
Py_INCREF(Py_None); Py_INCREF(Py_None);
......
...@@ -75,6 +75,7 @@ PyObject* ToPyObject(const std::vector<paddle::experimental::Tensor>& value, ...@@ -75,6 +75,7 @@ PyObject* ToPyObject(const std::vector<paddle::experimental::Tensor>& value,
bool return_py_none_if_not_initialize = false); bool return_py_none_if_not_initialize = false);
PyObject* ToPyObject(const platform::Place& value); PyObject* ToPyObject(const platform::Place& value);
PyObject* ToPyObject(const framework::LoDTensor* value); PyObject* ToPyObject(const framework::LoDTensor* value);
PyObject* ToPyObject(const phi::SelectedRows* value);
PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype); PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype);
PyObject* ToPyObject(const paddle::framework::proto::VarType& type); PyObject* ToPyObject(const paddle::framework::proto::VarType& type);
PyObject* ToPyObject(const void* value); PyObject* ToPyObject(const void* value);
......
...@@ -101,7 +101,11 @@ int64_t Tensor::size() const { return impl_->numel(); } ...@@ -101,7 +101,11 @@ int64_t Tensor::size() const { return impl_->numel(); }
phi::DDim Tensor::dims() const { return impl_->dims(); } phi::DDim Tensor::dims() const { return impl_->dims(); }
std::vector<int64_t> Tensor::shape() const { std::vector<int64_t> Tensor::shape() const {
return phi::vectorize<int64_t>(impl_->dims()); auto dims = impl_->dims();
if (dims.size() == 1 && dims.at(0) == 0) {
return {};
}
return phi::vectorize<int64_t>(dims);
} }
void Tensor::reshape(const std::vector<int64_t> &shape) { void Tensor::reshape(const std::vector<int64_t> &shape) {
......
...@@ -846,7 +846,11 @@ def monkey_patch_varbase(): ...@@ -846,7 +846,11 @@ def monkey_patch_varbase():
return res return res
@framework.dygraph_only @framework.dygraph_only
def cuda(self, device_id, blocking): def cuda(self, device_id=0, blocking=True):
if device_id is None:
device_id = 0
if not isinstance(device_id, int):
raise ValueError("\'device_id\' must be a positive integer")
if self.place.is_gpu_place(): if self.place.is_gpu_place():
return self return self
else: else:
......
...@@ -31,7 +31,7 @@ class TestVarBase(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestVarBase(unittest.TestCase):
self.dtype = np.float32 self.dtype = np.float32
self.array = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) self.array = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
def test_to_tensor(self): def func_test_to_tensor(self):
def _test_place(place): def _test_place(place):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
paddle.set_default_dtype('float32') paddle.set_default_dtype('float32')
...@@ -262,7 +262,12 @@ class TestVarBase(unittest.TestCase): ...@@ -262,7 +262,12 @@ class TestVarBase(unittest.TestCase):
_test_place(core.NPUPlace(0)) _test_place(core.NPUPlace(0))
_test_place("npu:0") _test_place("npu:0")
def test_to_tensor_not_change_input_stop_gradient(self): def test_to_tensor(self):
with _test_eager_guard():
self.func_test_to_tensor()
self.func_test_to_tensor()
def func_test_to_tensor_not_change_input_stop_gradient(self):
with paddle.fluid.dygraph.guard(core.CPUPlace()): with paddle.fluid.dygraph.guard(core.CPUPlace()):
a = paddle.zeros([1024]) a = paddle.zeros([1024])
a.stop_gradient = False a.stop_gradient = False
...@@ -270,7 +275,12 @@ class TestVarBase(unittest.TestCase): ...@@ -270,7 +275,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a.stop_gradient, False) self.assertEqual(a.stop_gradient, False)
self.assertEqual(b.stop_gradient, True) self.assertEqual(b.stop_gradient, True)
def test_to_tensor_change_place(self): def test_to_tensor_not_change_input_stop_gradient(self):
with _test_eager_guard():
self.func_test_to_tensor_not_change_input_stop_gradient()
self.func_test_to_tensor_not_change_input_stop_gradient()
def func_test_to_tensor_change_place(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
a_np = np.random.rand(1024, 1024) a_np = np.random.rand(1024, 1024)
with paddle.fluid.dygraph.guard(core.CPUPlace()): with paddle.fluid.dygraph.guard(core.CPUPlace()):
...@@ -288,7 +298,12 @@ class TestVarBase(unittest.TestCase): ...@@ -288,7 +298,12 @@ class TestVarBase(unittest.TestCase):
a = paddle.to_tensor(a, place=paddle.CUDAPinnedPlace()) a = paddle.to_tensor(a, place=paddle.CUDAPinnedPlace())
self.assertEqual(a.place.__repr__(), "Place(gpu_pinned)") self.assertEqual(a.place.__repr__(), "Place(gpu_pinned)")
def test_to_tensor_with_lodtensor(self): def test_to_tensor_change_place(self):
with _test_eager_guard():
self.func_test_to_tensor_change_place()
self.func_test_to_tensor_change_place()
def func_test_to_tensor_with_lodtensor(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
a_np = np.random.rand(1024, 1024) a_np = np.random.rand(1024, 1024)
with paddle.fluid.dygraph.guard(core.CPUPlace()): with paddle.fluid.dygraph.guard(core.CPUPlace()):
...@@ -304,7 +319,12 @@ class TestVarBase(unittest.TestCase): ...@@ -304,7 +319,12 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(np.array_equal(a_np, a.numpy())) self.assertTrue(np.array_equal(a_np, a.numpy()))
self.assertTrue(a.place.__repr__(), "Place(cpu)") self.assertTrue(a.place.__repr__(), "Place(cpu)")
def test_to_variable(self): def test_to_tensor_with_lodtensor(self):
with _test_eager_guard():
self.func_test_to_tensor_with_lodtensor()
self.func_test_to_tensor_with_lodtensor()
def func_test_to_variable(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array, name="abc") var = fluid.dygraph.to_variable(self.array, name="abc")
self.assertTrue(np.array_equal(var.numpy(), self.array)) self.assertTrue(np.array_equal(var.numpy(), self.array))
...@@ -323,7 +343,12 @@ class TestVarBase(unittest.TestCase): ...@@ -323,7 +343,12 @@ class TestVarBase(unittest.TestCase):
linear = fluid.dygraph.Linear(32, 64) linear = fluid.dygraph.Linear(32, 64)
var = linear._helper.to_variable("test", name="abc") var = linear._helper.to_variable("test", name="abc")
def test_list_to_variable(self): def test_to_variable(self):
with _test_eager_guard():
self.func_test_to_variable()
self.func_test_to_variable()
def func_test_list_to_variable(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]] array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]]
var = fluid.dygraph.to_variable(array, dtype='int32') var = fluid.dygraph.to_variable(array, dtype='int32')
...@@ -332,7 +357,12 @@ class TestVarBase(unittest.TestCase): ...@@ -332,7 +357,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(var.dtype, core.VarDesc.VarType.INT32) self.assertEqual(var.dtype, core.VarDesc.VarType.INT32)
self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR)
def test_tuple_to_variable(self): def test_list_to_variable(self):
with _test_eager_guard():
self.func_test_list_to_variable()
self.func_test_list_to_variable()
def func_test_tuple_to_variable(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
array = (((1, 2), (1, 2), (1, 2)), ((1, 2), (1, 2), (1, 2))) array = (((1, 2), (1, 2), (1, 2)), ((1, 2), (1, 2), (1, 2)))
var = fluid.dygraph.to_variable(array, dtype='float32') var = fluid.dygraph.to_variable(array, dtype='float32')
...@@ -341,14 +371,24 @@ class TestVarBase(unittest.TestCase): ...@@ -341,14 +371,24 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(var.dtype, core.VarDesc.VarType.FP32) self.assertEqual(var.dtype, core.VarDesc.VarType.FP32)
self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR)
def test_tensor_to_variable(self): def test_tuple_to_variable(self):
with _test_eager_guard():
self.func_test_tuple_to_variable()
self.func_test_tuple_to_variable()
def func_test_tensor_to_variable(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
t = fluid.Tensor() t = fluid.Tensor()
t.set(np.random.random((1024, 1024)), fluid.CPUPlace()) t.set(np.random.random((1024, 1024)), fluid.CPUPlace())
var = fluid.dygraph.to_variable(t) var = fluid.dygraph.to_variable(t)
self.assertTrue(np.array_equal(t, var.numpy())) self.assertTrue(np.array_equal(t, var.numpy()))
def test_leaf_tensor(self): def test_tensor_to_variable(self):
with _test_eager_guard():
self.func_test_tensor_to_variable()
self.func_test_tensor_to_variable()
def func_test_leaf_tensor(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x = paddle.to_tensor(np.random.uniform(-1, 1, size=[10, 10])) x = paddle.to_tensor(np.random.uniform(-1, 1, size=[10, 10]))
self.assertTrue(x.is_leaf) self.assertTrue(x.is_leaf)
...@@ -374,7 +414,12 @@ class TestVarBase(unittest.TestCase): ...@@ -374,7 +414,12 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(linear.bias.is_leaf) self.assertTrue(linear.bias.is_leaf)
self.assertFalse(out.is_leaf) self.assertFalse(out.is_leaf)
def test_detach(self): def test_leaf_tensor(self):
with _test_eager_guard():
self.func_test_leaf_tensor()
self.func_test_leaf_tensor()
def func_test_detach(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x = paddle.to_tensor(1.0, dtype="float64", stop_gradient=False) x = paddle.to_tensor(1.0, dtype="float64", stop_gradient=False)
detach_x = x.detach() detach_x = x.detach()
...@@ -407,7 +452,12 @@ class TestVarBase(unittest.TestCase): ...@@ -407,7 +452,12 @@ class TestVarBase(unittest.TestCase):
detach_x[:] = 5.0 detach_x[:] = 5.0
y.backward() y.backward()
def test_write_property(self): def test_detach(self):
with _test_eager_guard():
self.func_test_detach()
self.func_test_detach()
def func_test_write_property(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
...@@ -423,9 +473,17 @@ class TestVarBase(unittest.TestCase): ...@@ -423,9 +473,17 @@ class TestVarBase(unittest.TestCase):
var.stop_gradient = False var.stop_gradient = False
self.assertEqual(var.stop_gradient, False) self.assertEqual(var.stop_gradient, False)
def test_deep_copy(self): def test_write_property(self):
with _test_eager_guard():
self.func_test_write_property()
self.func_test_write_property()
def func_test_deep_copy(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
empty_var = core.VarBase() if _in_legacy_dygraph():
empty_var = core.VarBase()
else:
empty_var = core.eager.Tensor()
empty_var_copy = copy.deepcopy(empty_var) empty_var_copy = copy.deepcopy(empty_var)
self.assertEqual(empty_var.stop_gradient, self.assertEqual(empty_var.stop_gradient,
empty_var_copy.stop_gradient) empty_var_copy.stop_gradient)
...@@ -462,9 +520,15 @@ class TestVarBase(unittest.TestCase): ...@@ -462,9 +520,15 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(id(y_copy), id(y_copy2)) self.assertEqual(id(y_copy), id(y_copy2))
# test copy selected rows # test copy selected rows
x = core.VarBase(core.VarDesc.VarType.FP32, [3, 100], if _in_legacy_dygraph():
"selected_rows", x = core.VarBase(core.VarDesc.VarType.FP32, [3, 100],
core.VarDesc.VarType.SELECTED_ROWS, True) "selected_rows",
core.VarDesc.VarType.SELECTED_ROWS, True)
else:
x = core.eager.Tensor(core.VarDesc.VarType.FP32, [3, 100],
"selected_rows",
core.VarDesc.VarType.SELECTED_ROWS, True)
selected_rows = x.value().get_selected_rows() selected_rows = x.value().get_selected_rows()
selected_rows.get_tensor().set( selected_rows.get_tensor().set(
np.random.rand(3, 100), core.CPUPlace()) np.random.rand(3, 100), core.CPUPlace())
...@@ -486,8 +550,13 @@ class TestVarBase(unittest.TestCase): ...@@ -486,8 +550,13 @@ class TestVarBase(unittest.TestCase):
np.array(copy_selected_rows.get_tensor()), np.array(copy_selected_rows.get_tensor()),
np.array(selected_rows.get_tensor()))) np.array(selected_rows.get_tensor())))
def test_deep_copy(self):
with _test_eager_guard():
self.func_test_deep_copy()
self.func_test_deep_copy()
# test some patched methods # test some patched methods
def test_set_value(self): def func_test_set_value(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
tmp1 = np.random.uniform(0.1, 1, [2, 2, 3]).astype(self.dtype) tmp1 = np.random.uniform(0.1, 1, [2, 2, 3]).astype(self.dtype)
...@@ -497,12 +566,22 @@ class TestVarBase(unittest.TestCase): ...@@ -497,12 +566,22 @@ class TestVarBase(unittest.TestCase):
var.set_value(tmp2) var.set_value(tmp2)
self.assertTrue(np.array_equal(var.numpy(), tmp2)) self.assertTrue(np.array_equal(var.numpy(), tmp2))
def test_to_string(self): def test_set_value(self):
with _test_eager_guard():
self.func_test_set_value()
self.func_test_set_value()
def func_test_to_string(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertTrue(isinstance(str(var), str)) self.assertTrue(isinstance(str(var), str))
def test_element_size(self): def test_to_string(self):
with _test_eager_guard():
self.func_test_to_string()
self.func_test_to_string()
def func_test_element_size(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x = paddle.to_tensor(1, dtype='bool') x = paddle.to_tensor(1, dtype='bool')
self.assertEqual(x.element_size(), 1) self.assertEqual(x.element_size(), 1)
...@@ -537,7 +616,12 @@ class TestVarBase(unittest.TestCase): ...@@ -537,7 +616,12 @@ class TestVarBase(unittest.TestCase):
x = paddle.to_tensor(1, dtype='complex128') x = paddle.to_tensor(1, dtype='complex128')
self.assertEqual(x.element_size(), 16) self.assertEqual(x.element_size(), 16)
def test_backward(self): def test_element_size(self):
with _test_eager_guard():
self.func_test_element_size()
self.func_test_element_size()
def func_test_backward(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
var.stop_gradient = False var.stop_gradient = False
...@@ -546,7 +630,12 @@ class TestVarBase(unittest.TestCase): ...@@ -546,7 +630,12 @@ class TestVarBase(unittest.TestCase):
grad_var = var._grad_ivar() grad_var = var._grad_ivar()
self.assertEqual(grad_var.shape, self.shape) self.assertEqual(grad_var.shape, self.shape)
def test_gradient(self): def test_backward(self):
with _test_eager_guard():
self.func_test_backward()
self.func_test_backward()
def func_test_gradient(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
var.stop_gradient = False var.stop_gradient = False
...@@ -555,12 +644,22 @@ class TestVarBase(unittest.TestCase): ...@@ -555,12 +644,22 @@ class TestVarBase(unittest.TestCase):
grad_var = var.gradient() grad_var = var.gradient()
self.assertEqual(grad_var.shape, self.array.shape) self.assertEqual(grad_var.shape, self.array.shape)
def test_block(self): def test_gradient(self):
with _test_eager_guard():
self.func_test_gradient()
self.func_test_gradient()
def func_test_block(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertEqual(var.block, self.assertEqual(var.block,
fluid.default_main_program().global_block()) fluid.default_main_program().global_block())
def test_block(self):
with _test_eager_guard():
self.func_test_block()
self.func_test_block()
def _test_slice(self): def _test_slice(self):
w = fluid.dygraph.to_variable( w = fluid.dygraph.to_variable(
np.random.random((784, 100, 100)).astype('float64')) np.random.random((784, 100, 100)).astype('float64'))
...@@ -916,14 +1015,19 @@ class TestVarBase(unittest.TestCase): ...@@ -916,14 +1015,19 @@ class TestVarBase(unittest.TestCase):
self.func_test_slice() self.func_test_slice()
self.func_test_slice() self.func_test_slice()
def test_var_base_to_np(self): def func_test_var_base_to_np(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertTrue( self.assertTrue(
np.array_equal(var.numpy(), np.array_equal(var.numpy(),
fluid.framework._var_base_to_np(var))) fluid.framework._var_base_to_np(var)))
def test_var_base_as_np(self): def test_var_base_to_np(self):
with _test_eager_guard():
self.func_test_var_base_to_np()
self.func_test_var_base_to_np()
def func_test_var_base_as_np(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var.numpy(), np.array(var))) self.assertTrue(np.array_equal(var.numpy(), np.array(var)))
...@@ -932,7 +1036,12 @@ class TestVarBase(unittest.TestCase): ...@@ -932,7 +1036,12 @@ class TestVarBase(unittest.TestCase):
var.numpy(), np.array( var.numpy(), np.array(
var, dtype=np.float32))) var, dtype=np.float32)))
def test_if(self): def test_var_base_as_np(self):
with _test_eager_guard():
self.func_test_var_base_as_np()
self.func_test_var_base_as_np()
def func_test_if(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var1 = fluid.dygraph.to_variable(np.array([[[0]]])) var1 = fluid.dygraph.to_variable(np.array([[[0]]]))
var2 = fluid.dygraph.to_variable(np.array([[[1]]])) var2 = fluid.dygraph.to_variable(np.array([[[1]]]))
...@@ -951,7 +1060,12 @@ class TestVarBase(unittest.TestCase): ...@@ -951,7 +1060,12 @@ class TestVarBase(unittest.TestCase):
assert bool(var1) == False, "bool(var1) is False" assert bool(var1) == False, "bool(var1) is False"
assert bool(var2) == True, "bool(var2) is True" assert bool(var2) == True, "bool(var2) is True"
def test_to_static_var(self): def test_if(self):
with _test_eager_guard():
self.func_test_if()
self.func_test_if()
def func_test_to_static_var(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
# Convert VarBase into Variable or Parameter # Convert VarBase into Variable or Parameter
var_base = fluid.dygraph.to_variable(self.array, name="var_base_1") var_base = fluid.dygraph.to_variable(self.array, name="var_base_1")
...@@ -974,6 +1088,11 @@ class TestVarBase(unittest.TestCase): ...@@ -974,6 +1088,11 @@ class TestVarBase(unittest.TestCase):
static_param = weight._to_static_var() static_param = weight._to_static_var()
self._assert_to_static(weight, static_param, True) self._assert_to_static(weight, static_param, True)
def test_to_static_var(self):
with _test_eager_guard():
self.func_test_to_static_var()
self.func_test_to_static_var()
def _assert_to_static(self, var_base, static_var, is_param=False): def _assert_to_static(self, var_base, static_var, is_param=False):
if is_param: if is_param:
self.assertTrue(isinstance(static_var, fluid.framework.Parameter)) self.assertTrue(isinstance(static_var, fluid.framework.Parameter))
...@@ -1015,7 +1134,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1015,7 +1134,6 @@ class TestVarBase(unittest.TestCase):
[0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]])''' [0.2665, 0.8483, 0.5389, ..., 0.4956, 0.6862, 0.9178]])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str(self): def test_tensor_str(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1032,7 +1150,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1032,7 +1150,6 @@ class TestVarBase(unittest.TestCase):
[0. , 0. ]])''' [0. , 0. ]])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str2(self): def test_tensor_str2(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1049,7 +1166,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1049,7 +1166,6 @@ class TestVarBase(unittest.TestCase):
[ 0. , -0.5000]])''' [ 0. , -0.5000]])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str3(self): def test_tensor_str3(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1065,7 +1181,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1065,7 +1181,6 @@ class TestVarBase(unittest.TestCase):
False)''' False)'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_scaler(self): def test_tensor_str_scaler(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1082,7 +1197,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1082,7 +1197,6 @@ class TestVarBase(unittest.TestCase):
[])''' [])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_shape_with_zero(self): def test_tensor_str_shape_with_zero(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1115,7 +1229,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1115,7 +1229,6 @@ class TestVarBase(unittest.TestCase):
0.4678, 0.5047])''' 0.4678, 0.5047])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_linewidth(self): def test_tensor_str_linewidth(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1143,7 +1256,6 @@ class TestVarBase(unittest.TestCase): ...@@ -1143,7 +1256,6 @@ class TestVarBase(unittest.TestCase):
8.9448e-01, 7.0981e-01, 8.0783e-01, 4.7065e-01, 5.7154e-01, 7.2319e-01, 4.6777e-01, 5.0465e-01])''' 8.9448e-01, 7.0981e-01, 8.0783e-01, 4.7065e-01, 5.7154e-01, 7.2319e-01, 4.6777e-01, 5.0465e-01])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_linewidth2(self): def test_tensor_str_linewidth2(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -1162,14 +1274,18 @@ class TestVarBase(unittest.TestCase): ...@@ -1162,14 +1274,18 @@ class TestVarBase(unittest.TestCase):
[0. , 0. ]])''' [0. , 0. ]])'''
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_bf16(self): def test_tensor_str_bf16(self):
with _test_eager_guard(): with _test_eager_guard():
self.func_tensor_str_bf16() self.func_tensor_str_bf16()
self.func_tensor_str_bf16() self.func_tensor_str_bf16()
def test_print_tensor_dtype(self): def test_tensor_str_bf16(self):
with _test_eager_guard():
self.func_tensor_str_bf16()
self.func_tensor_str_bf16()
def func_test_print_tensor_dtype(self):
paddle.disable_static(paddle.CPUPlace()) paddle.disable_static(paddle.CPUPlace())
a = paddle.rand([1]) a = paddle.rand([1])
a_str = str(a.dtype) a_str = str(a.dtype)
...@@ -1177,11 +1293,15 @@ class TestVarBase(unittest.TestCase): ...@@ -1177,11 +1293,15 @@ class TestVarBase(unittest.TestCase):
expected = 'paddle.float32' expected = 'paddle.float32'
self.assertEqual(a_str, expected) self.assertEqual(a_str, expected)
paddle.enable_static()
def test_print_tensor_dtype(self):
with _test_eager_guard():
self.func_test_print_tensor_dtype()
self.func_test_print_tensor_dtype()
class TestVarBaseSetitem(unittest.TestCase): class TestVarBaseSetitem(unittest.TestCase):
def setUp(self): def func_setUp(self):
self.set_dtype() self.set_dtype()
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype)) self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype) self.np_value = np.random.random((2, 3)).astype(self.dtype)
...@@ -1225,9 +1345,9 @@ class TestVarBaseSetitem(unittest.TestCase): ...@@ -1225,9 +1345,9 @@ class TestVarBaseSetitem(unittest.TestCase):
def test_value_tensor(self): def test_value_tensor(self):
with _test_eager_guard(): with _test_eager_guard():
self.setUp() self.func_setUp()
self.func_test_value_tensor() self.func_test_value_tensor()
self.setUp() self.func_setUp()
self.func_test_value_tensor() self.func_test_value_tensor()
def func_test_value_numpy(self): def func_test_value_numpy(self):
...@@ -1235,9 +1355,9 @@ class TestVarBaseSetitem(unittest.TestCase): ...@@ -1235,9 +1355,9 @@ class TestVarBaseSetitem(unittest.TestCase):
def test_value_numpy(self): def test_value_numpy(self):
with _test_eager_guard(): with _test_eager_guard():
self.setUp() self.func_setUp()
self.func_test_value_numpy() self.func_test_value_numpy()
self.setUp() self.func_setUp()
self.func_test_value_numpy() self.func_test_value_numpy()
def func_test_value_int(self): def func_test_value_int(self):
...@@ -1245,9 +1365,9 @@ class TestVarBaseSetitem(unittest.TestCase): ...@@ -1245,9 +1365,9 @@ class TestVarBaseSetitem(unittest.TestCase):
def test_value_int(self): def test_value_int(self):
with _test_eager_guard(): with _test_eager_guard():
self.setUp() self.func_setUp()
self.func_test_value_int() self.func_test_value_int()
self.setUp() self.func_setUp()
self.func_test_value_int() self.func_test_value_int()
...@@ -1260,10 +1380,17 @@ class TestVarBaseSetitemFp32(TestVarBaseSetitem): ...@@ -1260,10 +1380,17 @@ class TestVarBaseSetitemFp32(TestVarBaseSetitem):
def set_dtype(self): def set_dtype(self):
self.dtype = "float32" self.dtype = "float32"
def test_value_float(self): def func_test_value_float(self):
paddle.disable_static() paddle.disable_static()
self._test(3.3) self._test(3.3)
def test_value_float(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_value_float()
self.func_setUp()
self.func_test_value_float()
class TestVarBaseSetitemFp64(TestVarBaseSetitem): class TestVarBaseSetitemFp64(TestVarBaseSetitem):
def set_dtype(self): def set_dtype(self):
...@@ -1271,7 +1398,7 @@ class TestVarBaseSetitemFp64(TestVarBaseSetitem): ...@@ -1271,7 +1398,7 @@ class TestVarBaseSetitemFp64(TestVarBaseSetitem):
class TestVarBaseSetitemBoolIndex(unittest.TestCase): class TestVarBaseSetitemBoolIndex(unittest.TestCase):
def setUp(self): def func_setUp(self):
paddle.disable_static() paddle.disable_static()
self.set_dtype() self.set_dtype()
self.set_input() self.set_input()
...@@ -1314,18 +1441,39 @@ class TestVarBaseSetitemBoolIndex(unittest.TestCase): ...@@ -1314,18 +1441,39 @@ class TestVarBaseSetitemBoolIndex(unittest.TestCase):
self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result)) self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x)) self.assertEqual(id_origin, id(self.tensor_x))
def test_value_tensor(self): def func_test_value_tensor(self):
paddle.disable_static() paddle.disable_static()
self._test(self.tensor_value) self._test(self.tensor_value)
def test_value_numpy(self): def test_value_tensor(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_value_tensor()
self.func_setUp()
self.func_test_value_tensor()
def func_test_value_numpy(self):
paddle.disable_static() paddle.disable_static()
self._test(self.np_value) self._test(self.np_value)
def test_value_int(self): def test_value_numpy(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_value_numpy()
self.func_setUp()
self.func_test_value_numpy()
def func_test_value_int(self):
paddle.disable_static() paddle.disable_static()
self._test(10) self._test(10)
def test_value_int(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_value_int()
self.func_setUp()
self.func_test_value_int()
class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase): class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase):
def set_input(self): def set_input(self):
...@@ -1353,7 +1501,7 @@ class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase): ...@@ -1353,7 +1501,7 @@ class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase):
class TestVarBaseInplaceVersion(unittest.TestCase): class TestVarBaseInplaceVersion(unittest.TestCase):
def test_setitem(self): def func_test_setitem(self):
paddle.disable_static() paddle.disable_static()
var = paddle.ones(shape=[4, 2, 3], dtype="float32") var = paddle.ones(shape=[4, 2, 3], dtype="float32")
...@@ -1365,7 +1513,12 @@ class TestVarBaseInplaceVersion(unittest.TestCase): ...@@ -1365,7 +1513,12 @@ class TestVarBaseInplaceVersion(unittest.TestCase):
var[1:2] = 1 var[1:2] = 1
self.assertEqual(var.inplace_version, 2) self.assertEqual(var.inplace_version, 2)
def test_bump_inplace_version(self): def test_setitem(self):
with _test_eager_guard():
self.func_test_setitem()
self.func_test_setitem()
def func_test_bump_inplace_version(self):
paddle.disable_static() paddle.disable_static()
var = paddle.ones(shape=[4, 2, 3], dtype="float32") var = paddle.ones(shape=[4, 2, 3], dtype="float32")
self.assertEqual(var.inplace_version, 0) self.assertEqual(var.inplace_version, 0)
...@@ -1376,9 +1529,14 @@ class TestVarBaseInplaceVersion(unittest.TestCase): ...@@ -1376,9 +1529,14 @@ class TestVarBaseInplaceVersion(unittest.TestCase):
var._bump_inplace_version() var._bump_inplace_version()
self.assertEqual(var.inplace_version, 2) self.assertEqual(var.inplace_version, 2)
def test_bump_inplace_version(self):
with _test_eager_guard():
self.func_test_bump_inplace_version()
self.func_test_bump_inplace_version()
class TestVarBaseSlice(unittest.TestCase): class TestVarBaseSlice(unittest.TestCase):
def test_slice(self): def func_test_slice(self):
paddle.disable_static() paddle.disable_static()
np_x = np.random.random((3, 8, 8)) np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64") x = paddle.to_tensor(np_x, dtype="float64")
...@@ -1386,15 +1544,25 @@ class TestVarBaseSlice(unittest.TestCase): ...@@ -1386,15 +1544,25 @@ class TestVarBaseSlice(unittest.TestCase):
actual_x = paddle.to_tensor(actual_x) actual_x = paddle.to_tensor(actual_x)
self.assertEqual(actual_x.numpy().all(), np_x[0:1].all()) self.assertEqual(actual_x.numpy().all(), np_x[0:1].all())
def test_slice(self):
with _test_eager_guard():
self.func_test_slice()
self.func_test_slice()
class TestVarBaseClear(unittest.TestCase): class TestVarBaseClear(unittest.TestCase):
def test_clear(self): def func_test_clear(self):
paddle.disable_static() paddle.disable_static()
np_x = np.random.random((3, 8, 8)) np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64") x = paddle.to_tensor(np_x, dtype="float64")
x._clear() x._clear()
self.assertEqual(str(x), "Tensor(Not initialized)") self.assertEqual(str(x), "Tensor(Not initialized)")
def test_clear(self):
with _test_eager_guard():
self.func_test_clear()
self.func_test_clear()
class TestVarBaseOffset(unittest.TestCase): class TestVarBaseOffset(unittest.TestCase):
def func_offset(self): def func_offset(self):
...@@ -1413,23 +1581,31 @@ class TestVarBaseOffset(unittest.TestCase): ...@@ -1413,23 +1581,31 @@ class TestVarBaseOffset(unittest.TestCase):
class TestVarBaseShareBufferTo(unittest.TestCase): class TestVarBaseShareBufferTo(unittest.TestCase):
def test_share_buffer_To(self): def func_test_share_buffer_To(self):
paddle.disable_static() paddle.disable_static()
np_src = np.random.random((3, 8, 8)) np_src = np.random.random((3, 8, 8))
src = paddle.to_tensor(np_src, dtype="float64") src = paddle.to_tensor(np_src, dtype="float64")
# empty_var # empty_var
dst = core.VarBase() if _in_legacy_dygraph():
dst = core.VarBase()
else:
dst = core.eager.Tensor()
src._share_buffer_to(dst) src._share_buffer_to(dst)
self.assertEqual(src._is_shared_buffer_with(dst), True) self.assertEqual(src._is_shared_buffer_with(dst), True)
def test_share_buffer_To(self):
with _test_eager_guard():
self.func_test_share_buffer_To()
self.func_test_share_buffer_To()
class TestVarBaseTo(unittest.TestCase): class TestVarBaseTo(unittest.TestCase):
def setUp(self): def func_setUp(self):
paddle.disable_static() paddle.disable_static()
self.np_x = np.random.random((3, 8, 8)) self.np_x = np.random.random((3, 8, 8))
self.x = paddle.to_tensor(self.np_x, dtype="float32") self.x = paddle.to_tensor(self.np_x, dtype="float32")
def test_to_api(self): def func_test_to_api(self):
x_double = self.x._to(dtype='double') x_double = self.x._to(dtype='double')
self.assertEqual(x_double.dtype, paddle.fluid.core.VarDesc.VarType.FP64) self.assertEqual(x_double.dtype, paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(np.allclose(self.np_x, x_double)) self.assertTrue(np.allclose(self.np_x, x_double))
...@@ -1476,9 +1652,16 @@ class TestVarBaseTo(unittest.TestCase): ...@@ -1476,9 +1652,16 @@ class TestVarBaseTo(unittest.TestCase):
self.assertRaises(ValueError, self.x._to, device=1) self.assertRaises(ValueError, self.x._to, device=1)
self.assertRaises(AssertionError, self.x._to, blocking=1) self.assertRaises(AssertionError, self.x._to, blocking=1)
def test_to_api(self):
with _test_eager_guard():
self.func_setUp()
self.func_test_to_api()
self.func_setUp()
self.func_test_to_api()
class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase): class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase):
def test_varbase_init(self): def func_test_varbase_init(self):
paddle.disable_static() paddle.disable_static()
t = fluid.Tensor() t = fluid.Tensor()
np_x = np.random.random((3, 8, 8)) np_x = np.random.random((3, 8, 8))
...@@ -1486,17 +1669,28 @@ class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase): ...@@ -1486,17 +1669,28 @@ class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase):
if paddle.fluid.is_compiled_with_cuda(): if paddle.fluid.is_compiled_with_cuda():
device = paddle.CUDAPlace(0) device = paddle.CUDAPlace(0)
tmp = fluid.core.VarBase(t, device) if _in_legacy_dygraph():
tmp = fluid.core.VarBase(t, device)
else:
tmp = fluid.core.eager.Tensor(t, device)
self.assertTrue(tmp.place.is_gpu_place()) self.assertTrue(tmp.place.is_gpu_place())
self.assertEqual(tmp.numpy().all(), np_x.all()) self.assertEqual(tmp.numpy().all(), np_x.all())
device = paddle.CPUPlace() device = paddle.CPUPlace()
tmp = fluid.core.VarBase(t, device) if _in_legacy_dygraph():
tmp = fluid.core.VarBase(t, device)
else:
tmp = fluid.core.eager.Tensor(t, device)
self.assertEqual(tmp.numpy().all(), np_x.all()) self.assertEqual(tmp.numpy().all(), np_x.all())
def test_varbase_init(self):
with _test_eager_guard():
self.func_test_varbase_init()
self.func_test_varbase_init()
class TestVarBaseNumel(unittest.TestCase): class TestVarBaseNumel(unittest.TestCase):
def test_numel_normal(self): def func_test_numel_normal(self):
paddle.disable_static() paddle.disable_static()
np_x = np.random.random((3, 8, 8)) np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64") x = paddle.to_tensor(np_x, dtype="float64")
...@@ -1504,15 +1698,28 @@ class TestVarBaseNumel(unittest.TestCase): ...@@ -1504,15 +1698,28 @@ class TestVarBaseNumel(unittest.TestCase):
x_expected_numel = np.product((3, 8, 8)) x_expected_numel = np.product((3, 8, 8))
self.assertEqual(x_actual_numel, x_expected_numel) self.assertEqual(x_actual_numel, x_expected_numel)
def test_numel_without_holder(self): def test_numel_normal(self):
with _test_eager_guard():
self.func_test_numel_normal()
self.func_test_numel_normal()
def func_test_numel_without_holder(self):
paddle.disable_static() paddle.disable_static()
x_without_holder = core.VarBase() if _in_legacy_dygraph():
x_without_holder = core.VarBase()
else:
x_without_holder = core.eager.Tensor()
x_actual_numel = x_without_holder._numel() x_actual_numel = x_without_holder._numel()
self.assertEqual(x_actual_numel, 0) self.assertEqual(x_actual_numel, 0)
def ttest_numel_without_holder(self):
with _test_eager_guard():
self.func_test_numel_without_holder()
self.func_test_numel_without_holder()
class TestVarBaseCopyGradientFrom(unittest.TestCase): class TestVarBaseCopyGradientFrom(unittest.TestCase):
def test_copy_gradient_from(self): def func_test_copy_gradient_from(self):
paddle.disable_static() paddle.disable_static()
np_x = np.random.random((2, 2)) np_x = np.random.random((2, 2))
np_y = np.random.random((2, 2)) np_y = np.random.random((2, 2))
...@@ -1523,7 +1730,11 @@ class TestVarBaseCopyGradientFrom(unittest.TestCase): ...@@ -1523,7 +1730,11 @@ class TestVarBaseCopyGradientFrom(unittest.TestCase):
x._copy_gradient_from(y) x._copy_gradient_from(y)
self.assertEqual(x.grad.numpy().all(), np_y.all()) self.assertEqual(x.grad.numpy().all(), np_y.all())
def test_copy_gradient_from(self):
with _test_eager_guard():
self.func_test_copy_gradient_from()
self.func_test_copy_gradient_from()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册