未验证 提交 80355949 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat]Support LoDTensorArray for slice op (#23091)

* Support LoDTensorArray for slice op.
* Support read elements of list in dygraph_to_static
* Fix infershape add test for infershape.
* Support Tensor for Attr(starts) and Attr(ends). 
* Use new interfaces in VarTypeInference. 
上级 84cf5db8
...@@ -33,13 +33,32 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -33,13 +33,32 @@ class SliceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Out) of slice op should not be null."); "Output (Out) of slice op should not be null.");
auto x_var_type = ctx->GetInputsVarType("Input")[0];
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
PADDLE_ENFORCE_EQ(axes.size(), 1,
platform::errors::InvalidArgument(
"The size of axes must be 1 when the Input of "
"SliceOp is LoDTensorArray, "
"but received %d.",
axes.size()));
if (ctx->IsRuntime()) {
// If the var type of input is LOD_TENSOR_ARRAY,
// the output shape is determined by SliceKernel:Compute in runtime.
return;
} else {
// NOTE: A better way is needed to get accurate dims of tensor array.
// The resulted dim of GetInputDim("Input") is the dim of the
// last item written into TensorArray "Input". Maybe it's a bug to fix.
ctx->SetOutputDim("Out", ctx->GetInputDim("Input"));
return;
}
}
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(in_dims.size(), 7, PADDLE_ENFORCE_LT(in_dims.size(), 7,
"The rank of input should be less than 7."); "The rank of input should be less than 7.");
framework::DDim out_dims(in_dims); framework::DDim out_dims(in_dims);
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto starts = ctx->Attrs().Get<std::vector<int>>("starts"); auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends"); auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags"); auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
...@@ -146,6 +165,25 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -146,6 +165,25 @@ class SliceOp : public framework::OperatorWithKernel {
} }
}; };
class SliceOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = "Input";
auto out_name = "Out";
auto decrease_axis = ctx->GetAttr("decrease_axis");
auto not_decrease = boost::get<std::vector<int>>(decrease_axis).size() == 0;
if (not_decrease) {
// The default type of out is LoDTensor.
// However, if no axis is decreased and the type of input is not
// LoDTensor, the type of out should be the same as input.
// For example, input is a LoDTensorArray and no axis is decreased, the
// output should be a LoDTensorArray.
ctx->SetOutputType(out_name, ctx->GetInputType(x_name));
ctx->SetOutputDataType(out_name, ctx->GetInputDataType(x_name));
}
}
};
class SliceOpMaker : public framework::OpProtoAndCheckerMaker { class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -236,6 +274,14 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -236,6 +274,14 @@ class SliceOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null"); PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_var_type = ctx->GetInputsVarType("Input")[0];
if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// If the var type of input is LOD_TENSOR_ARRAY,
// the output shape is determined by SliceGradKernel:Compute in runtime.
if (ctx->IsRuntime()) {
return;
}
}
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input"); auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
...@@ -262,6 +308,21 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -262,6 +308,21 @@ class SliceOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SliceOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x = "Input";
auto d_out = framework::GradVarName("Out");
auto out = framework::GradVarName("Input");
// The types of grad_input and input should always be the same.
// The default type of out is LoDTensor, but the type of input can be
// LoDTensor or LoDTensorArray,
// so set the type of both to be the same.
ctx->SetOutputType(out, ctx->GetInputType(x));
ctx->SetOutputDataType(out, ctx->GetInputDataType(d_out));
}
};
template <typename T> template <typename T>
class SliceOpGradMaker : public framework::SingleGradOpMaker<T> { class SliceOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -324,11 +385,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SliceOpGradNoNeedBufferVarsInference, ...@@ -324,11 +385,13 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SliceOpGradNoNeedBufferVarsInference,
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
ops::SliceOpGradMaker<paddle::framework::OpDesc>, ops::SliceOpGradMaker<paddle::framework::OpDesc>,
ops::SliceOpGradMaker<paddle::imperative::OpBase>); ops::SliceOpGradMaker<paddle::imperative::OpBase>,
ops::SliceOpVarTypeInference);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad,
ops::SliceDoubleOpGradMaker<paddle::framework::OpDesc>, ops::SliceDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::SliceDoubleOpGradMaker<paddle::imperative::OpBase>, ops::SliceDoubleOpGradMaker<paddle::imperative::OpBase>,
ops::SliceOpGradNoNeedBufferVarsInference); ops::SliceOpGradNoNeedBufferVarsInference,
ops::SliceOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>, slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -58,7 +59,12 @@ template <typename DeviceContext, typename T> ...@@ -58,7 +59,12 @@ template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> { class SliceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<framework::Tensor>("Input")->dims().size(); const framework::Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
int rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
SliceCompute<1>(ctx); SliceCompute<1>(ctx);
...@@ -86,17 +92,17 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -86,17 +92,17 @@ class SliceKernel : public framework::OpKernel<T> {
void SliceCompute(const framework::ExecutionContext& context) const { void SliceCompute(const framework::ExecutionContext& context) const {
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input"); const framework::Variable* input_var = context.InputVar("Input");
auto out = context.Output<framework::Tensor>("Out"); framework::Variable* out_var = context.OutputVar("Out");
auto out_dims = out->dims(); bool input_is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
auto in_dims = in->dims(); bool out_is_tensor_array = out_var->IsType<framework::LoDTensorArray>();
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts"); auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends"); auto ends = context.Attr<std::vector<int>>("ends");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis"); auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags"); auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto list_new_ends_tensor = auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList"); context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor = auto list_new_starts_tensor =
...@@ -109,7 +115,6 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -109,7 +115,6 @@ class SliceKernel : public framework::OpKernel<T> {
if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) { if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) {
need_infer = true; need_infer = true;
} }
if (need_infer) { if (need_infer) {
if (context.HasInput("StartsTensor")) { if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor"); auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
...@@ -117,17 +122,70 @@ class SliceKernel : public framework::OpKernel<T> { ...@@ -117,17 +122,70 @@ class SliceKernel : public framework::OpKernel<T> {
} else if (list_new_starts_tensor.size() > 0) { } else if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor); starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} }
PADDLE_ENFORCE_EQ(
starts.size(), axes.size(),
"The size of starts must be equal to the size of axes.");
if (context.HasInput("EndsTensor")) { if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor"); auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor); ends = get_new_data_from_tensor(ends_tensor);
} else if (list_new_ends_tensor.size() > 0) { } else if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor); ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} }
PADDLE_ENFORCE_EQ(ends.size(), axes.size(), }
"The size of ends must be equal to the size of axes."); PADDLE_ENFORCE_EQ(
starts.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(
ends.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
if (input_is_tensor_array) {
auto in_array = context.Input<framework::LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int in_size = in_array->size();
int start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, in_size);
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d.",
end, start));
int out_size = end - start;
if (out_is_tensor_array) {
auto out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(out_size);
for (int i = 0; i < out_size; ++i) {
auto* out_tensor = &out_array->at(i);
auto in_tensor = in_array->at(i + start);
out_tensor->set_lod(in_tensor.lod());
if (in_tensor.memory_size() > 0) {
TensorCopy(in_tensor, context.GetPlace(), out_tensor);
} else {
VLOG(10)
<< "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
<< i << "].";
}
}
} else {
auto out = context.Output<framework::Tensor>("Out");
auto in_tensor = in_array->at(start);
TensorCopy(in_tensor, context.GetPlace(), out);
}
return;
}
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
auto out_dims = out->dims();
auto in_dims = in->dims();
if (need_infer) {
out_dims = in_dims; out_dims = in_dims;
int dim_value, start, end; int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
...@@ -233,7 +291,12 @@ template <typename DeviceContext, typename T> ...@@ -233,7 +291,12 @@ template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> { class SliceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
size_t rank = ctx.Input<framework::Tensor>("Input")->dims().size(); const framework::Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<framework::LoDTensorArray>();
size_t rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
SliceCompute<1>(ctx); SliceCompute<1>(ctx);
...@@ -261,17 +324,9 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -261,17 +324,9 @@ class SliceGradKernel : public framework::OpKernel<T> {
void SliceCompute(const framework::ExecutionContext& context) const { void SliceCompute(const framework::ExecutionContext& context) const {
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts"); auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends"); auto ends = context.Attr<std::vector<int>>("ends");
auto list_new_ends_tensor = auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList"); context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor = auto list_new_starts_tensor =
...@@ -290,6 +345,66 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -290,6 +345,66 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor"); auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor); ends = get_new_data_from_tensor(ends_tensor);
} }
framework::Variable* d_input_var =
context.OutputVar(framework::GradVarName("Input"));
const framework::Variable* d_out_var =
context.InputVar(framework::GradVarName("Out"));
bool d_input_is_tensor_array =
d_input_var->IsType<framework::LoDTensorArray>();
bool d_out_is_tensor_array = d_out_var->IsType<framework::LoDTensorArray>();
if (d_input_is_tensor_array) {
auto* input_array = context.Input<framework::LoDTensorArray>("Input");
auto* d_input_array = context.Output<framework::LoDTensorArray>(
framework::GradVarName("Input"));
int d_in_size = input_array->size();
d_input_array->resize(d_in_size);
// If the input is LoDTensorArray, the rank of input is 1.
// So only use the 0th element of starts.
int start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0];
start = std::max(start, 0);
// set zero
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(context.GetPlace());
T value = 0.0;
math::SetConstant<DeviceContext, T> functor;
for (int i = 0; i < d_in_size; ++i) {
auto dim = input_array->at(i).dims();
d_input_array->at(i).Resize(dim);
d_input_array->at(i).mutable_data<T>(context.GetPlace());
functor(reinterpret_cast<const DeviceContext&>(dev_ctx),
&d_input_array->at(i), static_cast<T>(value));
}
if (d_out_is_tensor_array) {
auto* d_out_array = context.Input<framework::LoDTensorArray>(
framework::GradVarName("Out"));
int d_out_size = d_out_array->size();
for (int i = 0; i < d_out_size; ++i) {
TensorCopy(d_out_array->at(i), context.GetPlace(),
&(d_input_array->at(start + i)));
}
} else {
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
TensorCopy(*d_out, context.GetPlace(), &(d_input_array->at(start)));
}
return;
}
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(context.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis"); auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
if (decrease_axis.size() > 0) { if (decrease_axis.size() > 0) {
......
...@@ -42,28 +42,29 @@ def test_slice_in_if(x): ...@@ -42,28 +42,29 @@ def test_slice_in_if(x):
shape=[1, 2], value=9, dtype="int64")) shape=[1, 2], value=9, dtype="int64"))
if x.numpy()[0] > 0: if x.numpy()[0] > 0:
a[0] = x a[0] = x
return a out = a[0:]
return out
def test_slice_in_while_loop(x, iter_num): def test_slice_in_while_loop(x, iter_num):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant( iter_num_var = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32") shape=[1], value=iter_num, dtype="int32")
a = [] a = []
i = 0 i = 0
# Note: `i < iter_num` can't be supported in dygraph mode now, # Note: `i < iter_num` can't be supported in dygraph mode now,
# but PR22892 is fixing it https://github.com/PaddlePaddle/Paddle/pull/22892. # but PR22892 is fixing it https://github.com/PaddlePaddle/Paddle/pull/22892.
# If PR22892 merged, change `i < iter_num.numpy()[0]` to `i < iter_num`. # If PR22892 merged, change `i < iter_num.numpy()[0]` to `i < iter_num`.
while i < iter_num.numpy()[0]: while i < iter_num_var.numpy()[0]:
a.append(x) a.append(x)
i += 1 i += 1
i = 0 i = 0
while i < iter_num.numpy()[0]: while i < iter_num_var.numpy()[0]:
a[i] = fluid.layers.fill_constant(shape=[2], value=2, dtype="float32") a[i] = fluid.layers.fill_constant(shape=[2], value=2, dtype="float32")
i += 1 i += 1
out = a[0:iter_num]
return a return out
def test_slice_in_for_loop(x, iter_num): def test_slice_in_for_loop(x, iter_num):
...@@ -79,7 +80,8 @@ def test_slice_in_for_loop(x, iter_num): ...@@ -79,7 +80,8 @@ def test_slice_in_for_loop(x, iter_num):
for i in range(iter_num): for i in range(iter_num):
a[i] = x a[i] = x
return a out = a[2]
return out
class TestSliceWithoutControlFlow(unittest.TestCase): class TestSliceWithoutControlFlow(unittest.TestCase):
...@@ -148,6 +150,8 @@ class TestSliceInWhileLoop(TestSliceWithoutControlFlow): ...@@ -148,6 +150,8 @@ class TestSliceInWhileLoop(TestSliceWithoutControlFlow):
def run_dygraph_mode(self): def run_dygraph_mode(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var_res = self.dygraph_func(self.input, self.iter_num) var_res = self.dygraph_func(self.input, self.iter_num)
if not isinstance(var_res, list):
var_res = [var_res]
numpy_res = [ele.numpy() for ele in var_res] numpy_res = [ele.numpy() for ele in var_res]
return numpy_res return numpy_res
...@@ -173,6 +177,15 @@ class TestSliceInForLoop(TestSliceInWhileLoop): ...@@ -173,6 +177,15 @@ class TestSliceInForLoop(TestSliceInWhileLoop):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_slice_in_for_loop self.dygraph_func = test_slice_in_for_loop
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_out = dygraph_to_static_func(self.dygraph_func)(
self.input, self.iter_num)
exe = fluid.Executor(self.place)
numpy_res = exe.run(main_program, fetch_list=static_out)
return numpy_res
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
# Situation 1: starts(list, no tensor), ends(list, no tensor) # Situation 1: starts(list, no tensor), ends(list, no tensor)
...@@ -528,5 +529,85 @@ class TestSliceAPI(unittest.TestCase): ...@@ -528,5 +529,85 @@ class TestSliceAPI(unittest.TestCase):
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1]) assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1])
class TestSliceApiWithLoDTensorArray(unittest.TestCase):
def setUp(self):
self.shape = (3, 4)
self.data = np.random.random(size=self.shape).astype('float32')
self.idx = 0
self.start = 0
self.end = 2
self.axis = 1
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
def set_program_and_run(self, main_program, case_num):
with fluid.program_guard(main_program):
x = [
fluid.data(
name='x0', shape=self.shape, dtype="float32"), fluid.data(
name='x1', shape=self.shape, dtype="float32"),
fluid.data(
name='x2', shape=self.shape, dtype="float32")
]
for each_x in x:
each_x.stop_gradient = False
arr = layers.create_array(dtype="float32")
for i in range(3):
idx = layers.array_length(arr)
arr = layers.array_write(x=x[i], i=idx, array=arr)
if case_num == 1:
self.sliced_arr = output = arr[0]
elif case_num == 2:
end = fluid.layers.array_length(arr) - 1
end = fluid.layers.cast(end, "int32")
self.sliced_arr = slice_arr = arr[self.start:end]
output, _ = fluid.layers.tensor_array_to_tensor(
slice_arr, axis=self.axis, use_stack=True)
loss = fluid.layers.reduce_sum(output)
fluid.backward.append_backward(loss)
g_vars = list(
map(main_program.global_block().var,
[each_x.name + "@GRAD" for each_x in x]))
self.out, self.g_x0, self.g_x1, self.g_x2 = \
self.exe.run(main_program,
feed = {'x0': self.data,
'x1': self.data,
'x2': self.data},
fetch_list=[output] + g_vars)
def test_case_1(self):
main_program = fluid.Program()
self.set_program_and_run(main_program, 1)
self.assertTrue(self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR)
self.assertEqual(self.sliced_arr.shape, self.shape)
self.assertTrue(np.array_equal(self.out, self.data))
self.assertTrue(np.array_equal(self.g_x0, np.ones_like(self.data)))
self.assertTrue(np.array_equal(self.g_x1, np.zeros_like(self.data)))
self.assertTrue(np.array_equal(self.g_x2, np.zeros_like(self.data)))
def test_case_2(self):
main_program = fluid.Program()
self.set_program_and_run(main_program, 2)
self.assertTrue(
self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY)
self.assertEqual(self.sliced_arr.shape, self.shape)
self.assertTrue(
np.array_equal(
self.out, np.stack(
[self.data, self.data], axis=self.axis)))
self.assertTrue(np.array_equal(self.g_x0, np.ones_like(self.data)))
self.assertTrue(np.array_equal(self.g_x1, np.ones_like(self.data)))
self.assertTrue(np.array_equal(self.g_x2, np.zeros_like(self.data)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册