未验证 提交 da61df5c 编写于 作者: Z zyfncg 提交者: GitHub

Fix inplace problem of setitem (#38298)

* add inplace_map for trace_op in pybind

* fix inplace problem of setitem

* refactor the param format  of trace_op
Co-authored-by: Npangyoki <pangyoki@126.com>
上级 4e578c2b
...@@ -2189,65 +2189,75 @@ void BindImperative(py::module *m_ptr) { ...@@ -2189,65 +2189,75 @@ void BindImperative(py::module *m_ptr) {
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::XPUPlace &place, framework::AttributeMap attrs, const platform::XPUPlace &place,
bool trace_backward) { bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins); auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward); std::move(attrs), place, trace_backward,
inplace_map);
} }
}) })
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::CUDAPlace &place, framework::AttributeMap attrs, const platform::CUDAPlace &place,
bool trace_backward) { bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins); auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward); std::move(attrs), place, trace_backward,
inplace_map);
} }
}) })
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::NPUPlace &place, framework::AttributeMap attrs, const platform::NPUPlace &place,
bool trace_backward) { bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins); auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward); std::move(attrs), place, trace_backward,
inplace_map);
} }
}) })
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::MLUPlace &place, framework::AttributeMap attrs, const platform::MLUPlace &place,
bool trace_backward) { bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins); auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward); std::move(attrs), place, trace_backward,
inplace_map);
} }
}) })
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs, const platform::CPUPlace &place, framework::AttributeMap attrs, const platform::CPUPlace &place,
bool trace_backward) { bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins); auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp(type, std::move(ins_map), std::move(outs_map),
std::move(attrs), place, trace_backward); std::move(attrs), place, trace_backward,
inplace_map);
} }
}); });
......
...@@ -39,10 +39,16 @@ class Tracer(core.Tracer): ...@@ -39,10 +39,16 @@ class Tracer(core.Tracer):
self._train_mode = True self._train_mode = True
def trace_op(self, type, inputs, outputs, attrs, stop_gradient=False): def trace_op(self,
type,
inputs,
outputs,
attrs,
stop_gradient=False,
inplace_map=None):
self.trace(type, inputs, outputs, attrs, self.trace(type, inputs, outputs, attrs,
framework._current_expected_place(), self._has_grad and framework._current_expected_place(), self._has_grad and
not stop_gradient) not stop_gradient, inplace_map if inplace_map else {})
def train_mode(self): def train_mode(self):
self._train_mode = True self._train_mode = True
......
...@@ -3262,6 +3262,7 @@ class Block(object): ...@@ -3262,6 +3262,7 @@ class Block(object):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
attrs = kwargs.get("attrs", {}) attrs = kwargs.get("attrs", {})
inplace_map = kwargs.get("inplace_map", None)
type = kwargs.get("type", None) type = kwargs.get("type", None)
op = Operator( op = Operator(
block=self, block=self,
...@@ -3280,7 +3281,8 @@ class Block(object): ...@@ -3280,7 +3281,8 @@ class Block(object):
kwargs.get("inputs", {}), kwargs.get("inputs", {}),
kwargs.get("outputs", {}), attrs kwargs.get("outputs", {}), attrs
if attrs else {}, if attrs else {},
kwargs.get("stop_gradient", False)) kwargs.get("stop_gradient", False),
inplace_map)
else: else:
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
......
...@@ -1330,6 +1330,24 @@ class TestGradientTruncated(unittest.TestCase): ...@@ -1330,6 +1330,24 @@ class TestGradientTruncated(unittest.TestCase):
array = array[0] array = array[0]
class TestSetValueInplace(unittest.TestCase):
def test_inplace(self):
paddle.disable_static()
with paddle.fluid.dygraph.guard():
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b = a[:]
c = b
b[paddle.to_tensor(0)] = 1.0
self.assertTrue(id(b) == id(c))
self.assertTrue(np.array_equal(b.numpy(), c.numpy()))
self.assertEqual(b.inplace_version, 1)
paddle.enable_static()
class TestSetValueInplaceLeafVar(unittest.TestCase): class TestSetValueInplaceLeafVar(unittest.TestCase):
def test_inplace_var_become_leaf_var(self): def test_inplace_var_become_leaf_var(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -665,9 +665,16 @@ def _setitem_impl_(var, item, value): ...@@ -665,9 +665,16 @@ def _setitem_impl_(var, item, value):
"paddle.Tensor to a paddle.Tensor, but received {}".format( "paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value))) type(value)))
if paddle.fluid.framework.in_dygraph_mode():
var._bump_inplace_version()
cur_block = default_main_program().current_block() cur_block = default_main_program().current_block()
cur_block.append_op( cur_block.append_op(
type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs) type="set_value",
inputs=inputs,
outputs={'Out': var},
attrs=attrs,
inplace_map={"Input": "Out"})
return var return var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册