diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index c7f14cd021f97731bf497ccfd689949c66c641dd..616a99b9bcc888ee0b6326dbcc3c2798ad4b3077 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -137,7 +137,10 @@ class EagerUtils { template static bool ComputeRequireGrad(T trace_backward, Args&&... args) { - if (!trace_backward) return false; + if (!trace_backward) { + VLOG(6) << "Do not require grad because trace_backward = false"; + return false; + } auto iter = ComputeRequireGradIter(); iter.apply(std::forward(args)...); diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index c1e1bfc88c41c86ba4a12e8dc97a573c9c66fb94..895d1db43cbf4f316d0b5d4db4880f01894b66f2 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" @@ -33,7 +32,6 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/slice_utils.h" -#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/api/include/api.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -41,6 +39,9 @@ limitations under the License. */ #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "pybind11/detail/internals.h" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#include "paddle/fluid/framework/python_headers.h" +#include "paddle/fluid/pybind/tensor_py.h" namespace paddle { namespace pybind { @@ -682,6 +683,103 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + auto ptr = static_cast(self->tensor.impl().get()); + PADDLE_ENFORCE_NOT_NULL( + ptr, platform::errors::InvalidArgument("%s is not a DenseTensor.", + self->tensor.name())); + const auto& tensor = *ptr; + PADDLE_ENFORCE_EQ( + tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "Tensor of %s is Empty, please check if it has no data.", + self->tensor.name())); + + const auto& tensor_dims = tensor.dims(); + + std::vector dims(tensor_dims.size()); + std::vector strides(tensor_dims.size()); + + size_t numel = 1; + for (int i = tensor_dims.size() - 1; i >= 0; --i) { + strides[i] = numel; + dims[i] = static_cast(tensor_dims[i]); + numel *= dims[i]; + } + size_t offset = 0; + if (PyTuple_Size(args) == 0) { + PADDLE_ENFORCE_EQ(numel, 1, + platform::errors::InvalidArgument( + "only one element tensors can be converted to Python " + "scalars when no input coordinates")); + } else if (PyTuple_Size(args) == 1) { + offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0); + PADDLE_ENFORCE_LT( + offset, numel, + platform::errors::InvalidArgument( + "index %d is out of bounds for size %d", offset, numel)); + } else { + PADDLE_ENFORCE_EQ(PyTuple_Size(args), dims.size(), + platform::errors::InvalidArgument( + "incorrect number of indices for Tensor")); + + for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) { + size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i); + PADDLE_ENFORCE_LT( + index, dims[i], + platform::errors::InvalidArgument( + "index %d is out fo bounds for axis %d with size %d", index, i, + dims[i])); + offset += index * strides[i]; + } + } +#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \ + _(bool, DataType::BOOL) \ + _(int8_t, DataType::INT8) \ + _(uint8_t, DataType::UINT8) \ + _(int16_t, DataType::INT16) \ + _(uint16_t, DataType::UINT16) \ + _(int32_t, DataType::INT32) \ + _(uint32_t, DataType::UINT32) \ + _(int64_t, DataType::INT64) \ + _(uint64_t, DataType::UINT64) \ + _(bfloat16, DataType::BFLOAT16) \ + _(float16, DataType::FLOAT16) \ + _(float, DataType::FLOAT32) \ + _(double, DataType::FLOAT64) \ + _(complex64, DataType::COMPLEX64) \ + _(complex128, DataType::COMPLEX128) + +#define TENSOR_TO_PY_SCALAR(T, proto_type) \ + if (tensor.dtype() == proto_type) { \ + auto numpy_dtype = TensorDtype2NumpyDtype(proto_type); \ + T b = paddle::pybind::TensorGetElement(tensor, offset); \ + Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; \ + Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank]; \ + py_dims[0] = 1; \ + py_strides[0] = 1; \ + auto& api = pybind11::detail::npy_api::get(); \ + PyObject* array = api.PyArray_NewFromDescr_( \ + api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype), 1, \ + py_dims, py_strides, nullptr, \ + pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ | \ + pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_, \ + nullptr); \ + std::memcpy( \ + reinterpret_cast(pybind11::detail::array_proxy(array)->data), \ + static_cast(&b), sizeof(b)); \ + return array; \ + } + + PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR); +#undef TENSOR_TO_PY_SCALAR + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported tensor data type: %s", tensor.dtype())); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -1252,6 +1350,9 @@ PyMethodDef variable_methods[] = { {"_getitem_index_not_tensor", (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_getitem_from_offset", + (PyCFunction)(void (*)(void))tensor__getitem_from_offset, + METH_VARARGS | METH_KEYWORDS, NULL}, {"__setitem_eager_tensor__", (PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 17300e5ce90ac342dfd5d84f15c723c5b58f53f8..af89861d151bea3eb3eadee4a9dbbaff3fcc0566 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -633,8 +633,15 @@ paddle::optional GetOptionalTensorFromArgs( return paddle::none; } - return paddle::make_optional( - reinterpret_cast(obj)->tensor); + if (PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type))) { + return paddle::make_optional( + reinterpret_cast(obj)->tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got %s", op_type, + arg_name, arg_idx, + reinterpret_cast(obj->ob_type)->tp_name)); + } } static paddle::experimental::Tensor& GetTensorFromPyObject( @@ -654,7 +661,14 @@ static paddle::experimental::Tensor& GetTensorFromPyObject( return emptytensor; } - return reinterpret_cast(obj)->tensor; + if (PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type))) { + return reinterpret_cast(obj)->tensor; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got %s", op_type, + arg_name, arg_idx, + reinterpret_cast(obj->ob_type)->tp_name)); + } } // For Intermediate State Dygraph, @@ -744,7 +758,14 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type, return &emptytensor; } - return &(reinterpret_cast(obj)->tensor); + if (PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type))) { + return &(reinterpret_cast(obj)->tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got %s", op_type, + arg_name, arg_idx, + reinterpret_cast(obj->ob_type)->tp_name)); + } } std::vector GetTensorPtrListFromArgs( diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index f6baee02f85c77bfdf723ca2ba6de1cbf2960a75..54a245aab81c90abc2d745ed597270b051f45f99 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -764,7 +764,7 @@ class Layer(object): elif tensor is not None and not (type(tensor) == core.VarBase or type(tensor) == core.eager.Tensor): raise TypeError( - "The registered buffer should be a core.VarBase, but received {}.". + "The registered buffer should be a Paddle.Tensor, but received {}.". format(type(tensor).__name__)) else: self._buffers[name] = tensor @@ -1158,8 +1158,7 @@ class Layer(object): layers[name] = None else: _buffers = self.__dict__.get('_buffers', None) - if type(value) == core.VarBase or \ - type(value) == core.eager.Tensor: + if isinstance(value, (core.VarBase, core.eager.Tensor)): if _buffers is None: raise ValueError( "super(YourLayer, self).__init__() should be called first" diff --git a/python/paddle/fluid/tests/unittests/test_base_layer.py b/python/paddle/fluid/tests/unittests/test_base_layer.py index 789cfa82658f43d2adb148fe41fd2fb380e96fba..3bdd03b32127642dcfa33b04fdb7a78f9a1fc0a7 100644 --- a/python/paddle/fluid/tests/unittests/test_base_layer.py +++ b/python/paddle/fluid/tests/unittests/test_base_layer.py @@ -18,8 +18,9 @@ import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import to_variable -from paddle.fluid.framework import ParamBase +from paddle.fluid.framework import ParamBase, EagerParamBase from paddle.jit import ProgramTranslator +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class L1(fluid.Layer): @@ -57,7 +58,7 @@ class L3(fluid.Layer): class TestBaseLayer(unittest.TestCase): - def test_one_level(self): + def func_test_one_level(self): with fluid.dygraph.guard(): l = L1() ret = l() @@ -68,7 +69,12 @@ class TestBaseLayer(unittest.TestCase): idx += 1 self.assertTrue(np.allclose(ret.numpy(), 0.2 * np.ones([2, 2]))) - def test_three_level(self): + def test_one_level(self): + with _test_eager_guard(): + self.func_test_one_level() + self.func_test_one_level() + + def func_test_three_level(self): with fluid.dygraph.guard(): l = L3() expected_names = [ @@ -88,7 +94,12 @@ class TestBaseLayer(unittest.TestCase): ret = l() self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2]))) - def test_add_parameter_with_error(self): + def test_three_level(self): + with _test_eager_guard(): + self.func_test_three_level() + self.func_test_three_level() + + def func_test_add_parameter_with_error(self): with fluid.dygraph.guard(): net = fluid.Layer() param = net.create_parameter(shape=[1]) @@ -113,6 +124,11 @@ class TestBaseLayer(unittest.TestCase): net._loaddict_holder[load_param.name] = load_param net.add_parameter("load_param", load_param) + def test_add_parameter_with_error(self): + with _test_eager_guard(): + self.func_test_add_parameter_with_error() + self.func_test_add_parameter_with_error() + class BufferLayer(fluid.Layer): def __init__(self): @@ -140,7 +156,7 @@ class BufferNet(fluid.Layer): class TestBuffer(unittest.TestCase): - def test_buffers_and_named_buffers(self): + def func_test_buffers_and_named_buffers(self): def names(named_buffers): return [name for name, _ in named_buffers] @@ -161,7 +177,12 @@ class TestBuffer(unittest.TestCase): names(net.named_buffers(include_sublayers=False)), ['net_buffer', 'new_buffer']) - def test_register_buffer_with_error(self): + def test_buffers_and_named_buffers(self): + with _test_eager_guard(): + self.func_test_buffers_and_named_buffers() + self.func_test_buffers_and_named_buffers() + + def func_test_register_buffer_with_error(self): with fluid.dygraph.guard(): net = fluid.Layer() var = to_variable(np.zeros([1])) @@ -171,8 +192,13 @@ class TestBuffer(unittest.TestCase): net.register_buffer(12, var) with self.assertRaisesRegexp(TypeError, - "buffer should be a core.VarBase"): - net.register_buffer("buffer_name", ParamBase([2, 2], 'float32')) + "buffer should be a Paddle.Tensor"): + if in_dygraph_mode(): + net.register_buffer("buffer_name", + EagerParamBase([2, 2], 'float32')) + else: + net.register_buffer("buffer_name", + ParamBase([2, 2], 'float32')) with self.assertRaisesRegexp(KeyError, "name of buffer can not contain"): @@ -187,11 +213,19 @@ class TestBuffer(unittest.TestCase): net.register_buffer("attr_name", var) del net.attr_name - net.attr_name = ParamBase([2, 2], 'float32') + if in_dygraph_mode(): + net.attr_name = EagerParamBase([2, 2], 'float32') + else: + net.attr_name = ParamBase([2, 2], 'float32') with self.assertRaisesRegexp(KeyError, "already exists"): net.register_buffer("attr_name", var) - def test_register_buffer_same_name(self): + def test_register_buffer_with_error(self): + with _test_eager_guard(): + self.func_test_register_buffer_with_error() + self.func_test_register_buffer_with_error() + + def func_test_register_buffer_same_name(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([1])) @@ -205,7 +239,12 @@ class TestBuffer(unittest.TestCase): net.register_buffer("buffer_name", var3) self.assert_var_base_equal(net.buffer_name, var3) - def test_buffer_not_persistable(self): + def test_register_buffer_same_name(self): + with _test_eager_guard(): + self.func_test_register_buffer_same_name() + self.func_test_register_buffer_same_name() + + def func_test_buffer_not_persistable(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([1])) @@ -214,7 +253,12 @@ class TestBuffer(unittest.TestCase): self.assertEqual(len(net.buffers()), 1) self.assertEqual(len(net.state_dict()), 0) - def test_buffer_not_persistable_del(self): + def test_buffer_not_persistable(self): + with _test_eager_guard(): + self.func_test_buffer_not_persistable() + self.func_test_buffer_not_persistable() + + def func_test_buffer_not_persistable_del(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([1])) @@ -222,7 +266,12 @@ class TestBuffer(unittest.TestCase): del net.buffer_name self.assertEqual(len(net.buffers()), 0) - def test_buffer_not_persistable_overwrite(self): + def test_buffer_not_persistable_del(self): + with _test_eager_guard(): + self.func_test_buffer_not_persistable_del() + self.func_test_buffer_not_persistable_del() + + def func_test_buffer_not_persistable_overwrite(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([1])) @@ -238,7 +287,12 @@ class TestBuffer(unittest.TestCase): self.assertEqual(len(net.buffers()), 1) self.assertEqual(len(net.state_dict()), 0) - def test_buffer_not_persistable_assign(self): + def test_buffer_not_persistable_overwrite(self): + with _test_eager_guard(): + self.func_test_buffer_not_persistable_overwrite() + self.func_test_buffer_not_persistable_overwrite() + + def func_test_buffer_not_persistable_assign(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([1])) @@ -255,18 +309,31 @@ class TestBuffer(unittest.TestCase): self.assertEqual(len(net.state_dict()), 0) # Re-assign a ParamBase will remove the buffer. - net.buffer_name = ParamBase([2, 2], 'float32') + if in_dygraph_mode(): + net.buffer_name = EagerParamBase([2, 2], 'float32') + else: + net.buffer_name = ParamBase([2, 2], 'float32') self.assertEqual(len(net.buffers()), 0) self.assertEqual(len(net.state_dict()), 1) - def test_buffer_not_persistable_load(self): + def test_buffer_not_persistable_assign(self): + with _test_eager_guard(): + self.func_test_buffer_not_persistable_assign() + self.func_test_buffer_not_persistable_assign() + + def func_test_buffer_not_persistable_load(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([1])) net.register_buffer("buffer_name", var1, persistable=False) net.load_dict({}) - def test_buffer_state_dict(self): + def test_buffer_not_persistable_load(self): + with _test_eager_guard(): + self.func_test_buffer_not_persistable_load() + self.func_test_buffer_not_persistable_load() + + def func_test_buffer_state_dict(self): with fluid.dygraph.guard(): net = fluid.Layer() var1 = to_variable(np.zeros([2, 3])) @@ -286,6 +353,11 @@ class TestBuffer(unittest.TestCase): self.assert_var_base_equal(net_load.buffer_var1, var1) + def test_buffer_state_dict(self): + with _test_eager_guard(): + self.func_test_buffer_state_dict() + self.func_test_buffer_state_dict() + def assert_var_base_equal(self, var1, var2): self.assertTrue(np.array_equal(var1.numpy(), var2.numpy())) @@ -308,7 +380,7 @@ class BufferNetWithModification(paddle.nn.Layer): class TestModifiedBuffer(unittest.TestCase): - def setUp(self): + def funcsetUp(self): paddle.disable_static() self.prog_trans = ProgramTranslator() self.shape = [10, 16] @@ -322,7 +394,8 @@ class TestModifiedBuffer(unittest.TestCase): return out, net.buffer1, net.buffer2 - def test_modified(self): + def func_test_modified(self): + self.funcsetUp() dy_outs = self._run(False) st_outs = self._run(True) @@ -330,9 +403,14 @@ class TestModifiedBuffer(unittest.TestCase): self.assertTrue( np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy())) + def test_modified(self): + with _test_eager_guard(): + self.func_test_modified() + self.func_test_modified() + class TestLayerTo(unittest.TestCase): - def setUp(self): + def funcsetUp(self): paddle.disable_static() self.linear = paddle.nn.Linear(2, 2) self.new_grad = np.random.random([2, 2]) @@ -343,7 +421,7 @@ class TestLayerTo(unittest.TestCase): sublayer = paddle.nn.Conv1D(3, 2, 3) self.linear.add_sublayer("1", sublayer) - def test_to_api(self): + def func_test_to_api(self): self.linear.to(dtype='double') self.assertEqual(self.linear.weight.dtype, paddle.fluid.core.VarDesc.VarType.FP64) @@ -364,7 +442,11 @@ class TestLayerTo(unittest.TestCase): self.assertEqual(self.linear.weight._grad_ivar().dtype, paddle.fluid.core.VarDesc.VarType.FP64) for p in self.linear.parameters(): - self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) + if in_dygraph_mode(): + self.assertTrue( + isinstance(p, paddle.fluid.framework.EagerParamBase)) + else: + self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) if paddle.fluid.is_compiled_with_cuda(): self.linear.to(device=paddle.CUDAPlace(0)) @@ -387,7 +469,12 @@ class TestLayerTo(unittest.TestCase): self.assertEqual( self.linear.weight._grad_ivar().place.gpu_device_id(), 0) for p in self.linear.parameters(): - self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) + if in_dygraph_mode(): + self.assertTrue( + isinstance(p, paddle.fluid.framework.EagerParamBase)) + else: + self.assertTrue( + isinstance(p, paddle.fluid.framework.ParamBase)) self.linear.to(device=paddle.CPUPlace()) self.assertTrue(self.linear.weight.place.is_cpu_place()) @@ -403,7 +490,7 @@ class TestLayerTo(unittest.TestCase): self.assertRaises(AssertionError, self.linear.to, blocking=1) - def test_to_api_paddle_dtype(self): + def func_test_to_api_paddle_dtype(self): self.linear.to(dtype=paddle.float64) self.assertEqual(self.linear.weight.dtype, paddle.fluid.core.VarDesc.VarType.FP64) @@ -424,9 +511,13 @@ class TestLayerTo(unittest.TestCase): self.assertEqual(self.linear.weight._grad_ivar().dtype, paddle.fluid.core.VarDesc.VarType.FP64) for p in self.linear.parameters(): - self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) + if in_dygraph_mode(): + self.assertTrue( + isinstance(p, paddle.fluid.framework.EagerParamBase)) + else: + self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) - def test_to_api_numpy_dtype(self): + def func_test_to_api_numpy_dtype(self): self.linear.to(dtype=np.float64) self.assertEqual(self.linear.weight.dtype, paddle.fluid.core.VarDesc.VarType.FP64) @@ -447,7 +538,22 @@ class TestLayerTo(unittest.TestCase): self.assertEqual(self.linear.weight._grad_ivar().dtype, paddle.fluid.core.VarDesc.VarType.FP64) for p in self.linear.parameters(): - self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) + if in_dygraph_mode(): + self.assertTrue( + isinstance(p, paddle.fluid.framework.EagerParamBase)) + else: + self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) + + def test_main(self): + with _test_eager_guard(): + self.funcsetUp() + self.func_test_to_api() + self.func_test_to_api_paddle_dtype() + self.func_test_to_api_numpy_dtype() + self.funcsetUp() + self.func_test_to_api() + self.func_test_to_api_paddle_dtype() + self.func_test_to_api_numpy_dtype() if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_cov.py b/python/paddle/fluid/tests/unittests/test_cov.py index 93ecf13bdcbe7b8d488ad496f751bfea9de070a1..5c4b9cbab27904d956f8c5b50b76ac98e0b6ae3f 100644 --- a/python/paddle/fluid/tests/unittests/test_cov.py +++ b/python/paddle/fluid/tests/unittests/test_cov.py @@ -17,6 +17,7 @@ import unittest import numpy as np import six import paddle +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode def numpy_cov(np_arr, rowvar=True, ddof=1, fweights=None, aweights=None): @@ -32,7 +33,7 @@ class Cov_Test(unittest.TestCase): self.shape = [20, 10] self.weightshape = [10] - def test_tensor_cov_default(self): + def func_test_tensor_cov_default(self): typelist = ['float64'] places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -56,7 +57,12 @@ class Cov_Test(unittest.TestCase): np_arr, rowvar=True, ddof=1, fweights=None, aweights=None) self.assertTrue(np.allclose(np_cov, cov.numpy())) - def test_tensor_cov_rowvar(self): + def test_tensor_cov_default(self): + with _test_eager_guard(): + self.func_test_tensor_cov_default() + self.func_test_tensor_cov_default() + + def func_test_tensor_cov_rowvar(self): typelist = ['float64'] places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -80,7 +86,12 @@ class Cov_Test(unittest.TestCase): np_arr, rowvar=False, ddof=1, fweights=None, aweights=None) self.assertTrue(np.allclose(np_cov, cov.numpy())) - def test_tensor_cov_ddof(self): + def test_tensor_cov_rowvar(self): + with _test_eager_guard(): + self.func_test_tensor_cov_rowvar() + self.func_test_tensor_cov_rowvar() + + def func_test_tensor_cov_ddof(self): typelist = ['float64'] places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -104,7 +115,12 @@ class Cov_Test(unittest.TestCase): np_arr, rowvar=True, ddof=0, fweights=None, aweights=None) self.assertTrue(np.allclose(np_cov, cov.numpy())) - def test_tensor_cov_fweights(self): + def test_tensor_cov_ddof(self): + with _test_eager_guard(): + self.func_test_tensor_cov_ddof() + self.func_test_tensor_cov_ddof() + + def func_test_tensor_cov_fweights(self): typelist = ['float64'] places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -131,7 +147,12 @@ class Cov_Test(unittest.TestCase): np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=None) self.assertTrue(np.allclose(np_cov, cov.numpy())) - def test_tensor_cov_aweights(self): + def test_tensor_cov_fweights(self): + with _test_eager_guard(): + self.func_test_tensor_cov_fweights() + self.func_test_tensor_cov_fweights() + + def func_test_tensor_cov_aweights(self): typelist = ['float64'] places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -158,7 +179,12 @@ class Cov_Test(unittest.TestCase): np_arr, rowvar=True, ddof=1, fweights=None, aweights=np_aw) self.assertTrue(np.allclose(np_cov, cov.numpy())) - def test_tensor_cov_weights(self): + def test_tensor_cov_aweights(self): + with _test_eager_guard(): + self.func_test_tensor_cov_aweights() + self.func_test_tensor_cov_aweights() + + def func_test_tensor_cov_weights(self): typelist = ['float64'] places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -187,6 +213,11 @@ class Cov_Test(unittest.TestCase): np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=np_aw) self.assertTrue(np.allclose(np_cov, cov.numpy())) + def test_tensor_cov_weights(self): + with _test_eager_guard(): + self.func_test_tensor_cov_weights() + self.func_test_tensor_cov_weights() + class Cov_Test2(Cov_Test): def setUp(self): @@ -203,7 +234,7 @@ class Cov_Test3(unittest.TestCase): self.fw_s = 1. self.aw_s = 1. - def test_errors(self): + def func_test_errors(self): def test_err(): np_arr = np.random.rand(*self.shape).astype('float64') np_fw = self.fw_s * np.random.rand( @@ -221,6 +252,11 @@ class Cov_Test3(unittest.TestCase): self.assertRaises(ValueError, test_err) + def test_errors(self): + with _test_eager_guard(): + self.func_test_errors() + self.func_test_errors() + #Input(fweights) only support N-D (N<=1) tensor class Cov_Test4(Cov_Test3): diff --git a/python/paddle/fluid/tests/unittests/test_inner.py b/python/paddle/fluid/tests/unittests/test_inner.py index ff9f15ebbfc8204de042d7731ed94035152f46eb..2174c20c9a0954fdf09bf8ff8afeffac790c43cc 100644 --- a/python/paddle/fluid/tests/unittests/test_inner.py +++ b/python/paddle/fluid/tests/unittests/test_inner.py @@ -19,6 +19,7 @@ import numpy as np import paddle from paddle.static import Program, program_guard +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class TestMultiplyApi(unittest.TestCase): @@ -48,7 +49,7 @@ class TestMultiplyApi(unittest.TestCase): res = paddle.inner(x, y) return res.numpy() - def test_multiply(self): + def func_test_multiply(self): np.random.seed(7) # test static computation graph: 3-d array @@ -105,9 +106,14 @@ class TestMultiplyApi(unittest.TestCase): res = self._run_dynamic_graph_case(x_data, y_data) self.assertTrue(np.allclose(res, np.inner(x_data, y_data))) + def test_multiply(self): + with _test_eager_guard(): + self.func_test_multiply() + self.func_test_multiply() + class TestMultiplyError(unittest.TestCase): - def test_errors(self): + def func_test_errors(self): # test static computation graph: dtype can not be int8 paddle.enable_static() with program_guard(Program(), Program()): @@ -161,6 +167,11 @@ class TestMultiplyError(unittest.TestCase): y_data = np.random.randn(200).astype(np.float32) self.assertRaises(ValueError, paddle.inner, x_data, y_data) + def test_errors(self): + with _test_eager_guard(): + self.func_test_errors() + self.func_test_errors() + if __name__ == '__main__': paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_multiply.py b/python/paddle/fluid/tests/unittests/test_multiply.py index 3fd6e3f0c865abcec9bb42d83c021fba94a5cf0b..e8463ed8ad235b3f0a06232c211dd920770bd6fb 100755 --- a/python/paddle/fluid/tests/unittests/test_multiply.py +++ b/python/paddle/fluid/tests/unittests/test_multiply.py @@ -20,6 +20,7 @@ import numpy as np import paddle import paddle.tensor as tensor from paddle.static import Program, program_guard +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class TestMultiplyApi(unittest.TestCase): @@ -49,7 +50,7 @@ class TestMultiplyApi(unittest.TestCase): res = paddle.multiply(x, y) return res.numpy() - def test_multiply(self): + def func_test_multiply(self): np.random.seed(7) # test static computation graph: 1-d array @@ -100,9 +101,14 @@ class TestMultiplyApi(unittest.TestCase): res = self._run_dynamic_graph_case(x_data, y_data) self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) + def test_multiply(self): + with _test_eager_guard(): + self.func_test_multiply() + self.func_test_multiply() + class TestMultiplyError(unittest.TestCase): - def test_errors(self): + def func_test_errors(self): # test static computation graph: dtype can not be int8 paddle.enable_static() with program_guard(Program(), Program()): @@ -175,6 +181,11 @@ class TestMultiplyError(unittest.TestCase): y_data = np.random.randn(200).astype(np.float32) self.assertRaises(ValueError, paddle.multiply, x_data, y_data) + def test_errors(self): + with _test_eager_guard(): + self.func_test_errors() + self.func_test_errors() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_outer.py b/python/paddle/fluid/tests/unittests/test_outer.py index 1b11a71bb2f09811669014ddb12da7febf24a42c..2c4d64344cfc758d88bc70a310dc26789bdccabb 100644 --- a/python/paddle/fluid/tests/unittests/test_outer.py +++ b/python/paddle/fluid/tests/unittests/test_outer.py @@ -19,6 +19,7 @@ import numpy as np import paddle from paddle.static import Program, program_guard +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class TestMultiplyApi(unittest.TestCase): @@ -48,7 +49,7 @@ class TestMultiplyApi(unittest.TestCase): res = paddle.outer(x, y) return res.numpy() - def test_multiply(self): + def func_test_multiply(self): np.random.seed(7) # test static computation graph: 3-d array @@ -105,9 +106,14 @@ class TestMultiplyApi(unittest.TestCase): res = self._run_dynamic_graph_case(x_data, y_data) self.assertTrue(np.allclose(res, np.outer(x_data, y_data))) + def test_multiply(self): + with _test_eager_guard(): + self.func_test_multiply() + self.func_test_multiply() + class TestMultiplyError(unittest.TestCase): - def test_errors(self): + def func_test_errors(self): # test static computation graph: dtype can not be int8 paddle.enable_static() with program_guard(Program(), Program()): @@ -148,6 +154,11 @@ class TestMultiplyError(unittest.TestCase): y_data = np.random.randn(200).astype(np.float32) self.assertRaises(ValueError, paddle.outer, x_data, y_data) + def test_errors(self): + with _test_eager_guard(): + self.func_test_errors() + self.func_test_errors() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index caed181d453052eefc911b63695bd6c5a18f4b0e..6e7e5678be0b0b5d437200ff8716b6b985d63e6a 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -127,7 +127,12 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): "\n\tFaild to convert input data to a regular ndarray :\n\t - Usually " "this means the input data contains nested lists with different lengths. " ) - elif isinstance(data, (paddle.Tensor, core.eager.Tensor)): + elif isinstance(data, paddle.Tensor) and not in_dygraph_mode(): + data = data._copy_to(place, False) + data = _handle_dtype(data, dtype) + data.stop_gradient = stop_gradient + return data + elif isinstance(data, core.eager.Tensor) and in_dygraph_mode(): data = data._copy_to(place, False) data = _handle_dtype(data, dtype) data.stop_gradient = stop_gradient @@ -136,7 +141,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): # should't expose it to users, just for internal use. # convert core.Tensor/core.LoDTensor to VarBase first # Currenly, there is no copy when places are same - data = paddle.Tensor(data) + if in_dygraph_mode(): + data = core.eager.Tensor(data) + else: + data = paddle.Tensor(data) if not data.place._equals(place): data = data._copy_to(place, False) data = _handle_dtype(data, dtype) diff --git a/python/paddle/tests/test_dlpack.py b/python/paddle/tests/test_dlpack.py index 3a3f748bb991e78fa579c8c94bb80cb190e25e02..458efd047de6898ec23276f3a9bc2eb73cbf6432 100644 --- a/python/paddle/tests/test_dlpack.py +++ b/python/paddle/tests/test_dlpack.py @@ -18,21 +18,31 @@ import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class TestDLPack(unittest.TestCase): - def test_dlpack_dygraph(self): + def func_test_dlpack_dygraph(self): paddle.disable_static() tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int')) dlpack = paddle.utils.dlpack.to_dlpack(tensor) out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) - self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor)) + if paddle.fluid.framework.in_dygraph_mode(): + self.assertTrue( + isinstance(out_from_dlpack, paddle.fluid.core.eager.Tensor)) + else: + self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor)) self.assertTrue( np.array_equal( np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype( 'int'))) - def test_dlpack_tensor_larger_than_2dim(self): + def test_dlpack_dygraph(self): + with _test_eager_guard(): + self.func_test_dlpack_dygraph() + self.func_test_dlpack_dygraph() + + def func_test_dlpack_tensor_larger_than_2dim(self): paddle.disable_static() numpy_data = np.random.randn(4, 5, 6) t = paddle.to_tensor(numpy_data) @@ -41,6 +51,11 @@ class TestDLPack(unittest.TestCase): out = paddle.utils.dlpack.from_dlpack(dlpack) self.assertTrue(np.allclose(numpy_data, out.numpy())) + def test_dlpack_tensor_larger_than_2dim(self): + with _test_eager_guard(): + self.func_test_dlpack_tensor_larger_than_2dim() + self.func_test_dlpack_tensor_larger_than_2dim() + def test_dlpack_static(self): paddle.enable_static() tensor = fluid.create_lod_tensor( @@ -67,7 +82,7 @@ class TestDLPack(unittest.TestCase): np.array(gout_from_dlpack), np.array([[1], [2], [3], [4]]).astype('int'))) - def test_dlpack_dtype_conversion(self): + def func_test_dlpack_dtype_conversion(self): paddle.disable_static() # DLpack does not explicitly support bool data type. dtypes = [ @@ -98,15 +113,30 @@ class TestDLPack(unittest.TestCase): self.assertEqual(x.dtype, o.dtype) self.assertTrue(np.allclose(x.numpy(), o.numpy())) + def test_dlpack_dtype_conversion(self): + with _test_eager_guard(): + self.func_test_dlpack_dtype_conversion() + self.func_test_dlpack_dtype_conversion() + class TestRaiseError(unittest.TestCase): - def test_from_dlpack_raise_type_error(self): + def func_test_from_dlpack_raise_type_error(self): self.assertRaises(TypeError, paddle.utils.dlpack.from_dlpack, np.zeros(5)) - def test_to_dlpack_raise_type_error(self): + def test_from_dlpack_raise_type_error(self): + with _test_eager_guard(): + self.func_test_from_dlpack_raise_type_error() + self.func_test_from_dlpack_raise_type_error() + + def func_test_to_dlpack_raise_type_error(self): self.assertRaises(TypeError, paddle.utils.dlpack.to_dlpack, np.zeros(5)) + def test_to_dlpack_raise_type_error(self): + with _test_eager_guard(): + self.func_test_to_dlpack_raise_type_error() + self.func_test_to_dlpack_raise_type_error() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index a5e375fff148eab21b67d47b9f8abe180bb7efc6..1ece08daa27274f5d4b70f9c39464aaab38f5a62 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -48,7 +48,7 @@ def to_dlpack(x): """ if _non_static_mode(): - if not isinstance(x, paddle.Tensor): + if not isinstance(x, (paddle.Tensor, paddle.fluid.core.eager.Tensor)): raise TypeError( "The type of 'x' in to_dlpack must be paddle.Tensor," " but received {}.".format(type(x)))