未验证 提交 25e723e7 编写于 作者: L liym27 提交者: GitHub

[Setitem] Support grad computation of op set_value (#32431)

上级 5943ff7b
develop Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_some_yaml_config addfile all_new_design_exec ascendrelease cherry_undefined_var cp_2.4_fix_numpy delete_add_default_att delete_all_new_design_exec delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_paralleltest delete_prv-disable-more-cache delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_dlpack_for fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fix_var_stop_gradient_error fixiscan fixiscan1 fixiscan2 fixiscan3 improve_sccache incubate/frl_train_eval incubate/infrt inplace_addto layer_norm make_flag_adding_easier matmul_double_grad move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 prv-reshape-mkldnn-ut2 pten_tensor_refactor release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment revert-47325-remove_cudnn_hardcode revert-47645-add_npu_storage_dims revert-48815-set_free_when_no_cache_hit_default_value_true revert-49654-prim_api_gen revert-49763-fix_static_composite_gen support-0D-sort support_weight_transpose test_for_Filtetfiles zhiqiu-patch-1 v2.4.1 v2.4.0 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0
无相关合并请求
......@@ -146,22 +146,75 @@ Assignment to a Tensor in static mode.
)DOC");
}
};
template <typename T>
class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("slice");
op->SetInput("Input", this->OutputGrad("Out"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
// convert std::vector<int64_t > to std::vector<int >
std::vector<int64_t> axes_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("axes")));
std::vector<int64_t> starts_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("starts")));
std::vector<int64_t> ends_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("ends")));
std::vector<int64_t> decrease_axes_int64 =
static_cast<std::vector<int64_t>>(BOOST_GET_CONST(
std::vector<int64_t>, this->GetAttr("decrease_axes")));
std::vector<int> axes(axes_int64.begin(), axes_int64.end());
std::vector<int> starts(starts_int64.begin(), starts_int64.end());
std::vector<int> ends(ends_int64.begin(), ends_int64.end());
std::vector<int> decrease_axes(decrease_axes_int64.begin(),
decrease_axes_int64.end());
op->SetAttr("axes", axes);
op->SetAttr("starts", starts);
op->SetAttr("ends", ends);
op->SetAttr("decrease_axis", decrease_axes);
op->SetAttr("infer_flags", std::vector<int>({}));
op->SetOutput("Out", this->InputGrad("ValueTensor"));
} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
}
}
};
DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
set_value, ops::SetValue, ops::SetValueMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::framework::OpDesc>,
ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>);
ops::SetValueKernel<plat::CPUDeviceContext, int64_t>,
ops::SetValueKernel<plat::CPUDeviceContext, float>,
ops::SetValueKernel<plat::CPUDeviceContext, double>,
ops::SetValueKernel<plat::CPUDeviceContext, bool>);
REGISTER_OP_VERSION(set_value)
.AddCheckpoint(
......
......@@ -718,7 +718,8 @@ void BindImperative(py::module *m_ptr) {
{
// Release gil and do tracing
py::gil_scoped_release release;
tracer->TraceOp("set_value", ins, outs, std::move(attrs));
tracer->TraceOp("set_value", ins, outs, std::move(attrs),
{{"Input", "Out"}});
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
......
......@@ -775,5 +775,76 @@ class TestError(TestSetValueBase):
self._broadcast_mismatch()
# 5. Test backward
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv = paddle.nn.Conv2D(12, 12, 3)
def forward(self, x, y):
x = self.conv(x)
y = self.conv(y)
var = y.flatten()
x[0, :, 0, 0] = var
loss = paddle.mean(x)
return loss, var, x
class TestBackward(unittest.TestCase):
def test_static(self):
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
x_np = np.random.random(size=(4, 4)).astype('float32')
y_np = np.random.random(size=(4, 4)).astype('float32')
label_np = np.random.randint(2, size=(4, 1)).astype('int64')
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
y = paddle.static.data(name="y", shape=[4, 4], dtype='float32')
label = paddle.static.data(
name="label", shape=[4, 1], dtype='int64')
z = paddle.add(x, y)
var = y[0, :]
z[0, :] = var
prediction = paddle.static.nn.fc(x=z, size=2, activation='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=label)
loss = paddle.mean(cost)
sgd = paddle.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(startup_program)
var_grad, z_grad = exe.run(
main_program,
feed={"x": x_np,
"y": y_np,
"label": label_np},
fetch_list=[var.name + "@GRAD", z.name + "@GRAD"])
self.assertTrue((var_grad == z_grad[0, :]).all())
def test_dynamic(self):
paddle.disable_static()
model = Model()
x = paddle.ones([1, 12, 3, 3]).astype("float32")
y = paddle.ones([1, 12, 3, 3]).astype("float32")
loss, var, x = model(x, y)
loss.backward()
self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape)
self.assertTrue((var.grad == x.grad[0, :, 0, 0]).all())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部