From 25e723e75d0762fbc8890a6b496318abf7fb76a7 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sun, 25 Apr 2021 16:28:04 +0800 Subject: [PATCH] [Setitem] Support grad computation of op set_value (#32431) --- paddle/fluid/operators/set_value_op.cc | 69 +++++++++++++++--- paddle/fluid/pybind/imperative.cc | 3 +- .../tests/unittests/test_set_value_op.py | 71 +++++++++++++++++++ 3 files changed, 134 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index 105d61015fc..96a132ac6ab 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -146,22 +146,75 @@ Assignment to a Tensor in static mode. )DOC"); } }; + +template +class SetValueGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + if (this->HasInput("ValueTensor")) { + op->SetType("slice"); + op->SetInput("Input", this->OutputGrad("Out")); + if (this->HasInput("StartsTensorList")) { + op->SetInput("StartsTensorList", this->Input("StartsTensorList")); + } + if (this->HasInput("EndsTensorList")) { + op->SetInput("EndsTensorList", this->Input("EndsTensorList")); + } + + // convert std::vector to std::vector + std::vector axes_int64 = static_cast>( + BOOST_GET_CONST(std::vector, this->GetAttr("axes"))); + std::vector starts_int64 = static_cast>( + BOOST_GET_CONST(std::vector, this->GetAttr("starts"))); + std::vector ends_int64 = static_cast>( + BOOST_GET_CONST(std::vector, this->GetAttr("ends"))); + std::vector decrease_axes_int64 = + static_cast>(BOOST_GET_CONST( + std::vector, this->GetAttr("decrease_axes"))); + + std::vector axes(axes_int64.begin(), axes_int64.end()); + std::vector starts(starts_int64.begin(), starts_int64.end()); + std::vector ends(ends_int64.begin(), ends_int64.end()); + std::vector decrease_axes(decrease_axes_int64.begin(), + decrease_axes_int64.end()); + + op->SetAttr("axes", axes); + op->SetAttr("starts", starts); + op->SetAttr("ends", ends); + op->SetAttr("decrease_axis", decrease_axes); + op->SetAttr("infer_flags", std::vector({})); + + op->SetOutput("Out", this->InputGrad("ValueTensor")); + } else { + op->SetType("assign"); + op->SetInput("X", this->OutputGrad("Out")); + op->SetOutput("Out", this->InputGrad("Input")); + } + } +}; + +DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"}); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; -REGISTER_OPERATOR( - set_value, ops::SetValue, ops::SetValueMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker, + ops::SetValueGradMaker, + ops::SetValueGradMaker, + ops::SetValueOpInplaceInferer); REGISTER_OP_CPU_KERNEL( set_value, ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel); + ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel); REGISTER_OP_VERSION(set_value) .AddCheckpoint( diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 0817dc33671..ace62d210b3 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -718,7 +718,8 @@ void BindImperative(py::module *m_ptr) { { // Release gil and do tracing py::gil_scoped_release release; - tracer->TraceOp("set_value", ins, outs, std::move(attrs)); + tracer->TraceOp("set_value", ins, outs, std::move(attrs), + {{"Input", "Out"}}); } } else { auto self_numpy = TensorToPyArray(*self_tensor); diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 0885891cdbe..9534e4fe954 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -775,5 +775,76 @@ class TestError(TestSetValueBase): self._broadcast_mismatch() +# 5. Test backward + + +class Model(paddle.nn.Layer): + def __init__(self): + super(Model, self).__init__() + self.conv = paddle.nn.Conv2D(12, 12, 3) + + def forward(self, x, y): + x = self.conv(x) + y = self.conv(y) + var = y.flatten() + + x[0, :, 0, 0] = var + loss = paddle.mean(x) + return loss, var, x + + +class TestBackward(unittest.TestCase): + def test_static(self): + paddle.enable_static() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + x_np = np.random.random(size=(4, 4)).astype('float32') + y_np = np.random.random(size=(4, 4)).astype('float32') + label_np = np.random.randint(2, size=(4, 1)).astype('int64') + + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') + y = paddle.static.data(name="y", shape=[4, 4], dtype='float32') + + label = paddle.static.data( + name="label", shape=[4, 1], dtype='int64') + + z = paddle.add(x, y) + var = y[0, :] + z[0, :] = var + + prediction = paddle.static.nn.fc(x=z, size=2, activation='softmax') + + cost = paddle.nn.functional.cross_entropy( + input=prediction, label=label) + loss = paddle.mean(cost) + sgd = paddle.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(startup_program) + + var_grad, z_grad = exe.run( + main_program, + feed={"x": x_np, + "y": y_np, + "label": label_np}, + fetch_list=[var.name + "@GRAD", z.name + "@GRAD"]) + + self.assertTrue((var_grad == z_grad[0, :]).all()) + + def test_dynamic(self): + paddle.disable_static() + model = Model() + x = paddle.ones([1, 12, 3, 3]).astype("float32") + y = paddle.ones([1, 12, 3, 3]).astype("float32") + loss, var, x = model(x, y) + loss.backward() + + self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape) + self.assertTrue((var.grad == x.grad[0, :, 0, 0]).all()) + + if __name__ == '__main__': unittest.main() -- GitLab