未验证 提交 9d02313c 编写于 作者: W WeiXin 提交者: GitHub

`set_value_grad` propagate gradients to `Input` and `TensorValue` (#34304)

* add set_value_grad op

* add unittest.

* polish unittest.

* polish code.

* support cuda kernel

* polish code according to CI

* polish code.

* polish code

* remove *.pyc

* polish code.

* add unittest to improve coverage.

* polish code.
上级 3429c04b
......@@ -157,39 +157,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("slice");
op->SetInput("Input", this->OutputGrad("Out"));
op->SetType("set_value_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
// convert std::vector<int64_t > to std::vector<int >
std::vector<int64_t> axes_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("axes")));
std::vector<int64_t> starts_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("starts")));
std::vector<int64_t> ends_int64 = static_cast<std::vector<int64_t>>(
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("ends")));
std::vector<int64_t> decrease_axes_int64 =
static_cast<std::vector<int64_t>>(BOOST_GET_CONST(
std::vector<int64_t>, this->GetAttr("decrease_axes")));
std::vector<int> axes(axes_int64.begin(), axes_int64.end());
std::vector<int> starts(starts_int64.begin(), starts_int64.end());
std::vector<int> ends(ends_int64.begin(), ends_int64.end());
std::vector<int> decrease_axes(decrease_axes_int64.begin(),
decrease_axes_int64.end());
op->SetAttr("axes", axes);
op->SetAttr("starts", starts);
op->SetAttr("ends", ends);
op->SetAttr("decrease_axis", decrease_axes);
op->SetAttr("infer_flags", std::vector<int>({}));
op->SetOutput("Out", this->InputGrad("ValueTensor"));
} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
......@@ -198,6 +185,50 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class SetValueGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "set_value_grad");
auto in_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The dimension of set_value_grad operator's input should be less "
"than 7, but received dimension is %d.",
in_dims.size()));
if (ctx->HasOutput(framework::GradVarName("ValueTensor"))) {
ctx->ShareDim("ValueTensor",
/*->*/ framework::GradVarName("ValueTensor"));
ctx->ShareLoD("ValueTensor",
/*->*/ framework::GradVarName("ValueTensor"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto in_tensor = ctx.Input<Tensor>(framework::GradVarName("Out"));
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
in_tensor->place());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StepsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
} // namespace operators
......@@ -218,6 +249,16 @@ REGISTER_OP_CPU_KERNEL(
ops::SetValueKernel<plat::CPUDeviceContext, double>,
ops::SetValueKernel<plat::CPUDeviceContext, bool>);
REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);
REGISTER_OP_CPU_KERNEL(
set_value_grad,
ops::SetValueGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueGradKernel<plat::CPUDeviceContext, int64_t>,
ops::SetValueGradKernel<plat::CPUDeviceContext, float>,
ops::SetValueGradKernel<plat::CPUDeviceContext, double>,
ops::SetValueGradKernel<plat::CPUDeviceContext, bool>);
REGISTER_OP_VERSION(set_value)
.AddCheckpoint(
R"ROC(
......
......@@ -22,3 +22,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::SetValueKernel<paddle::platform::CUDADeviceContext, float>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, double>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
set_value_grad,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -22,8 +22,10 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -31,6 +33,24 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
inline void GetOffsets(const DDim& big_dim, const DDim& small_dim,
DDim start_offset, int cur_dim,
std::vector<DDim>* offsets) {
if (cur_dim == big_dim.size()) {
offsets->push_back(start_offset);
return;
}
if (small_dim[cur_dim] == big_dim[cur_dim]) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
} else {
for (int i = 0; i < big_dim[cur_dim]; i++) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
start_offset[cur_dim] += 1;
}
}
}
inline std::string GetValueName(framework::proto::VarType::Type data_type) {
std::string value_name;
......@@ -292,5 +312,253 @@ class SetValueKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class SetValueGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
SetValueGradCompute<1>(ctx);
break;
case 2:
SetValueGradCompute<2>(ctx);
break;
case 3:
SetValueGradCompute<3>(ctx);
break;
case 4:
SetValueGradCompute<4>(ctx);
break;
case 5:
SetValueGradCompute<5>(ctx);
break;
case 6:
SetValueGradCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of set_value_grad's input should be less than 7, but "
"received %d.",
rank));
}
}
private:
template <size_t D>
void SetValueGradCompute(const framework::ExecutionContext& context) const {
auto starts = context.Attr<std::vector<int64_t>>("starts");
auto ends = context.Attr<std::vector<int64_t>>("ends");
auto steps = context.Attr<std::vector<int64_t>>("steps");
auto axes_int64 = context.Attr<std::vector<int64_t>>("axes");
std::vector<int> axes(axes_int64.begin(), axes_int64.end());
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto steps_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_steps_tensor =
context.MultiInput<framework::Tensor>("StepsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
}
if (list_new_steps_tensor.size() > 0) {
steps = GetDataFromTensorList<int64_t>(list_new_steps_tensor);
}
auto in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
in->IsInitialized(), true,
platform::errors::PermissionDenied(
"The input of `set_value_grad`(%s) has not been initialized",
framework::GradVarName("Out")));
auto grad_value = context.Output<framework::Tensor>(
framework::GradVarName("ValueTensor"));
auto grad_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
auto in_dims = in->dims();
auto decrease_axis_int64 =
context.Attr<std::vector<int64_t>>("decrease_axes");
std::vector<int> decrease_axis(decrease_axis_int64.begin(),
decrease_axis_int64.end());
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, steps, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), steps.data(), axes.data(),
reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
steps_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts[axis];
ends_indices[axis_index] = ends[axis];
steps_indices[axis_index] = steps[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
math::SetConstant<DeviceContext, T> set_zero;
if (grad_input) {
// Set gradient of `Input`
TensorCopy(*in, context.GetPlace(), grad_input);
auto grad_input_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*grad_input);
framework::Tensor tmp(grad_input->type());
tmp.mutable_data<T>(out_dims, context.GetPlace());
set_zero(dev_ctx, &tmp, static_cast<T>(0));
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
grad_input_t.stridedSlice(starts_indices, ends_indices, steps_indices)
.device(place) = tmp_t;
}
if (grad_value) {
grad_value->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, grad_value, static_cast<T>(0));
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*in);
if (grad_value->dims() == out_dims) {
auto grad_value_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*grad_value);
if (need_reverse) {
framework::Tensor tmp(grad_value->type());
tmp.mutable_data<T>(out_dims, context.GetPlace());
set_zero(dev_ctx, &tmp, static_cast<T>(0));
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
grad_value_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
grad_value_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
}
} else {
int out_dims_size = out_dims.size();
auto grad_value_dims = grad_value->dims();
auto fake_grad_value_dims = out_dims;
// Create an extented shape according to the rules of broadcast.
auto grad_value_dims_size = grad_value_dims.size();
int num_decrease = 0;
int decrease_axis_size = decrease_axis.size();
for (int i = 0; i < out_dims_size; i++) {
if (decrease_axis.end() !=
std::find(decrease_axis.begin(), decrease_axis.end(), i)) {
fake_grad_value_dims[i] = 1;
num_decrease++;
} else if (i < out_dims_size - (grad_value_dims_size +
decrease_axis_size - num_decrease)) {
fake_grad_value_dims[i] = 1;
} else {
auto index_grad =
i - (out_dims_size - (grad_value_dims_size +
decrease_axis_size - num_decrease));
fake_grad_value_dims[i] = grad_value_dims[index_grad];
PADDLE_ENFORCE_EQ((out_dims[i] == grad_value_dims[index_grad]) ||
(grad_value_dims[index_grad] == 1),
true,
platform::errors::InvalidArgument(
"An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
framework::GradVarName("ValueTensor"),
out_dims, grad_value_dims));
}
}
VLOG(3) << "Dimensions of " << framework::GradVarName("ValueTensor")
<< "([" << grad_value_dims << "])is broadcasted into ["
<< fake_grad_value_dims << "].";
auto extent = Eigen::DSizes<Eigen::DenseIndex, D>();
auto offset = out_dims;
for (int i = 0; i < out_dims_size; i++) {
offset[i] = 0;
extent[i] = fake_grad_value_dims[i];
}
std::vector<DDim> offsets;
GetOffsets(out_dims, fake_grad_value_dims, offset, 0, &offsets);
auto grad_value_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::
From(*grad_value, fake_grad_value_dims);
framework::Tensor tmp(grad_value->type());
tmp.mutable_data<T>(out_dims, context.GetPlace());
set_zero(dev_ctx, &tmp, static_cast<T>(0));
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
// accumulate gradient
for (auto offset : offsets) {
grad_value_t.device(place) =
grad_value_t +
tmp_t.slice(framework::EigenDim<D>::From(offset), extent);
}
if (need_reverse) {
framework::Tensor tmp_value(grad_value->type());
tmp_value.mutable_data<T>(fake_grad_value_dims, context.GetPlace());
auto tmp_value_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp_value);
tmp_value_t.device(place) = grad_value_t.reverse(reverse_axis);
grad_value_t.device(place) = tmp_value_t;
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -20,6 +20,8 @@ import unittest
import numpy as np
import paddle
from paddle.fluid.layer_helper import LayerHelper
from functools import reduce
class TestSetValueBase(unittest.TestCase):
......@@ -915,7 +917,317 @@ class TestBackward(unittest.TestCase):
loss.backward()
self.assertTrue(var.grad.shape == x.grad[0, :, 0, 0].shape)
self.assertTrue((var.grad == x.grad[0, :, 0, 0]).all())
#
self.assertTrue((0 == x.grad[0, :, 0, 0]).all())
class TestGradientTruncated(unittest.TestCase):
def test_consistent_with_competitor(self):
paddle.disable_static()
def set_value(t, value):
a = t * t
a[0, 1] = value
y = a * a
return y.sum()
# case 1
array = np.arange(
1, 1 + 2 * 3 * 4, dtype="float32").reshape([1, 2, 1, 3, 1, 4])
value = np.arange(100, 104, dtype="float32").reshape(1, 4)
inps = paddle.to_tensor(array, stop_gradient=False)
value = paddle.to_tensor(value, stop_gradient=False)
loss = set_value(inps, value)
loss.backward()
value_grad = np.array([[600., 606., 612., 618.]])
input_grad = np.array(
[[[[[[4., 32., 108., 256.]], [[500., 864., 1372., 2048.]],
[[2916., 4000., 5324., 6912.]]]],
[[[[0., 0., 0., 0.]], [[0., 0., 0., 0.]], [[0., 0., 0., 0.]]]]]])
self.assertTrue(
np.array_equal(inps.grad.numpy(), input_grad),
msg="The gradient of value should be \n{},\n but reveived {}".
format(input_grad, inps.grad.numpy()))
self.assertTrue(
np.array_equal(value.grad.numpy(), value_grad),
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value.grad.numpy()))
# case 2
array = np.arange(1, 2 * 3 * 4 + 1, dtype="float32").reshape([4, 2, 3])
value = np.arange(100, 100 + 1, dtype="float32")
inps2 = paddle.to_tensor(array, stop_gradient=False)
value2 = paddle.to_tensor(value, stop_gradient=False)
loss = set_value(inps2, value2)
loss.backward()
value_grad2 = np.array([600.])
input_grad2 = np.array(
[[[4., 32., 108.], [0., 0., 0.]], [[1372., 2048., 2916.],
[4000., 5324., 6912.]],
[[8788., 10976., 13500.], [16384., 19652., 23328.]],
[[27436., 32000., 37044.], [42592., 48668., 55296.]]])
self.assertTrue(
np.array_equal(inps2.grad.numpy(), input_grad2),
msg="The gradient of value should be \n{},\n but reveived {}".
format(input_grad, inps2.grad.numpy()))
self.assertTrue(
np.array_equal(value2.grad.numpy(), value_grad2),
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value2.grad.numpy()))
# case 3
def set_value3(t, value):
a = t * t
a[0, :, 0, :] = value
y = a * a
return y.sum()
array = np.arange(
1, 1 + 2 * 3 * 4, dtype="float32").reshape([4, 3, 1, 1, 2, 1])
value = np.arange(100, 100 + 2, dtype="float32").reshape(1, 2, 1)
inps = paddle.to_tensor(array, stop_gradient=False)
value = paddle.to_tensor(value, stop_gradient=False)
loss = set_value3(inps, value)
loss.backward()
value_grad = np.array([[[600.], [606.]]])
input_grad = np.array(
[[[[[[0.], [0.]]]], [[[[0.], [0.]]]], [[[[0.], [0.]]]]],
[[[[[1372.], [2048.]]]], [[[[2916.], [4000.]]]],
[[[[5324.], [6912.]]]]], [[[[[8788.], [10976.]]]], [[[[13500.],
[16384.]]]],
[[[[19652.], [23328.]]]]],
[[[[[27436.], [32000.]]]], [[[[37044.], [42592.]]]],
[[[[48668.], [55296.]]]]]])
self.assertTrue(
np.array_equal(inps.grad.numpy(), input_grad),
msg="The gradient of value should be \n{},\n but reveived {}".
format(input_grad, inps.grad.numpy()))
self.assertTrue(
np.array_equal(value.grad.numpy(), value_grad),
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value.grad.numpy()))
#case 4: step >0
def set_value4(t, value):
a = t * t
a[0, :, 0, ::3] = value
y = a * a
return y.sum()
array = np.arange(
1, 1 + 2 * 3 * 4, dtype="float32").reshape([2, 3, 1, 4, 1])
value = np.arange(100, 100 + 2, dtype="float32").reshape(1, 2, 1)
inps = paddle.to_tensor(array, stop_gradient=False)
value = paddle.to_tensor(value, stop_gradient=False)
loss = set_value4(inps, value)
loss.backward()
value_grad = np.array([[[600.], [606.]]])
input_grad = np.array([[[[[0.], [32.], [108.],
[0.]]], [[[0.], [864.], [1372.], [0.]]],
[[[0.], [4000.], [5324.], [0.]]]],
[[[[8788.], [10976.], [13500.], [16384.]]],
[[[19652.], [23328.], [27436.], [32000.]]],
[[[37044.], [42592.], [48668.], [55296.]]]]])
self.assertTrue(
np.array_equal(inps.grad.numpy(), input_grad),
msg="The gradient of value should be \n{},\n but reveived {}".
format(input_grad, inps.grad.numpy()))
self.assertTrue(
np.array_equal(value.grad.numpy(), value_grad),
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value.grad.numpy()))
# case 5:a[0].shape==value.shape
def set_value5(t, value):
a = t * t
a[0] = value
y = a * a
return y.sum()
array = np.arange(1, 1 + 2 * 3 * 4, dtype="float32").reshape([2, 3, 4])
value = np.arange(100, 100 + 12, dtype="float32").reshape(3, 4)
inps = paddle.to_tensor(array, stop_gradient=False)
value = paddle.to_tensor(value, stop_gradient=False)
loss = set_value5(inps, value)
loss.backward()
value_grad = np.array([[200., 202., 204., 206.],
[208., 210., 212., 214.],
[216., 218., 220., 222.]])
input_grad = np.array([[[0., 0., 0., 0.], [0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[8788., 10976., 13500., 16384.],
[19652., 23328., 27436., 32000.],
[37044., 42592., 48668., 55296.]]])
self.assertTrue(
np.array_equal(inps.grad.numpy(), input_grad),
msg="The gradient of value should be \n{},\n but reveived {}".
format(input_grad, inps.grad.numpy()))
self.assertTrue(
np.array_equal(value.grad.numpy(), value_grad),
msg="The gradient of input should be \n{},\n but reveived {}".
format(value_grad, value.grad.numpy()))
def test_static_graph(self):
paddle.enable_static()
to_string = lambda x, i, : x + '_' + str(i)
numel = lambda input_shape: reduce(lambda x, y: x * y, input_shape)
def op1(x):
value = paddle.fluid.layers.fill_constant([1], "float32", 1)
# test stop_gradient
value.stop_gradient = True
x.stop_gradient = False
start = paddle.fluid.layers.fill_constant(
[1], "int32", 5, force_cpu=True)
end = paddle.fluid.layers.fill_constant(
[1], "int32", 0, force_cpu=True)
step = paddle.fluid.layers.fill_constant(
[1], "int32", -2, force_cpu=True)
inputs = {
'Input': x,
'ValueTensor': value,
'StartsTensorList': [start, ],
'EndsTensorList': [end, ],
'StepsTensorList': [step, ]
}
helper = LayerHelper("set_value")
y = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': y},
attrs={'axes': [0]})
return y, value
def op2(x):
value = paddle.fluid.layers.fill_constant([1, 3, 2], "float32", 1)
# test stop_gradient
value.stop_gradient = False
x.stop_gradient = False
attrs = {
'axes': [0],
'starts': [6],
'ends': [0],
'steps': [-4],
'decrease_axes': [],
'none_axes': [],
'dtype': paddle.float32
}
inputs = {'Input': x, 'ValueTensor': value}
helper = LayerHelper("set_value")
y = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': y},
attrs=attrs)
return y, value
def op3(x):
value = paddle.fluid.layers.fill_constant([1], "float32", 1)
x.stop_gradient = True
value.stop_gradient = False
start = paddle.fluid.layers.fill_constant(
[1], "int32", 0, force_cpu=True)
end = paddle.fluid.layers.fill_constant(
[1], "int32", 5, force_cpu=True)
step = paddle.fluid.layers.fill_constant(
[1], "int32", 3, force_cpu=True)
inputs = {
'Input': x,
'ValueTensor': value,
'StartsTensorList': [start, ],
'EndsTensorList': [end, ],
'StepsTensorList': [step, ]
}
helper = LayerHelper("set_value")
y = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': y},
attrs={'axes': [0]})
return y, value
def set_value(array, i, op):
name_x = to_string('x', i)
x = paddle.static.data(
name=name_x, shape=array.shape, dtype='float32')
# set_value_op in __get/setitem__ is an inplace operation.
# When `input.stop_gradient = True` and `value.stop_gradient = False`,
# set_value_grad_op will not be run during backward.
y, value = op(x)
y2 = y + 1
loss = paddle.fluid.layers.reduce_sum(y2)
sgd = paddle.optimizer.Adam()
sgd.minimize(loss)
place = paddle.fluid.CPUPlace(
) if not paddle.fluid.core.is_compiled_with_cuda(
) else paddle.fluid.CUDAPlace(0)
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = []
if not x.stop_gradient:
fetch_list.append(x.grad_name)
if not value.stop_gradient:
fetch_list.append(value.grad_name)
out = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list)
return out
input_shape = [7, 6, 5, 4, 3, 2]
array = np.arange(
0, numel(input_shape), dtype="float32").reshape(input_shape)
for i in range(len(input_shape)):
program = paddle.static.Program()
with paddle.static.program_guard(program):
out1 = set_value(array, i, op1)
self.assertTrue((out1[0][5:0:-2] == 0).all())
if len(array.shape) > 2:
program2 = paddle.static.Program()
with paddle.static.program_guard(program2):
out2 = set_value(array, i, op2)
self.assertTrue((out2[0][6:0:-4] == 0).all())
program3 = paddle.static.Program()
with paddle.static.program_guard(program3):
out3 = set_value(array, i, op3)
self.assertTrue((numel(out1[0][0:5:3].shape) == out3[0]).all())
array = array[0]
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册