未验证 提交 4d300224 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] dlpack (#40811)

* dlpack eager, test=develop

* eager test_base_layer, test=develop

* fix error report, test=develop

* eager _getitem_from_offset, test=develop

* refine, test=develop

* refine offset, test=develop

* add test_inner test_outer, test=develop

* refine, test=develop

* refine, test=develop
上级 13f1641d
...@@ -137,7 +137,10 @@ class EagerUtils { ...@@ -137,7 +137,10 @@ class EagerUtils {
template <typename T, typename... Args> template <typename T, typename... Args>
static bool ComputeRequireGrad(T trace_backward, Args&&... args) { static bool ComputeRequireGrad(T trace_backward, Args&&... args) {
if (!trace_backward) return false; if (!trace_backward) {
VLOG(6) << "Do not require grad because trace_backward = false";
return false;
}
auto iter = ComputeRequireGradIter(); auto iter = ComputeRequireGradIter();
iter.apply(std::forward<Args>(args)...); iter.apply(std::forward<Args>(args)...);
......
...@@ -25,7 +25,6 @@ limitations under the License. */ ...@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -33,7 +32,6 @@ limitations under the License. */ ...@@ -33,7 +32,6 @@ limitations under the License. */
#include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/slice_utils.h" #include "paddle/fluid/pybind/slice_utils.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
...@@ -41,6 +39,9 @@ limitations under the License. */ ...@@ -41,6 +39,9 @@ limitations under the License. */
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "pybind11/detail/internals.h" #include "pybind11/detail/internals.h"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/pybind/tensor_py.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -682,6 +683,103 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, ...@@ -682,6 +683,103 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
PADDLE_ENFORCE_NOT_NULL(
ptr, platform::errors::InvalidArgument("%s is not a DenseTensor.",
self->tensor.name()));
const auto& tensor = *ptr;
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor of %s is Empty, please check if it has no data.",
self->tensor.name()));
const auto& tensor_dims = tensor.dims();
std::vector<size_t> dims(tensor_dims.size());
std::vector<size_t> strides(tensor_dims.size());
size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
strides[i] = numel;
dims[i] = static_cast<size_t>(tensor_dims[i]);
numel *= dims[i];
}
size_t offset = 0;
if (PyTuple_Size(args) == 0) {
PADDLE_ENFORCE_EQ(numel, 1,
platform::errors::InvalidArgument(
"only one element tensors can be converted to Python "
"scalars when no input coordinates"));
} else if (PyTuple_Size(args) == 1) {
offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
PADDLE_ENFORCE_LT(
offset, numel,
platform::errors::InvalidArgument(
"index %d is out of bounds for size %d", offset, numel));
} else {
PADDLE_ENFORCE_EQ(PyTuple_Size(args), dims.size(),
platform::errors::InvalidArgument(
"incorrect number of indices for Tensor"));
for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
PADDLE_ENFORCE_LT(
index, dims[i],
platform::errors::InvalidArgument(
"index %d is out fo bounds for axis %d with size %d", index, i,
dims[i]));
offset += index * strides[i];
}
}
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(uint16_t, DataType::UINT16) \
_(int32_t, DataType::INT32) \
_(uint32_t, DataType::UINT32) \
_(int64_t, DataType::INT64) \
_(uint64_t, DataType::UINT64) \
_(bfloat16, DataType::BFLOAT16) \
_(float16, DataType::FLOAT16) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
#define TENSOR_TO_PY_SCALAR(T, proto_type) \
if (tensor.dtype() == proto_type) { \
auto numpy_dtype = TensorDtype2NumpyDtype(proto_type); \
T b = paddle::pybind::TensorGetElement<T>(tensor, offset); \
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; \
Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank]; \
py_dims[0] = 1; \
py_strides[0] = 1; \
auto& api = pybind11::detail::npy_api::get(); \
PyObject* array = api.PyArray_NewFromDescr_( \
api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype), 1, \
py_dims, py_strides, nullptr, \
pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ | \
pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_, \
nullptr); \
std::memcpy( \
reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
static_cast<void*>(&b), sizeof(b)); \
return array; \
}
PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported tensor data type: %s", tensor.dtype()));
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1252,6 +1350,9 @@ PyMethodDef variable_methods[] = { ...@@ -1252,6 +1350,9 @@ PyMethodDef variable_methods[] = {
{"_getitem_index_not_tensor", {"_getitem_index_not_tensor",
(PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor, (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_getitem_from_offset",
(PyCFunction)(void (*)(void))tensor__getitem_from_offset,
METH_VARARGS | METH_KEYWORDS, NULL},
{"__setitem_eager_tensor__", {"__setitem_eager_tensor__",
(PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor, (PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
......
...@@ -633,8 +633,15 @@ paddle::optional<const paddle::experimental::Tensor&> GetOptionalTensorFromArgs( ...@@ -633,8 +633,15 @@ paddle::optional<const paddle::experimental::Tensor&> GetOptionalTensorFromArgs(
return paddle::none; return paddle::none;
} }
return paddle::make_optional<const paddle::experimental::Tensor&>( if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
reinterpret_cast<TensorObject*>(obj)->tensor); return paddle::make_optional<const paddle::experimental::Tensor&>(
reinterpret_cast<TensorObject*>(obj)->tensor);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got %s", op_type,
arg_name, arg_idx,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
} }
static paddle::experimental::Tensor& GetTensorFromPyObject( static paddle::experimental::Tensor& GetTensorFromPyObject(
...@@ -654,7 +661,14 @@ static paddle::experimental::Tensor& GetTensorFromPyObject( ...@@ -654,7 +661,14 @@ static paddle::experimental::Tensor& GetTensorFromPyObject(
return emptytensor; return emptytensor;
} }
return reinterpret_cast<TensorObject*>(obj)->tensor; if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
return reinterpret_cast<TensorObject*>(obj)->tensor;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got %s", op_type,
arg_name, arg_idx,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
} }
// For Intermediate State Dygraph, // For Intermediate State Dygraph,
...@@ -744,7 +758,14 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type, ...@@ -744,7 +758,14 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type,
return &emptytensor; return &emptytensor;
} }
return &(reinterpret_cast<TensorObject*>(obj)->tensor); if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
return &(reinterpret_cast<TensorObject*>(obj)->tensor);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got %s", op_type,
arg_name, arg_idx,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
} }
std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs( std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
......
...@@ -764,7 +764,7 @@ class Layer(object): ...@@ -764,7 +764,7 @@ class Layer(object):
elif tensor is not None and not (type(tensor) == core.VarBase or elif tensor is not None and not (type(tensor) == core.VarBase or
type(tensor) == core.eager.Tensor): type(tensor) == core.eager.Tensor):
raise TypeError( raise TypeError(
"The registered buffer should be a core.VarBase, but received {}.". "The registered buffer should be a Paddle.Tensor, but received {}.".
format(type(tensor).__name__)) format(type(tensor).__name__))
else: else:
self._buffers[name] = tensor self._buffers[name] = tensor
...@@ -1158,8 +1158,7 @@ class Layer(object): ...@@ -1158,8 +1158,7 @@ class Layer(object):
layers[name] = None layers[name] = None
else: else:
_buffers = self.__dict__.get('_buffers', None) _buffers = self.__dict__.get('_buffers', None)
if type(value) == core.VarBase or \ if isinstance(value, (core.VarBase, core.eager.Tensor)):
type(value) == core.eager.Tensor:
if _buffers is None: if _buffers is None:
raise ValueError( raise ValueError(
"super(YourLayer, self).__init__() should be called first" "super(YourLayer, self).__init__() should be called first"
......
...@@ -18,8 +18,9 @@ import numpy as np ...@@ -18,8 +18,9 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import ParamBase from paddle.fluid.framework import ParamBase, EagerParamBase
from paddle.jit import ProgramTranslator from paddle.jit import ProgramTranslator
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
class L1(fluid.Layer): class L1(fluid.Layer):
...@@ -57,7 +58,7 @@ class L3(fluid.Layer): ...@@ -57,7 +58,7 @@ class L3(fluid.Layer):
class TestBaseLayer(unittest.TestCase): class TestBaseLayer(unittest.TestCase):
def test_one_level(self): def func_test_one_level(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
l = L1() l = L1()
ret = l() ret = l()
...@@ -68,7 +69,12 @@ class TestBaseLayer(unittest.TestCase): ...@@ -68,7 +69,12 @@ class TestBaseLayer(unittest.TestCase):
idx += 1 idx += 1
self.assertTrue(np.allclose(ret.numpy(), 0.2 * np.ones([2, 2]))) self.assertTrue(np.allclose(ret.numpy(), 0.2 * np.ones([2, 2])))
def test_three_level(self): def test_one_level(self):
with _test_eager_guard():
self.func_test_one_level()
self.func_test_one_level()
def func_test_three_level(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
l = L3() l = L3()
expected_names = [ expected_names = [
...@@ -88,7 +94,12 @@ class TestBaseLayer(unittest.TestCase): ...@@ -88,7 +94,12 @@ class TestBaseLayer(unittest.TestCase):
ret = l() ret = l()
self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2]))) self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2])))
def test_add_parameter_with_error(self): def test_three_level(self):
with _test_eager_guard():
self.func_test_three_level()
self.func_test_three_level()
def func_test_add_parameter_with_error(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
param = net.create_parameter(shape=[1]) param = net.create_parameter(shape=[1])
...@@ -113,6 +124,11 @@ class TestBaseLayer(unittest.TestCase): ...@@ -113,6 +124,11 @@ class TestBaseLayer(unittest.TestCase):
net._loaddict_holder[load_param.name] = load_param net._loaddict_holder[load_param.name] = load_param
net.add_parameter("load_param", load_param) net.add_parameter("load_param", load_param)
def test_add_parameter_with_error(self):
with _test_eager_guard():
self.func_test_add_parameter_with_error()
self.func_test_add_parameter_with_error()
class BufferLayer(fluid.Layer): class BufferLayer(fluid.Layer):
def __init__(self): def __init__(self):
...@@ -140,7 +156,7 @@ class BufferNet(fluid.Layer): ...@@ -140,7 +156,7 @@ class BufferNet(fluid.Layer):
class TestBuffer(unittest.TestCase): class TestBuffer(unittest.TestCase):
def test_buffers_and_named_buffers(self): def func_test_buffers_and_named_buffers(self):
def names(named_buffers): def names(named_buffers):
return [name for name, _ in named_buffers] return [name for name, _ in named_buffers]
...@@ -161,7 +177,12 @@ class TestBuffer(unittest.TestCase): ...@@ -161,7 +177,12 @@ class TestBuffer(unittest.TestCase):
names(net.named_buffers(include_sublayers=False)), names(net.named_buffers(include_sublayers=False)),
['net_buffer', 'new_buffer']) ['net_buffer', 'new_buffer'])
def test_register_buffer_with_error(self): def test_buffers_and_named_buffers(self):
with _test_eager_guard():
self.func_test_buffers_and_named_buffers()
self.func_test_buffers_and_named_buffers()
def func_test_register_buffer_with_error(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var = to_variable(np.zeros([1])) var = to_variable(np.zeros([1]))
...@@ -171,8 +192,13 @@ class TestBuffer(unittest.TestCase): ...@@ -171,8 +192,13 @@ class TestBuffer(unittest.TestCase):
net.register_buffer(12, var) net.register_buffer(12, var)
with self.assertRaisesRegexp(TypeError, with self.assertRaisesRegexp(TypeError,
"buffer should be a core.VarBase"): "buffer should be a Paddle.Tensor"):
net.register_buffer("buffer_name", ParamBase([2, 2], 'float32')) if in_dygraph_mode():
net.register_buffer("buffer_name",
EagerParamBase([2, 2], 'float32'))
else:
net.register_buffer("buffer_name",
ParamBase([2, 2], 'float32'))
with self.assertRaisesRegexp(KeyError, with self.assertRaisesRegexp(KeyError,
"name of buffer can not contain"): "name of buffer can not contain"):
...@@ -187,11 +213,19 @@ class TestBuffer(unittest.TestCase): ...@@ -187,11 +213,19 @@ class TestBuffer(unittest.TestCase):
net.register_buffer("attr_name", var) net.register_buffer("attr_name", var)
del net.attr_name del net.attr_name
net.attr_name = ParamBase([2, 2], 'float32') if in_dygraph_mode():
net.attr_name = EagerParamBase([2, 2], 'float32')
else:
net.attr_name = ParamBase([2, 2], 'float32')
with self.assertRaisesRegexp(KeyError, "already exists"): with self.assertRaisesRegexp(KeyError, "already exists"):
net.register_buffer("attr_name", var) net.register_buffer("attr_name", var)
def test_register_buffer_same_name(self): def test_register_buffer_with_error(self):
with _test_eager_guard():
self.func_test_register_buffer_with_error()
self.func_test_register_buffer_with_error()
def func_test_register_buffer_same_name(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([1])) var1 = to_variable(np.zeros([1]))
...@@ -205,7 +239,12 @@ class TestBuffer(unittest.TestCase): ...@@ -205,7 +239,12 @@ class TestBuffer(unittest.TestCase):
net.register_buffer("buffer_name", var3) net.register_buffer("buffer_name", var3)
self.assert_var_base_equal(net.buffer_name, var3) self.assert_var_base_equal(net.buffer_name, var3)
def test_buffer_not_persistable(self): def test_register_buffer_same_name(self):
with _test_eager_guard():
self.func_test_register_buffer_same_name()
self.func_test_register_buffer_same_name()
def func_test_buffer_not_persistable(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([1])) var1 = to_variable(np.zeros([1]))
...@@ -214,7 +253,12 @@ class TestBuffer(unittest.TestCase): ...@@ -214,7 +253,12 @@ class TestBuffer(unittest.TestCase):
self.assertEqual(len(net.buffers()), 1) self.assertEqual(len(net.buffers()), 1)
self.assertEqual(len(net.state_dict()), 0) self.assertEqual(len(net.state_dict()), 0)
def test_buffer_not_persistable_del(self): def test_buffer_not_persistable(self):
with _test_eager_guard():
self.func_test_buffer_not_persistable()
self.func_test_buffer_not_persistable()
def func_test_buffer_not_persistable_del(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([1])) var1 = to_variable(np.zeros([1]))
...@@ -222,7 +266,12 @@ class TestBuffer(unittest.TestCase): ...@@ -222,7 +266,12 @@ class TestBuffer(unittest.TestCase):
del net.buffer_name del net.buffer_name
self.assertEqual(len(net.buffers()), 0) self.assertEqual(len(net.buffers()), 0)
def test_buffer_not_persistable_overwrite(self): def test_buffer_not_persistable_del(self):
with _test_eager_guard():
self.func_test_buffer_not_persistable_del()
self.func_test_buffer_not_persistable_del()
def func_test_buffer_not_persistable_overwrite(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([1])) var1 = to_variable(np.zeros([1]))
...@@ -238,7 +287,12 @@ class TestBuffer(unittest.TestCase): ...@@ -238,7 +287,12 @@ class TestBuffer(unittest.TestCase):
self.assertEqual(len(net.buffers()), 1) self.assertEqual(len(net.buffers()), 1)
self.assertEqual(len(net.state_dict()), 0) self.assertEqual(len(net.state_dict()), 0)
def test_buffer_not_persistable_assign(self): def test_buffer_not_persistable_overwrite(self):
with _test_eager_guard():
self.func_test_buffer_not_persistable_overwrite()
self.func_test_buffer_not_persistable_overwrite()
def func_test_buffer_not_persistable_assign(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([1])) var1 = to_variable(np.zeros([1]))
...@@ -255,18 +309,31 @@ class TestBuffer(unittest.TestCase): ...@@ -255,18 +309,31 @@ class TestBuffer(unittest.TestCase):
self.assertEqual(len(net.state_dict()), 0) self.assertEqual(len(net.state_dict()), 0)
# Re-assign a ParamBase will remove the buffer. # Re-assign a ParamBase will remove the buffer.
net.buffer_name = ParamBase([2, 2], 'float32') if in_dygraph_mode():
net.buffer_name = EagerParamBase([2, 2], 'float32')
else:
net.buffer_name = ParamBase([2, 2], 'float32')
self.assertEqual(len(net.buffers()), 0) self.assertEqual(len(net.buffers()), 0)
self.assertEqual(len(net.state_dict()), 1) self.assertEqual(len(net.state_dict()), 1)
def test_buffer_not_persistable_load(self): def test_buffer_not_persistable_assign(self):
with _test_eager_guard():
self.func_test_buffer_not_persistable_assign()
self.func_test_buffer_not_persistable_assign()
def func_test_buffer_not_persistable_load(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([1])) var1 = to_variable(np.zeros([1]))
net.register_buffer("buffer_name", var1, persistable=False) net.register_buffer("buffer_name", var1, persistable=False)
net.load_dict({}) net.load_dict({})
def test_buffer_state_dict(self): def test_buffer_not_persistable_load(self):
with _test_eager_guard():
self.func_test_buffer_not_persistable_load()
self.func_test_buffer_not_persistable_load()
def func_test_buffer_state_dict(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
net = fluid.Layer() net = fluid.Layer()
var1 = to_variable(np.zeros([2, 3])) var1 = to_variable(np.zeros([2, 3]))
...@@ -286,6 +353,11 @@ class TestBuffer(unittest.TestCase): ...@@ -286,6 +353,11 @@ class TestBuffer(unittest.TestCase):
self.assert_var_base_equal(net_load.buffer_var1, var1) self.assert_var_base_equal(net_load.buffer_var1, var1)
def test_buffer_state_dict(self):
with _test_eager_guard():
self.func_test_buffer_state_dict()
self.func_test_buffer_state_dict()
def assert_var_base_equal(self, var1, var2): def assert_var_base_equal(self, var1, var2):
self.assertTrue(np.array_equal(var1.numpy(), var2.numpy())) self.assertTrue(np.array_equal(var1.numpy(), var2.numpy()))
...@@ -308,7 +380,7 @@ class BufferNetWithModification(paddle.nn.Layer): ...@@ -308,7 +380,7 @@ class BufferNetWithModification(paddle.nn.Layer):
class TestModifiedBuffer(unittest.TestCase): class TestModifiedBuffer(unittest.TestCase):
def setUp(self): def funcsetUp(self):
paddle.disable_static() paddle.disable_static()
self.prog_trans = ProgramTranslator() self.prog_trans = ProgramTranslator()
self.shape = [10, 16] self.shape = [10, 16]
...@@ -322,7 +394,8 @@ class TestModifiedBuffer(unittest.TestCase): ...@@ -322,7 +394,8 @@ class TestModifiedBuffer(unittest.TestCase):
return out, net.buffer1, net.buffer2 return out, net.buffer1, net.buffer2
def test_modified(self): def func_test_modified(self):
self.funcsetUp()
dy_outs = self._run(False) dy_outs = self._run(False)
st_outs = self._run(True) st_outs = self._run(True)
...@@ -330,9 +403,14 @@ class TestModifiedBuffer(unittest.TestCase): ...@@ -330,9 +403,14 @@ class TestModifiedBuffer(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy())) np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy()))
def test_modified(self):
with _test_eager_guard():
self.func_test_modified()
self.func_test_modified()
class TestLayerTo(unittest.TestCase): class TestLayerTo(unittest.TestCase):
def setUp(self): def funcsetUp(self):
paddle.disable_static() paddle.disable_static()
self.linear = paddle.nn.Linear(2, 2) self.linear = paddle.nn.Linear(2, 2)
self.new_grad = np.random.random([2, 2]) self.new_grad = np.random.random([2, 2])
...@@ -343,7 +421,7 @@ class TestLayerTo(unittest.TestCase): ...@@ -343,7 +421,7 @@ class TestLayerTo(unittest.TestCase):
sublayer = paddle.nn.Conv1D(3, 2, 3) sublayer = paddle.nn.Conv1D(3, 2, 3)
self.linear.add_sublayer("1", sublayer) self.linear.add_sublayer("1", sublayer)
def test_to_api(self): def func_test_to_api(self):
self.linear.to(dtype='double') self.linear.to(dtype='double')
self.assertEqual(self.linear.weight.dtype, self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
...@@ -364,7 +442,11 @@ class TestLayerTo(unittest.TestCase): ...@@ -364,7 +442,11 @@ class TestLayerTo(unittest.TestCase):
self.assertEqual(self.linear.weight._grad_ivar().dtype, self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters(): for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) if in_dygraph_mode():
self.assertTrue(
isinstance(p, paddle.fluid.framework.EagerParamBase))
else:
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
if paddle.fluid.is_compiled_with_cuda(): if paddle.fluid.is_compiled_with_cuda():
self.linear.to(device=paddle.CUDAPlace(0)) self.linear.to(device=paddle.CUDAPlace(0))
...@@ -387,7 +469,12 @@ class TestLayerTo(unittest.TestCase): ...@@ -387,7 +469,12 @@ class TestLayerTo(unittest.TestCase):
self.assertEqual( self.assertEqual(
self.linear.weight._grad_ivar().place.gpu_device_id(), 0) self.linear.weight._grad_ivar().place.gpu_device_id(), 0)
for p in self.linear.parameters(): for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) if in_dygraph_mode():
self.assertTrue(
isinstance(p, paddle.fluid.framework.EagerParamBase))
else:
self.assertTrue(
isinstance(p, paddle.fluid.framework.ParamBase))
self.linear.to(device=paddle.CPUPlace()) self.linear.to(device=paddle.CPUPlace())
self.assertTrue(self.linear.weight.place.is_cpu_place()) self.assertTrue(self.linear.weight.place.is_cpu_place())
...@@ -403,7 +490,7 @@ class TestLayerTo(unittest.TestCase): ...@@ -403,7 +490,7 @@ class TestLayerTo(unittest.TestCase):
self.assertRaises(AssertionError, self.linear.to, blocking=1) self.assertRaises(AssertionError, self.linear.to, blocking=1)
def test_to_api_paddle_dtype(self): def func_test_to_api_paddle_dtype(self):
self.linear.to(dtype=paddle.float64) self.linear.to(dtype=paddle.float64)
self.assertEqual(self.linear.weight.dtype, self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
...@@ -424,9 +511,13 @@ class TestLayerTo(unittest.TestCase): ...@@ -424,9 +511,13 @@ class TestLayerTo(unittest.TestCase):
self.assertEqual(self.linear.weight._grad_ivar().dtype, self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters(): for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) if in_dygraph_mode():
self.assertTrue(
isinstance(p, paddle.fluid.framework.EagerParamBase))
else:
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
def test_to_api_numpy_dtype(self): def func_test_to_api_numpy_dtype(self):
self.linear.to(dtype=np.float64) self.linear.to(dtype=np.float64)
self.assertEqual(self.linear.weight.dtype, self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
...@@ -447,7 +538,22 @@ class TestLayerTo(unittest.TestCase): ...@@ -447,7 +538,22 @@ class TestLayerTo(unittest.TestCase):
self.assertEqual(self.linear.weight._grad_ivar().dtype, self.assertEqual(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64) paddle.fluid.core.VarDesc.VarType.FP64)
for p in self.linear.parameters(): for p in self.linear.parameters():
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) if in_dygraph_mode():
self.assertTrue(
isinstance(p, paddle.fluid.framework.EagerParamBase))
else:
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
def test_main(self):
with _test_eager_guard():
self.funcsetUp()
self.func_test_to_api()
self.func_test_to_api_paddle_dtype()
self.func_test_to_api_numpy_dtype()
self.funcsetUp()
self.func_test_to_api()
self.func_test_to_api_paddle_dtype()
self.func_test_to_api_numpy_dtype()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,6 +17,7 @@ import unittest ...@@ -17,6 +17,7 @@ import unittest
import numpy as np import numpy as np
import six import six
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
def numpy_cov(np_arr, rowvar=True, ddof=1, fweights=None, aweights=None): def numpy_cov(np_arr, rowvar=True, ddof=1, fweights=None, aweights=None):
...@@ -32,7 +33,7 @@ class Cov_Test(unittest.TestCase): ...@@ -32,7 +33,7 @@ class Cov_Test(unittest.TestCase):
self.shape = [20, 10] self.shape = [20, 10]
self.weightshape = [10] self.weightshape = [10]
def test_tensor_cov_default(self): def func_test_tensor_cov_default(self):
typelist = ['float64'] typelist = ['float64']
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -56,7 +57,12 @@ class Cov_Test(unittest.TestCase): ...@@ -56,7 +57,12 @@ class Cov_Test(unittest.TestCase):
np_arr, rowvar=True, ddof=1, fweights=None, aweights=None) np_arr, rowvar=True, ddof=1, fweights=None, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy())) self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_rowvar(self): def test_tensor_cov_default(self):
with _test_eager_guard():
self.func_test_tensor_cov_default()
self.func_test_tensor_cov_default()
def func_test_tensor_cov_rowvar(self):
typelist = ['float64'] typelist = ['float64']
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -80,7 +86,12 @@ class Cov_Test(unittest.TestCase): ...@@ -80,7 +86,12 @@ class Cov_Test(unittest.TestCase):
np_arr, rowvar=False, ddof=1, fweights=None, aweights=None) np_arr, rowvar=False, ddof=1, fweights=None, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy())) self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_ddof(self): def test_tensor_cov_rowvar(self):
with _test_eager_guard():
self.func_test_tensor_cov_rowvar()
self.func_test_tensor_cov_rowvar()
def func_test_tensor_cov_ddof(self):
typelist = ['float64'] typelist = ['float64']
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -104,7 +115,12 @@ class Cov_Test(unittest.TestCase): ...@@ -104,7 +115,12 @@ class Cov_Test(unittest.TestCase):
np_arr, rowvar=True, ddof=0, fweights=None, aweights=None) np_arr, rowvar=True, ddof=0, fweights=None, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy())) self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_fweights(self): def test_tensor_cov_ddof(self):
with _test_eager_guard():
self.func_test_tensor_cov_ddof()
self.func_test_tensor_cov_ddof()
def func_test_tensor_cov_fweights(self):
typelist = ['float64'] typelist = ['float64']
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -131,7 +147,12 @@ class Cov_Test(unittest.TestCase): ...@@ -131,7 +147,12 @@ class Cov_Test(unittest.TestCase):
np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=None) np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy())) self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_aweights(self): def test_tensor_cov_fweights(self):
with _test_eager_guard():
self.func_test_tensor_cov_fweights()
self.func_test_tensor_cov_fweights()
def func_test_tensor_cov_aweights(self):
typelist = ['float64'] typelist = ['float64']
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -158,7 +179,12 @@ class Cov_Test(unittest.TestCase): ...@@ -158,7 +179,12 @@ class Cov_Test(unittest.TestCase):
np_arr, rowvar=True, ddof=1, fweights=None, aweights=np_aw) np_arr, rowvar=True, ddof=1, fweights=None, aweights=np_aw)
self.assertTrue(np.allclose(np_cov, cov.numpy())) self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_weights(self): def test_tensor_cov_aweights(self):
with _test_eager_guard():
self.func_test_tensor_cov_aweights()
self.func_test_tensor_cov_aweights()
def func_test_tensor_cov_weights(self):
typelist = ['float64'] typelist = ['float64']
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -187,6 +213,11 @@ class Cov_Test(unittest.TestCase): ...@@ -187,6 +213,11 @@ class Cov_Test(unittest.TestCase):
np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=np_aw) np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=np_aw)
self.assertTrue(np.allclose(np_cov, cov.numpy())) self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_weights(self):
with _test_eager_guard():
self.func_test_tensor_cov_weights()
self.func_test_tensor_cov_weights()
class Cov_Test2(Cov_Test): class Cov_Test2(Cov_Test):
def setUp(self): def setUp(self):
...@@ -203,7 +234,7 @@ class Cov_Test3(unittest.TestCase): ...@@ -203,7 +234,7 @@ class Cov_Test3(unittest.TestCase):
self.fw_s = 1. self.fw_s = 1.
self.aw_s = 1. self.aw_s = 1.
def test_errors(self): def func_test_errors(self):
def test_err(): def test_err():
np_arr = np.random.rand(*self.shape).astype('float64') np_arr = np.random.rand(*self.shape).astype('float64')
np_fw = self.fw_s * np.random.rand( np_fw = self.fw_s * np.random.rand(
...@@ -221,6 +252,11 @@ class Cov_Test3(unittest.TestCase): ...@@ -221,6 +252,11 @@ class Cov_Test3(unittest.TestCase):
self.assertRaises(ValueError, test_err) self.assertRaises(ValueError, test_err)
def test_errors(self):
with _test_eager_guard():
self.func_test_errors()
self.func_test_errors()
#Input(fweights) only support N-D (N<=1) tensor #Input(fweights) only support N-D (N<=1) tensor
class Cov_Test4(Cov_Test3): class Cov_Test4(Cov_Test3):
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import paddle import paddle
from paddle.static import Program, program_guard from paddle.static import Program, program_guard
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
class TestMultiplyApi(unittest.TestCase): class TestMultiplyApi(unittest.TestCase):
...@@ -48,7 +49,7 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -48,7 +49,7 @@ class TestMultiplyApi(unittest.TestCase):
res = paddle.inner(x, y) res = paddle.inner(x, y)
return res.numpy() return res.numpy()
def test_multiply(self): def func_test_multiply(self):
np.random.seed(7) np.random.seed(7)
# test static computation graph: 3-d array # test static computation graph: 3-d array
...@@ -105,9 +106,14 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -105,9 +106,14 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data) res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data))) self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
def test_multiply(self):
with _test_eager_guard():
self.func_test_multiply()
self.func_test_multiply()
class TestMultiplyError(unittest.TestCase): class TestMultiplyError(unittest.TestCase):
def test_errors(self): def func_test_errors(self):
# test static computation graph: dtype can not be int8 # test static computation graph: dtype can not be int8
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -161,6 +167,11 @@ class TestMultiplyError(unittest.TestCase): ...@@ -161,6 +167,11 @@ class TestMultiplyError(unittest.TestCase):
y_data = np.random.randn(200).astype(np.float32) y_data = np.random.randn(200).astype(np.float32)
self.assertRaises(ValueError, paddle.inner, x_data, y_data) self.assertRaises(ValueError, paddle.inner, x_data, y_data)
def test_errors(self):
with _test_eager_guard():
self.func_test_errors()
self.func_test_errors()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import paddle import paddle
import paddle.tensor as tensor import paddle.tensor as tensor
from paddle.static import Program, program_guard from paddle.static import Program, program_guard
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
class TestMultiplyApi(unittest.TestCase): class TestMultiplyApi(unittest.TestCase):
...@@ -49,7 +50,7 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -49,7 +50,7 @@ class TestMultiplyApi(unittest.TestCase):
res = paddle.multiply(x, y) res = paddle.multiply(x, y)
return res.numpy() return res.numpy()
def test_multiply(self): def func_test_multiply(self):
np.random.seed(7) np.random.seed(7)
# test static computation graph: 1-d array # test static computation graph: 1-d array
...@@ -100,9 +101,14 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -100,9 +101,14 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data) res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) self.assertTrue(np.allclose(res, np.multiply(x_data, y_data)))
def test_multiply(self):
with _test_eager_guard():
self.func_test_multiply()
self.func_test_multiply()
class TestMultiplyError(unittest.TestCase): class TestMultiplyError(unittest.TestCase):
def test_errors(self): def func_test_errors(self):
# test static computation graph: dtype can not be int8 # test static computation graph: dtype can not be int8
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -175,6 +181,11 @@ class TestMultiplyError(unittest.TestCase): ...@@ -175,6 +181,11 @@ class TestMultiplyError(unittest.TestCase):
y_data = np.random.randn(200).astype(np.float32) y_data = np.random.randn(200).astype(np.float32)
self.assertRaises(ValueError, paddle.multiply, x_data, y_data) self.assertRaises(ValueError, paddle.multiply, x_data, y_data)
def test_errors(self):
with _test_eager_guard():
self.func_test_errors()
self.func_test_errors()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import paddle import paddle
from paddle.static import Program, program_guard from paddle.static import Program, program_guard
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
class TestMultiplyApi(unittest.TestCase): class TestMultiplyApi(unittest.TestCase):
...@@ -48,7 +49,7 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -48,7 +49,7 @@ class TestMultiplyApi(unittest.TestCase):
res = paddle.outer(x, y) res = paddle.outer(x, y)
return res.numpy() return res.numpy()
def test_multiply(self): def func_test_multiply(self):
np.random.seed(7) np.random.seed(7)
# test static computation graph: 3-d array # test static computation graph: 3-d array
...@@ -105,9 +106,14 @@ class TestMultiplyApi(unittest.TestCase): ...@@ -105,9 +106,14 @@ class TestMultiplyApi(unittest.TestCase):
res = self._run_dynamic_graph_case(x_data, y_data) res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data))) self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
def test_multiply(self):
with _test_eager_guard():
self.func_test_multiply()
self.func_test_multiply()
class TestMultiplyError(unittest.TestCase): class TestMultiplyError(unittest.TestCase):
def test_errors(self): def func_test_errors(self):
# test static computation graph: dtype can not be int8 # test static computation graph: dtype can not be int8
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -148,6 +154,11 @@ class TestMultiplyError(unittest.TestCase): ...@@ -148,6 +154,11 @@ class TestMultiplyError(unittest.TestCase):
y_data = np.random.randn(200).astype(np.float32) y_data = np.random.randn(200).astype(np.float32)
self.assertRaises(ValueError, paddle.outer, x_data, y_data) self.assertRaises(ValueError, paddle.outer, x_data, y_data)
def test_errors(self):
with _test_eager_guard():
self.func_test_errors()
self.func_test_errors()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -127,7 +127,12 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -127,7 +127,12 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
"\n\tFaild to convert input data to a regular ndarray :\n\t - Usually " "\n\tFaild to convert input data to a regular ndarray :\n\t - Usually "
"this means the input data contains nested lists with different lengths. " "this means the input data contains nested lists with different lengths. "
) )
elif isinstance(data, (paddle.Tensor, core.eager.Tensor)): elif isinstance(data, paddle.Tensor) and not in_dygraph_mode():
data = data._copy_to(place, False)
data = _handle_dtype(data, dtype)
data.stop_gradient = stop_gradient
return data
elif isinstance(data, core.eager.Tensor) and in_dygraph_mode():
data = data._copy_to(place, False) data = data._copy_to(place, False)
data = _handle_dtype(data, dtype) data = _handle_dtype(data, dtype)
data.stop_gradient = stop_gradient data.stop_gradient = stop_gradient
...@@ -136,7 +141,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -136,7 +141,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
# should't expose it to users, just for internal use. # should't expose it to users, just for internal use.
# convert core.Tensor/core.LoDTensor to VarBase first # convert core.Tensor/core.LoDTensor to VarBase first
# Currenly, there is no copy when places are same # Currenly, there is no copy when places are same
data = paddle.Tensor(data) if in_dygraph_mode():
data = core.eager.Tensor(data)
else:
data = paddle.Tensor(data)
if not data.place._equals(place): if not data.place._equals(place):
data = data._copy_to(place, False) data = data._copy_to(place, False)
data = _handle_dtype(data, dtype) data = _handle_dtype(data, dtype)
......
...@@ -18,21 +18,31 @@ import numpy as np ...@@ -18,21 +18,31 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
class TestDLPack(unittest.TestCase): class TestDLPack(unittest.TestCase):
def test_dlpack_dygraph(self): def func_test_dlpack_dygraph(self):
paddle.disable_static() paddle.disable_static()
tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int')) tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int'))
dlpack = paddle.utils.dlpack.to_dlpack(tensor) dlpack = paddle.utils.dlpack.to_dlpack(tensor)
out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack)
self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor)) if paddle.fluid.framework.in_dygraph_mode():
self.assertTrue(
isinstance(out_from_dlpack, paddle.fluid.core.eager.Tensor))
else:
self.assertTrue(isinstance(out_from_dlpack, paddle.Tensor))
self.assertTrue( self.assertTrue(
np.array_equal( np.array_equal(
np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype( np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype(
'int'))) 'int')))
def test_dlpack_tensor_larger_than_2dim(self): def test_dlpack_dygraph(self):
with _test_eager_guard():
self.func_test_dlpack_dygraph()
self.func_test_dlpack_dygraph()
def func_test_dlpack_tensor_larger_than_2dim(self):
paddle.disable_static() paddle.disable_static()
numpy_data = np.random.randn(4, 5, 6) numpy_data = np.random.randn(4, 5, 6)
t = paddle.to_tensor(numpy_data) t = paddle.to_tensor(numpy_data)
...@@ -41,6 +51,11 @@ class TestDLPack(unittest.TestCase): ...@@ -41,6 +51,11 @@ class TestDLPack(unittest.TestCase):
out = paddle.utils.dlpack.from_dlpack(dlpack) out = paddle.utils.dlpack.from_dlpack(dlpack)
self.assertTrue(np.allclose(numpy_data, out.numpy())) self.assertTrue(np.allclose(numpy_data, out.numpy()))
def test_dlpack_tensor_larger_than_2dim(self):
with _test_eager_guard():
self.func_test_dlpack_tensor_larger_than_2dim()
self.func_test_dlpack_tensor_larger_than_2dim()
def test_dlpack_static(self): def test_dlpack_static(self):
paddle.enable_static() paddle.enable_static()
tensor = fluid.create_lod_tensor( tensor = fluid.create_lod_tensor(
...@@ -67,7 +82,7 @@ class TestDLPack(unittest.TestCase): ...@@ -67,7 +82,7 @@ class TestDLPack(unittest.TestCase):
np.array(gout_from_dlpack), np.array(gout_from_dlpack),
np.array([[1], [2], [3], [4]]).astype('int'))) np.array([[1], [2], [3], [4]]).astype('int')))
def test_dlpack_dtype_conversion(self): def func_test_dlpack_dtype_conversion(self):
paddle.disable_static() paddle.disable_static()
# DLpack does not explicitly support bool data type. # DLpack does not explicitly support bool data type.
dtypes = [ dtypes = [
...@@ -98,15 +113,30 @@ class TestDLPack(unittest.TestCase): ...@@ -98,15 +113,30 @@ class TestDLPack(unittest.TestCase):
self.assertEqual(x.dtype, o.dtype) self.assertEqual(x.dtype, o.dtype)
self.assertTrue(np.allclose(x.numpy(), o.numpy())) self.assertTrue(np.allclose(x.numpy(), o.numpy()))
def test_dlpack_dtype_conversion(self):
with _test_eager_guard():
self.func_test_dlpack_dtype_conversion()
self.func_test_dlpack_dtype_conversion()
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
def test_from_dlpack_raise_type_error(self): def func_test_from_dlpack_raise_type_error(self):
self.assertRaises(TypeError, paddle.utils.dlpack.from_dlpack, self.assertRaises(TypeError, paddle.utils.dlpack.from_dlpack,
np.zeros(5)) np.zeros(5))
def test_to_dlpack_raise_type_error(self): def test_from_dlpack_raise_type_error(self):
with _test_eager_guard():
self.func_test_from_dlpack_raise_type_error()
self.func_test_from_dlpack_raise_type_error()
def func_test_to_dlpack_raise_type_error(self):
self.assertRaises(TypeError, paddle.utils.dlpack.to_dlpack, np.zeros(5)) self.assertRaises(TypeError, paddle.utils.dlpack.to_dlpack, np.zeros(5))
def test_to_dlpack_raise_type_error(self):
with _test_eager_guard():
self.func_test_to_dlpack_raise_type_error()
self.func_test_to_dlpack_raise_type_error()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -48,7 +48,7 @@ def to_dlpack(x): ...@@ -48,7 +48,7 @@ def to_dlpack(x):
""" """
if _non_static_mode(): if _non_static_mode():
if not isinstance(x, paddle.Tensor): if not isinstance(x, (paddle.Tensor, paddle.fluid.core.eager.Tensor)):
raise TypeError( raise TypeError(
"The type of 'x' in to_dlpack must be paddle.Tensor," "The type of 'x' in to_dlpack must be paddle.Tensor,"
" but received {}.".format(type(x))) " but received {}.".format(type(x)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册