未验证 提交 da61df5c 编写于 作者: Z zyfncg 提交者: GitHub

Fix inplace problem of setitem (#38298)

* add inplace_map for trace_op in pybind

* fix inplace problem of setitem

* refactor the param format  of trace_op
Co-authored-by: Npangyoki <pangyoki@126.com>
上级 4e578c2b
......@@ -2189,65 +2189,75 @@ void BindImperative(py::module *m_ptr) {
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::XPUPlace &place,
bool trace_backward) {
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward);
std::move(attrs), place, trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::CUDAPlace &place,
bool trace_backward) {
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward);
std::move(attrs), place, trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::NPUPlace &place,
bool trace_backward) {
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward);
std::move(attrs), place, trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::MLUPlace &place,
bool trace_backward) {
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward);
std::move(attrs), place, trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::CPUPlace &place,
bool trace_backward) {
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward);
std::move(attrs), place, trace_backward,
inplace_map);
}
});
......
......@@ -39,10 +39,16 @@ class Tracer(core.Tracer):
self._train_mode = True
def trace_op(self, type, inputs, outputs, attrs, stop_gradient=False):
def trace_op(self,
type,
inputs,
outputs,
attrs,
stop_gradient=False,
inplace_map=None):
self.trace(type, inputs, outputs, attrs,
framework._current_expected_place(), self._has_grad and
not stop_gradient)
not stop_gradient, inplace_map if inplace_map else {})
def train_mode(self):
self._train_mode = True
......
......@@ -3262,6 +3262,7 @@ class Block(object):
"""
if in_dygraph_mode():
attrs = kwargs.get("attrs", {})
inplace_map = kwargs.get("inplace_map", None)
type = kwargs.get("type", None)
op = Operator(
block=self,
......@@ -3280,7 +3281,8 @@ class Block(object):
kwargs.get("inputs", {}),
kwargs.get("outputs", {}), attrs
if attrs else {},
kwargs.get("stop_gradient", False))
kwargs.get("stop_gradient", False),
inplace_map)
else:
from paddle.fluid.dygraph.base import param_guard
......
......@@ -1330,6 +1330,24 @@ class TestGradientTruncated(unittest.TestCase):
array = array[0]
class TestSetValueInplace(unittest.TestCase):
def test_inplace(self):
paddle.disable_static()
with paddle.fluid.dygraph.guard():
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b = a[:]
c = b
b[paddle.to_tensor(0)] = 1.0
self.assertTrue(id(b) == id(c))
self.assertTrue(np.array_equal(b.numpy(), c.numpy()))
self.assertEqual(b.inplace_version, 1)
paddle.enable_static()
class TestSetValueInplaceLeafVar(unittest.TestCase):
def test_inplace_var_become_leaf_var(self):
paddle.disable_static()
......
......@@ -665,9 +665,16 @@ def _setitem_impl_(var, item, value):
"paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value)))
if paddle.fluid.framework.in_dygraph_mode():
var._bump_inplace_version()
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs)
type="set_value",
inputs=inputs,
outputs={'Out': var},
attrs=attrs,
inplace_map={"Input": "Out"})
return var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册