未验证 提交 50f8e974 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support test_var_base _offset in eager mode (#41369)

* [Eager]Polish enable/disable_legacy_dygraph logic

* Support _offset in eager mode

* Update framework.py

* Update framework.py
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 780c7a1d
...@@ -1344,6 +1344,19 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self, ...@@ -1344,6 +1344,19 @@ static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__offset(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument("Tensor %s has not been initialized!",
self->tensor.name()));
return ToPyObject(t->offset());
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self, PyObject* args, static PyObject* tensor_method__uva(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1472,6 +1485,8 @@ PyMethodDef variable_methods[] = { ...@@ -1472,6 +1485,8 @@ PyMethodDef variable_methods[] = {
{"_reset_grad_inplace_version", {"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version, (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_offset", (PyCFunction)(void (*)(void))tensor__offset,
METH_VARARGS | METH_KEYWORDS, NULL},
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
{"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva, {"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
......
...@@ -426,6 +426,8 @@ PyObject* ToPyObject(int value) { return PyLong_FromLong(value); } ...@@ -426,6 +426,8 @@ PyObject* ToPyObject(int value) { return PyLong_FromLong(value); }
PyObject* ToPyObject(uint32_t value) { return PyLong_FromUnsignedLong(value); } PyObject* ToPyObject(uint32_t value) { return PyLong_FromUnsignedLong(value); }
PyObject* ToPyObject(size_t value) { return PyLong_FromLong(value); }
PyObject* ToPyObject(int64_t value) { return PyLong_FromLongLong(value); } PyObject* ToPyObject(int64_t value) { return PyLong_FromLongLong(value); }
PyObject* ToPyObject(float value) { return PyLong_FromDouble(value); } PyObject* ToPyObject(float value) { return PyLong_FromDouble(value); }
......
...@@ -55,6 +55,7 @@ framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ...@@ -55,6 +55,7 @@ framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
PyObject* ToPyObject(int value); PyObject* ToPyObject(int value);
PyObject* ToPyObject(uint32_t value); PyObject* ToPyObject(uint32_t value);
PyObject* ToPyObject(size_t value);
PyObject* ToPyObject(bool value); PyObject* ToPyObject(bool value);
PyObject* ToPyObject(int64_t value); PyObject* ToPyObject(int64_t value);
PyObject* ToPyObject(float value); PyObject* ToPyObject(float value);
......
...@@ -1396,7 +1396,7 @@ class TestVarBaseClear(unittest.TestCase): ...@@ -1396,7 +1396,7 @@ class TestVarBaseClear(unittest.TestCase):
class TestVarBaseOffset(unittest.TestCase): class TestVarBaseOffset(unittest.TestCase):
def test_offset(self): def func_offset(self):
paddle.disable_static() paddle.disable_static()
np_x = np.random.random((3, 8, 8)) np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64") x = paddle.to_tensor(np_x, dtype="float64")
...@@ -1405,6 +1405,11 @@ class TestVarBaseOffset(unittest.TestCase): ...@@ -1405,6 +1405,11 @@ class TestVarBaseOffset(unittest.TestCase):
actual_x = paddle.to_tensor(actual_x) actual_x = paddle.to_tensor(actual_x)
self.assertEqual(actual_x._offset(), expected_offset) self.assertEqual(actual_x._offset(), expected_offset)
def test_offset(self):
with _test_eager_guard():
self.func_offset()
self.func_offset()
class TestVarBaseShareBufferTo(unittest.TestCase): class TestVarBaseShareBufferTo(unittest.TestCase):
def test_share_buffer_To(self): def test_share_buffer_To(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册