未验证 提交 b4a4f1bb 编写于 作者: X xiongkun 提交者: GitHub

[New Feature] add _inplace_assign interface for sot. (#56077)

* [New Feature] add _inplace_assign interface for sot

* add unittest for inplace_assign
上级 6ff4c130
......@@ -63,6 +63,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/utils/pybind.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
......@@ -1623,6 +1624,18 @@ static PyObject* tensor_remove_grad_hook(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_inplace_assign(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(6) << "inplace assign for tensor:" << self->tensor.name();
PyObject* other = PyTuple_GET_ITEM(args, 0);
PyObject* self_obj = reinterpret_cast<PyObject*>(self);
ShareTensor(self_obj, other);
RETURN_PY_NONE;
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyDoc_STRVAR(tensor_method__register_reduce_hook__doc__,
R"DOC(_register_backward_hook($self, hook, /)
--
......@@ -2455,6 +2468,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)())tensor_register_grad_hook,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_inplace_assign", // NOTE(xiongkun03): only used in sot.
(PyCFunction)(void (*)())tensor_inplace_assign,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_remove_grad_hook",
(PyCFunction)(void (*)())tensor_remove_grad_hook,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -32,6 +32,18 @@ bool PyCheckTensor(PyObject* obj) {
return PyObject_TypeCheck(obj, p_tensor_type);
}
void ShareTensor(PyObject* src, PyObject* dst) {
if (PyObject_TypeCheck(src, p_tensor_type) &&
PyObject_TypeCheck(dst, p_tensor_type)) {
auto& src_tensor = reinterpret_cast<TensorObject*>(src)->tensor;
const auto& dst_tensor = reinterpret_cast<TensorObject*>(dst)->tensor;
src_tensor = dst_tensor;
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("Share tensor only support DenseTensor."));
}
}
paddle::Tensor CastPyArg2Tensor(PyObject* obj, Py_ssize_t arg_pos) {
if (PyObject_TypeCheck(obj, p_tensor_type) ||
PyObject_TypeCheck(obj, p_string_tensor_type)) {
......
......@@ -37,6 +37,9 @@ typedef struct {
// Internal use only, to expose the Tensor type to Python.
bool PyCheckTensor(PyObject* obj);
// Share Tensor for inplace.
void ShareTensor(PyObject* src, PyObject* dst);
// Internal use only, to expose the Tensor type to Python.
paddle::Tensor CastPyArg2Tensor(PyObject* obj, Py_ssize_t arg_pos);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestInplaceAssign(unittest.TestCase):
def test_case0(self):
a = paddle.ones((1024, 2)) * 1
b = paddle.ones((1024, 3)) * 2
c = paddle.ones((1024, 4)) * 3
a._inplace_assign(b)
np.testing.assert_array_equal(a.numpy(), b.numpy())
b._inplace_assign(c)
np.testing.assert_array_equal(b.numpy(), c.numpy())
def test_case1(self):
def func(x):
a = 1 * x
b = 2 * x
a._inplace_assign(b)
return a
x = paddle.ones((1,))
a = paddle.randn((1,))
x.stop_gradient = False
a.stop_gradient = False
y = func(x)
y.mean().backward()
np.testing.assert_array_equal(x.grad.numpy(), np.array([2.0]))
def test_case2(self):
@paddle.jit.to_static
def func(a, x):
x[:] = a * 2.0
return x
def forward(a, x):
output = func(a, x)
x._inplace_assign(output)
return x
x = paddle.ones((1,))
a = paddle.randn((1,))
x.stop_gradient = False
a.stop_gradient = False
y = forward(a, x)
y.mean().backward()
np.testing.assert_array_equal(a.grad.numpy(), np.array([2.0]))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册