diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index 9a6c43dee6d9d1a9fce5c7ec5f7af3d495e1f957..41e6d2d40061e81d148e857bce6eb2c7876edf09 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -157,39 +157,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr op) const override { if (this->HasInput("ValueTensor")) { - op->SetType("slice"); - op->SetInput("Input", this->OutputGrad("Out")); + op->SetType("set_value_grad"); + + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("ValueTensor", this->Input("ValueTensor")); if (this->HasInput("StartsTensorList")) { op->SetInput("StartsTensorList", this->Input("StartsTensorList")); } if (this->HasInput("EndsTensorList")) { op->SetInput("EndsTensorList", this->Input("EndsTensorList")); } + if (this->HasInput("StepsTensorList")) { + op->SetInput("StepsTensorList", this->Input("StepsTensorList")); + } + + op->SetAttrMap(this->Attrs()); + + op->SetOutput(framework::GradVarName("ValueTensor"), + this->InputGrad("ValueTensor")); + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - // convert std::vector to std::vector - std::vector axes_int64 = static_cast>( - BOOST_GET_CONST(std::vector, this->GetAttr("axes"))); - std::vector starts_int64 = static_cast>( - BOOST_GET_CONST(std::vector, this->GetAttr("starts"))); - std::vector ends_int64 = static_cast>( - BOOST_GET_CONST(std::vector, this->GetAttr("ends"))); - std::vector decrease_axes_int64 = - static_cast>(BOOST_GET_CONST( - std::vector, this->GetAttr("decrease_axes"))); - - std::vector axes(axes_int64.begin(), axes_int64.end()); - std::vector starts(starts_int64.begin(), starts_int64.end()); - std::vector ends(ends_int64.begin(), ends_int64.end()); - std::vector decrease_axes(decrease_axes_int64.begin(), - decrease_axes_int64.end()); - - op->SetAttr("axes", axes); - op->SetAttr("starts", starts); - op->SetAttr("ends", ends); - op->SetAttr("decrease_axis", decrease_axes); - op->SetAttr("infer_flags", std::vector({})); - - op->SetOutput("Out", this->InputGrad("ValueTensor")); } else { op->SetType("assign"); op->SetInput("X", this->OutputGrad("Out")); @@ -198,6 +185,50 @@ class SetValueGradMaker : public framework::SingleGradOpMaker { } }; +class SetValueGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "set_value_grad"); + + auto in_dims = ctx->GetInputDim(framework::GradVarName("Out")); + PADDLE_ENFORCE_LT( + in_dims.size(), 7, + platform::errors::InvalidArgument( + "The dimension of set_value_grad operator's input should be less " + "than 7, but received dimension is %d.", + in_dims.size())); + + if (ctx->HasOutput(framework::GradVarName("ValueTensor"))) { + ctx->ShareDim("ValueTensor", + /*->*/ framework::GradVarName("ValueTensor")); + ctx->ShareLoD("ValueTensor", + /*->*/ framework::GradVarName("ValueTensor")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto in_tensor = ctx.Input(framework::GradVarName("Out")); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + in_tensor->place()); + } + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || + var_name == "StepsTensorList") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"}); } // namespace operators @@ -218,6 +249,16 @@ REGISTER_OP_CPU_KERNEL( ops::SetValueKernel, ops::SetValueKernel); +REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad); + +REGISTER_OP_CPU_KERNEL( + set_value_grad, + ops::SetValueGradKernel, + ops::SetValueGradKernel, + ops::SetValueGradKernel, + ops::SetValueGradKernel, + ops::SetValueGradKernel); + REGISTER_OP_VERSION(set_value) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/set_value_op.cu b/paddle/fluid/operators/set_value_op.cu index b65e1691b99c5d78069e5a176e4ace6f5fcc6470..f9701b0acaac769bd91bbba156a010c2e05e42c3 100644 --- a/paddle/fluid/operators/set_value_op.cu +++ b/paddle/fluid/operators/set_value_op.cu @@ -22,3 +22,11 @@ REGISTER_OP_CUDA_KERNEL( ops::SetValueKernel, ops::SetValueKernel, ops::SetValueKernel); + +REGISTER_OP_CUDA_KERNEL( + set_value_grad, + ops::SetValueGradKernel, + ops::SetValueGradKernel, + ops::SetValueGradKernel, + ops::SetValueGradKernel, + ops::SetValueGradKernel); diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index eed8a9c9b22bc824a80475fbcab00911013fdf3c..72b94dfa7727922a95f3b1d86f9b3327ff00238e 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -22,8 +22,10 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/assign_value_op.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/slice_utils.h" +#include "paddle/fluid/operators/strided_slice_op.h" #include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/enforce.h" @@ -31,6 +33,24 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using DDim = framework::DDim; + +inline void GetOffsets(const DDim& big_dim, const DDim& small_dim, + DDim start_offset, int cur_dim, + std::vector* offsets) { + if (cur_dim == big_dim.size()) { + offsets->push_back(start_offset); + return; + } + if (small_dim[cur_dim] == big_dim[cur_dim]) { + GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets); + } else { + for (int i = 0; i < big_dim[cur_dim]; i++) { + GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets); + start_offset[cur_dim] += 1; + } + } +} inline std::string GetValueName(framework::proto::VarType::Type data_type) { std::string value_name; @@ -292,5 +312,253 @@ class SetValueKernel : public framework::OpKernel { } }; +template +class SetValueGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int rank = ctx.Input(framework::GradVarName("Out"))->dims().size(); + + switch (rank) { + case 1: + SetValueGradCompute<1>(ctx); + break; + case 2: + SetValueGradCompute<2>(ctx); + break; + case 3: + SetValueGradCompute<3>(ctx); + break; + case 4: + SetValueGradCompute<4>(ctx); + break; + case 5: + SetValueGradCompute<5>(ctx); + break; + case 6: + SetValueGradCompute<6>(ctx); + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of set_value_grad's input should be less than 7, but " + "received %d.", + rank)); + } + } + + private: + template + void SetValueGradCompute(const framework::ExecutionContext& context) const { + auto starts = context.Attr>("starts"); + auto ends = context.Attr>("ends"); + auto steps = context.Attr>("steps"); + + auto axes_int64 = context.Attr>("axes"); + std::vector axes(axes_int64.begin(), axes_int64.end()); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto steps_indices = Eigen::DSizes(); + auto reverse_axis = Eigen::array(); + + auto list_new_ends_tensor = + context.MultiInput("EndsTensorList"); + auto list_new_starts_tensor = + context.MultiInput("StartsTensorList"); + auto list_new_steps_tensor = + context.MultiInput("StepsTensorList"); + + if (list_new_starts_tensor.size() > 0) { + starts = GetDataFromTensorList(list_new_starts_tensor); + } + + if (list_new_ends_tensor.size() > 0) { + ends = GetDataFromTensorList(list_new_ends_tensor); + } + + if (list_new_steps_tensor.size() > 0) { + steps = GetDataFromTensorList(list_new_steps_tensor); + } + + auto in = context.Input(framework::GradVarName("Out")); + PADDLE_ENFORCE_EQ( + in->IsInitialized(), true, + platform::errors::PermissionDenied( + "The input of `set_value_grad`(%s) has not been initialized", + framework::GradVarName("Out"))); + auto grad_value = context.Output( + framework::GradVarName("ValueTensor")); + auto grad_input = + context.Output(framework::GradVarName("Input")); + auto in_dims = in->dims(); + + auto decrease_axis_int64 = + context.Attr>("decrease_axes"); + std::vector decrease_axis(decrease_axis_int64.begin(), + decrease_axis_int64.end()); + std::vector infer_flags(axes.size(), 1); + std::vector out_dims_vector(in_dims.size(), -1); + StridedSliceOutDims(starts, ends, steps, axes, infer_flags, in_dims, + decrease_axis, out_dims_vector.data(), axes.size(), + false); + + framework::DDim out_dims(framework::make_ddim(out_dims_vector)); + + std::vector reverse_vector(starts.size(), 0); + StridedSliceFunctor(starts.data(), ends.data(), steps.data(), axes.data(), + reverse_vector.data(), in_dims, infer_flags, + decrease_axis, starts.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + steps_indices[axis] = 1; + reverse_axis[axis] = false; + } + + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts[axis]; + ends_indices[axis_index] = ends[axis]; + steps_indices[axis_index] = steps[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + auto& dev_ctx = context.template device_context(); + auto& place = + *context.template device_context().eigen_device(); + math::SetConstant set_zero; + + if (grad_input) { + // Set gradient of `Input` + TensorCopy(*in, context.GetPlace(), grad_input); + + auto grad_input_t = + framework::EigenTensor::From(*grad_input); + + framework::Tensor tmp(grad_input->type()); + tmp.mutable_data(out_dims, context.GetPlace()); + set_zero(dev_ctx, &tmp, static_cast(0)); + auto tmp_t = framework::EigenTensor::From(tmp); + + grad_input_t.stridedSlice(starts_indices, ends_indices, steps_indices) + .device(place) = tmp_t; + } + if (grad_value) { + grad_value->mutable_data(context.GetPlace()); + set_zero(dev_ctx, grad_value, static_cast(0)); + + auto in_t = framework::EigenTensor::From(*in); + + if (grad_value->dims() == out_dims) { + auto grad_value_t = + framework::EigenTensor::From(*grad_value); + if (need_reverse) { + framework::Tensor tmp(grad_value->type()); + tmp.mutable_data(out_dims, context.GetPlace()); + set_zero(dev_ctx, &tmp, static_cast(0)); + auto tmp_t = framework::EigenTensor::From(tmp); + + tmp_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, steps_indices); + grad_value_t.device(place) = tmp_t.reverse(reverse_axis); + } else { + grad_value_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, steps_indices); + } + } else { + int out_dims_size = out_dims.size(); + auto grad_value_dims = grad_value->dims(); + auto fake_grad_value_dims = out_dims; + + // Create an extented shape according to the rules of broadcast. + auto grad_value_dims_size = grad_value_dims.size(); + + int num_decrease = 0; + + int decrease_axis_size = decrease_axis.size(); + for (int i = 0; i < out_dims_size; i++) { + if (decrease_axis.end() != + std::find(decrease_axis.begin(), decrease_axis.end(), i)) { + fake_grad_value_dims[i] = 1; + num_decrease++; + } else if (i < out_dims_size - (grad_value_dims_size + + decrease_axis_size - num_decrease)) { + fake_grad_value_dims[i] = 1; + } else { + auto index_grad = + i - (out_dims_size - (grad_value_dims_size + + decrease_axis_size - num_decrease)); + fake_grad_value_dims[i] = grad_value_dims[index_grad]; + + PADDLE_ENFORCE_EQ((out_dims[i] == grad_value_dims[index_grad]) || + (grad_value_dims[index_grad] == 1), + true, + platform::errors::InvalidArgument( + "An error occurred while calculating %s: " + "[%s] can not be accumulated into [%s].", + framework::GradVarName("ValueTensor"), + out_dims, grad_value_dims)); + } + } + + VLOG(3) << "Dimensions of " << framework::GradVarName("ValueTensor") + << "([" << grad_value_dims << "])is broadcasted into [" + << fake_grad_value_dims << "]."; + + auto extent = Eigen::DSizes(); + auto offset = out_dims; + for (int i = 0; i < out_dims_size; i++) { + offset[i] = 0; + extent[i] = fake_grad_value_dims[i]; + } + std::vector offsets; + GetOffsets(out_dims, fake_grad_value_dims, offset, 0, &offsets); + + auto grad_value_t = + framework::EigenTensor:: + From(*grad_value, fake_grad_value_dims); + + framework::Tensor tmp(grad_value->type()); + tmp.mutable_data(out_dims, context.GetPlace()); + set_zero(dev_ctx, &tmp, static_cast(0)); + auto tmp_t = framework::EigenTensor::From(tmp); + + tmp_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, steps_indices); + + // accumulate gradient + for (auto offset : offsets) { + grad_value_t.device(place) = + grad_value_t + + tmp_t.slice(framework::EigenDim::From(offset), extent); + } + if (need_reverse) { + framework::Tensor tmp_value(grad_value->type()); + tmp_value.mutable_data(fake_grad_value_dims, context.GetPlace()); + auto tmp_value_t = + framework::EigenTensor::From(tmp_value); + tmp_value_t.device(place) = grad_value_t.reverse(reverse_axis); + grad_value_t.device(place) = tmp_value_t; + } + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 6f2f669913eb6a9b532442aaa45d01dacf56dcaa..d26055b3166d6e784dd22b7a6c8b4b3b7bf0d799 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -20,6 +20,8 @@ import unittest import numpy as np import paddle +from paddle.fluid.layer_helper import LayerHelper +from functools import reduce class TestSetValueBase(unittest.TestCase): @@ -915,7 +917,317 @@ class TestBackward(unittest.TestCase): loss.backward() self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape) - self.assertTrue((var.grad == x.grad[0, :, 0, 0]).all()) + # + self.assertTrue((0 == x.grad[0, :, 0, 0]).all()) + + +class TestGradientTruncated(unittest.TestCase): + def test_consistent_with_competitor(self): + paddle.disable_static() + + def set_value(t, value): + a = t * t + a[0, 1] = value + y = a * a + return y.sum() + + # case 1 + array = np.arange( + 1, 1 + 2 * 3 * 4, dtype="float32").reshape([1, 2, 1, 3, 1, 4]) + value = np.arange(100, 104, dtype="float32").reshape(1, 4) + + inps = paddle.to_tensor(array, stop_gradient=False) + value = paddle.to_tensor(value, stop_gradient=False) + + loss = set_value(inps, value) + loss.backward() + + value_grad = np.array([[600., 606., 612., 618.]]) + input_grad = np.array( + [[[[[[4., 32., 108., 256.]], [[500., 864., 1372., 2048.]], + [[2916., 4000., 5324., 6912.]]]], + [[[[0., 0., 0., 0.]], [[0., 0., 0., 0.]], [[0., 0., 0., 0.]]]]]]) + self.assertTrue( + np.array_equal(inps.grad.numpy(), input_grad), + msg="The gradient of value should be \n{},\n but reveived {}". + format(input_grad, inps.grad.numpy())) + self.assertTrue( + np.array_equal(value.grad.numpy(), value_grad), + msg="The gradient of input should be \n{},\n but reveived {}". + format(value_grad, value.grad.numpy())) + + # case 2 + array = np.arange(1, 2 * 3 * 4 + 1, dtype="float32").reshape([4, 2, 3]) + value = np.arange(100, 100 + 1, dtype="float32") + + inps2 = paddle.to_tensor(array, stop_gradient=False) + value2 = paddle.to_tensor(value, stop_gradient=False) + + loss = set_value(inps2, value2) + loss.backward() + + value_grad2 = np.array([600.]) + input_grad2 = np.array( + [[[4., 32., 108.], [0., 0., 0.]], [[1372., 2048., 2916.], + [4000., 5324., 6912.]], + [[8788., 10976., 13500.], [16384., 19652., 23328.]], + [[27436., 32000., 37044.], [42592., 48668., 55296.]]]) + self.assertTrue( + np.array_equal(inps2.grad.numpy(), input_grad2), + msg="The gradient of value should be \n{},\n but reveived {}". + format(input_grad, inps2.grad.numpy())) + self.assertTrue( + np.array_equal(value2.grad.numpy(), value_grad2), + msg="The gradient of input should be \n{},\n but reveived {}". + format(value_grad, value2.grad.numpy())) + + # case 3 + def set_value3(t, value): + a = t * t + a[0, :, 0, :] = value + y = a * a + return y.sum() + + array = np.arange( + 1, 1 + 2 * 3 * 4, dtype="float32").reshape([4, 3, 1, 1, 2, 1]) + value = np.arange(100, 100 + 2, dtype="float32").reshape(1, 2, 1) + + inps = paddle.to_tensor(array, stop_gradient=False) + value = paddle.to_tensor(value, stop_gradient=False) + + loss = set_value3(inps, value) + loss.backward() + + value_grad = np.array([[[600.], [606.]]]) + input_grad = np.array( + [[[[[[0.], [0.]]]], [[[[0.], [0.]]]], [[[[0.], [0.]]]]], + [[[[[1372.], [2048.]]]], [[[[2916.], [4000.]]]], + [[[[5324.], [6912.]]]]], [[[[[8788.], [10976.]]]], [[[[13500.], + [16384.]]]], + [[[[19652.], [23328.]]]]], + [[[[[27436.], [32000.]]]], [[[[37044.], [42592.]]]], + [[[[48668.], [55296.]]]]]]) + self.assertTrue( + np.array_equal(inps.grad.numpy(), input_grad), + msg="The gradient of value should be \n{},\n but reveived {}". + format(input_grad, inps.grad.numpy())) + self.assertTrue( + np.array_equal(value.grad.numpy(), value_grad), + msg="The gradient of input should be \n{},\n but reveived {}". + format(value_grad, value.grad.numpy())) + + #case 4: step >0 + def set_value4(t, value): + a = t * t + a[0, :, 0, ::3] = value + y = a * a + return y.sum() + + array = np.arange( + 1, 1 + 2 * 3 * 4, dtype="float32").reshape([2, 3, 1, 4, 1]) + value = np.arange(100, 100 + 2, dtype="float32").reshape(1, 2, 1) + + inps = paddle.to_tensor(array, stop_gradient=False) + value = paddle.to_tensor(value, stop_gradient=False) + + loss = set_value4(inps, value) + loss.backward() + + value_grad = np.array([[[600.], [606.]]]) + input_grad = np.array([[[[[0.], [32.], [108.], + [0.]]], [[[0.], [864.], [1372.], [0.]]], + [[[0.], [4000.], [5324.], [0.]]]], + [[[[8788.], [10976.], [13500.], [16384.]]], + [[[19652.], [23328.], [27436.], [32000.]]], + [[[37044.], [42592.], [48668.], [55296.]]]]]) + self.assertTrue( + np.array_equal(inps.grad.numpy(), input_grad), + msg="The gradient of value should be \n{},\n but reveived {}". + format(input_grad, inps.grad.numpy())) + self.assertTrue( + np.array_equal(value.grad.numpy(), value_grad), + msg="The gradient of input should be \n{},\n but reveived {}". + format(value_grad, value.grad.numpy())) + + # case 5:a[0].shape==value.shape + def set_value5(t, value): + a = t * t + a[0] = value + y = a * a + return y.sum() + + array = np.arange(1, 1 + 2 * 3 * 4, dtype="float32").reshape([2, 3, 4]) + value = np.arange(100, 100 + 12, dtype="float32").reshape(3, 4) + + inps = paddle.to_tensor(array, stop_gradient=False) + value = paddle.to_tensor(value, stop_gradient=False) + + loss = set_value5(inps, value) + loss.backward() + + value_grad = np.array([[200., 202., 204., 206.], + [208., 210., 212., 214.], + [216., 218., 220., 222.]]) + input_grad = np.array([[[0., 0., 0., 0.], [0., 0., 0., 0.], + [0., 0., 0., 0.]], + [[8788., 10976., 13500., 16384.], + [19652., 23328., 27436., 32000.], + [37044., 42592., 48668., 55296.]]]) + self.assertTrue( + np.array_equal(inps.grad.numpy(), input_grad), + msg="The gradient of value should be \n{},\n but reveived {}". + format(input_grad, inps.grad.numpy())) + self.assertTrue( + np.array_equal(value.grad.numpy(), value_grad), + msg="The gradient of input should be \n{},\n but reveived {}". + format(value_grad, value.grad.numpy())) + + def test_static_graph(self): + paddle.enable_static() + + to_string = lambda x, i, : x + '_' + str(i) + numel = lambda input_shape: reduce(lambda x, y: x * y, input_shape) + + def op1(x): + value = paddle.fluid.layers.fill_constant([1], "float32", 1) + # test stop_gradient + value.stop_gradient = True + x.stop_gradient = False + start = paddle.fluid.layers.fill_constant( + [1], "int32", 5, force_cpu=True) + end = paddle.fluid.layers.fill_constant( + [1], "int32", 0, force_cpu=True) + step = paddle.fluid.layers.fill_constant( + [1], "int32", -2, force_cpu=True) + + inputs = { + 'Input': x, + 'ValueTensor': value, + 'StartsTensorList': [start, ], + 'EndsTensorList': [end, ], + 'StepsTensorList': [step, ] + } + + helper = LayerHelper("set_value") + y = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': y}, + attrs={'axes': [0]}) + + return y, value + + def op2(x): + value = paddle.fluid.layers.fill_constant([1, 3, 2], "float32", 1) + # test stop_gradient + value.stop_gradient = False + x.stop_gradient = False + attrs = { + 'axes': [0], + 'starts': [6], + 'ends': [0], + 'steps': [-4], + 'decrease_axes': [], + 'none_axes': [], + 'dtype': paddle.float32 + } + inputs = {'Input': x, 'ValueTensor': value} + + helper = LayerHelper("set_value") + y = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': y}, + attrs=attrs) + + return y, value + + def op3(x): + value = paddle.fluid.layers.fill_constant([1], "float32", 1) + x.stop_gradient = True + value.stop_gradient = False + start = paddle.fluid.layers.fill_constant( + [1], "int32", 0, force_cpu=True) + end = paddle.fluid.layers.fill_constant( + [1], "int32", 5, force_cpu=True) + step = paddle.fluid.layers.fill_constant( + [1], "int32", 3, force_cpu=True) + + inputs = { + 'Input': x, + 'ValueTensor': value, + 'StartsTensorList': [start, ], + 'EndsTensorList': [end, ], + 'StepsTensorList': [step, ] + } + + helper = LayerHelper("set_value") + y = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': y}, + attrs={'axes': [0]}) + + return y, value + + def set_value(array, i, op): + name_x = to_string('x', i) + x = paddle.static.data( + name=name_x, shape=array.shape, dtype='float32') + + # set_value_op in __get/setitem__ is an inplace operation. + # When `input.stop_gradient = True` and `value.stop_gradient = False`, + # set_value_grad_op will not be run during backward. + y, value = op(x) + + y2 = y + 1 + loss = paddle.fluid.layers.reduce_sum(y2) + sgd = paddle.optimizer.Adam() + sgd.minimize(loss) + place = paddle.fluid.CPUPlace( + ) if not paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.fluid.CUDAPlace(0) + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + fetch_list = [] + if not x.stop_gradient: + fetch_list.append(x.grad_name) + if not value.stop_gradient: + fetch_list.append(value.grad_name) + out = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list) + return out + + input_shape = [7, 6, 5, 4, 3, 2] + + array = np.arange( + 0, numel(input_shape), dtype="float32").reshape(input_shape) + + for i in range(len(input_shape)): + program = paddle.static.Program() + with paddle.static.program_guard(program): + out1 = set_value(array, i, op1) + self.assertTrue((out1[0][5:0:-2] == 0).all()) + + if len(array.shape) > 2: + program2 = paddle.static.Program() + with paddle.static.program_guard(program2): + out2 = set_value(array, i, op2) + self.assertTrue((out2[0][6:0:-4] == 0).all()) + + program3 = paddle.static.Program() + with paddle.static.program_guard(program3): + out3 = set_value(array, i, op3) + self.assertTrue((numel(out1[0][0:5:3].shape) == out3[0]).all()) + + array = array[0] if __name__ == '__main__':