未验证 提交 25e723e7 编写于 作者: L liym27 提交者: GitHub

[Setitem] Support grad computation of op set_value (#32431)

上级 5943ff7b
......@@ -146,22 +146,75 @@ Assignment to a Tensor in static mode.
)DOC");
}
};
template <typename T>
class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("slice");
op->SetInput("Input", this->OutputGrad("Out"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
// convert std::vector<int64_t > to std::vector<int >
std::vector<int64_t> axes_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("axes")));
std::vector<int64_t> starts_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("starts")));
std::vector<int64_t> ends_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("ends")));
std::vector<int64_t> decrease_axes_int64 =
static_cast<std::vector<int64_t>>(BOOST_GET_CONST(
std::vector<int64_t>, this->GetAttr("decrease_axes")));
std::vector<int> axes(axes_int64.begin(), axes_int64.end());
std::vector<int> starts(starts_int64.begin(), starts_int64.end());
std::vector<int> ends(ends_int64.begin(), ends_int64.end());
std::vector<int> decrease_axes(decrease_axes_int64.begin(),
decrease_axes_int64.end());
op->SetAttr("axes", axes);
op->SetAttr("starts", starts);
op->SetAttr("ends", ends);
op->SetAttr("decrease_axis", decrease_axes);
op->SetAttr("infer_flags", std::vector<int>({}));
op->SetOutput("Out", this->InputGrad("ValueTensor"));
} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
}
}
};
DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
set_value, ops::SetValue, ops::SetValueMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::framework::OpDesc>,
ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>);
ops::SetValueKernel<plat::CPUDeviceContext, int64_t>,
ops::SetValueKernel<plat::CPUDeviceContext, float>,
ops::SetValueKernel<plat::CPUDeviceContext, double>,
ops::SetValueKernel<plat::CPUDeviceContext, bool>);
REGISTER_OP_VERSION(set_value)
.AddCheckpoint(
......
......@@ -718,7 +718,8 @@ void BindImperative(py::module *m_ptr) {
{
// Release gil and do tracing
py::gil_scoped_release release;
tracer->TraceOp("set_value", ins, outs, std::move(attrs));
tracer->TraceOp("set_value", ins, outs, std::move(attrs),
{{"Input", "Out"}});
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
......
......@@ -775,5 +775,76 @@ class TestError(TestSetValueBase):
self._broadcast_mismatch()
# 5. Test backward
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv = paddle.nn.Conv2D(12, 12, 3)
def forward(self, x, y):
x = self.conv(x)
y = self.conv(y)
var = y.flatten()
x[0, :, 0, 0] = var
loss = paddle.mean(x)
return loss, var, x
class TestBackward(unittest.TestCase):
def test_static(self):
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
x_np = np.random.random(size=(4, 4)).astype('float32')
y_np = np.random.random(size=(4, 4)).astype('float32')
label_np = np.random.randint(2, size=(4, 1)).astype('int64')
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
y = paddle.static.data(name="y", shape=[4, 4], dtype='float32')
label = paddle.static.data(
name="label", shape=[4, 1], dtype='int64')
z = paddle.add(x, y)
var = y[0, :]
z[0, :] = var
prediction = paddle.static.nn.fc(x=z, size=2, activation='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label)
loss = paddle.mean(cost)
sgd = paddle.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(startup_program)
var_grad, z_grad = exe.run(
main_program,
feed={"x": x_np,
"y": y_np,
"label": label_np},
fetch_list=[var.name + "@GRAD", z.name + "@GRAD"])
self.assertTrue((var_grad == z_grad[0, :]).all())
def test_dynamic(self):
paddle.disable_static()
model = Model()
x = paddle.ones([1, 12, 3, 3]).astype("float32")
y = paddle.ones([1, 12, 3, 3]).astype("float32")
loss, var, x = model(x, y)
loss.backward()
self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape)
self.assertTrue((var.grad == x.grad[0, :, 0, 0]).all())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册