提交 88628016 编写于 作者: L liym27 提交者: Aurelius84

add tensor(tensor and tensor in list) support for argument starts and ends in slice op; (#19208)

add support parameter inference when arguments starts or ends is a list containing integer and tensor variable;
test=develop,test=document_preview

improve slice op according to review(from hongyu). test=develop

fix slice op according to review: infer_flags, test=develop

fix slice op: improve overload operator __getitem__ to support attrs(starts and ends) are Variable.
test=develop,test=document_preview

fix test_slice_op: add TestSliceOp_decs_dim_6 to resolve conflict with test_slice_ngraph_op. test=develop

add stop_gradient=True when attr(starts) or attr(ends) is tensor Variable.
test=develop,test=document_preview
上级 e9e3c087
......@@ -243,7 +243,7 @@ paddle.fluid.layers.gaussian_random (ArgSpec(args=['shape', 'mean', 'std', 'seed
paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0, 'float32')), ('document', 'c39b647b6cf08e058d96ee503d5284fe'))
paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2'))
paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310'))
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '3ca6a761570d86e303e473afba99bb49'))
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff'))
paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b'))
paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3'))
paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe'))
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/slice_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
namespace paddle {
......@@ -26,44 +27,81 @@ class SliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input (Input) of slice op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output (Out) of slice op should not be null.");
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input (Input) of slice op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Out) of slice op should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(in_dims.size() < 7,
"The rank of input should be less than 7.");
PADDLE_ENFORCE_LT(in_dims.size(), 7,
"The rank of input should be less than 7.");
framework::DDim out_dims(in_dims);
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
PADDLE_ENFORCE_EQ(starts.size(), ends.size());
PADDLE_ENFORCE_EQ(starts.size(), axes.size());
auto starts_size = starts.size();
auto ends_size = ends.size();
if (infer_flags.empty()) {
// Initialize infer_flags with 1.
// To be compatible with other op tests in which infer_flags is not set.
infer_flags = std::vector<int>(axes.size(), 1);
}
if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList");
PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
"StartsTensorList size can't be zero");
starts_size = StartsTensorList.size();
}
if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList");
PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
"EndsTensorList size can't be zero");
ends_size = EndsTensorList.size();
}
if (ctx->HasInput("StartsTensor") == false) {
PADDLE_ENFORCE_EQ(
starts_size, axes.size(),
"The size of starts must be equal to the size of axes.");
}
if (ctx->HasInput("EndsTensor") == false) {
PADDLE_ENFORCE_EQ(ends_size, axes.size(),
"The size of ends must be equal to the size of axes.");
}
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
// start = std::min(start, dim_value);
end = std::min(end, dim_value);
// start = std::min(start, end);
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
out_dims[axes[i]] = end - start;
PADDLE_ENFORCE_LT(static_cast<int>(axes[i]), in_dims.size(),
"The index of dimension in axes must be less "
"than the size of input shape.");
if (infer_flags[i] == -1) {
out_dims[axes[i]] = -1;
} else {
// infer out_dim shape
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
out_dims[axes[i]] = end - start;
}
}
}
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime()) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
}
......@@ -81,7 +119,6 @@ class SliceOp : public framework::OperatorWithKernel {
out_dims = framework::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out");
......@@ -90,28 +127,67 @@ class SliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.Input<Tensor>("Input")->place());
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "Tensor of data to extract slices from.");
AddInput("Input", "(Tensor) Tensor of data to extract slices from.");
AddInput("StartsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StartsTensor, StartsTensorList "
"and attr(starts).")
.AsDispensable();
AddInput("EndsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of EndsTensor, EndsTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StartsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(starts).")
.AsDuplicable()
.AsDispensable();
AddInput(
"EndsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(ends).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out", "Sliced data tensor.");
AddAttr<std::vector<int>>(
"axes",
"(list<int>) Axes that `starts` and `ends` apply to. It's optional."
"If not present, will be treated as [0, 1, ..., len(`starts`) - 1].");
AddAttr<std::vector<int>>(
"starts",
"(list<int>) Starting indices of corresponding axis in `axes`");
"(list<int>) Starting indices of corresponding axis in `axes`")
.SetDefault({});
AddAttr<std::vector<int>>(
"ends", "(list<int>) Ending indices of corresponding axis in `axes`.")
.SetDefault({});
AddAttr<std::vector<int>>(
"ends",
"(list<int>) Starting indices of corresponding axis in `axes`.");
"infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({});
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
.SetDefault({});
AddComment(R"DOC(
......@@ -155,22 +231,33 @@ class SliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
......@@ -180,8 +267,12 @@ class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
auto *bind = new framework::OpDesc();
bind->SetInput("Input", Input("Input"));
bind->SetInput("StartsTensor", Input("StartsTensor"));
bind->SetInput("EndsTensor", Input("EndsTensor"));
bind->SetInput("StartsTensorList", Input("StartsTensorList"));
bind->SetInput("EndsTensorList", Input("EndsTensorList"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs());
......
......@@ -65,6 +65,16 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
auto list_new_starts_tensor =
ctx.MultiInput<framework::Tensor>("StartsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (ctx.HasInput("StartsTensor")) {
auto* starts_tensor = ctx.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
for (size_t i = 0; i < starts.size(); ++i) {
if (starts[i] < 0) {
starts[i] += in_dims[axes[i]];
......
......@@ -20,6 +20,39 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int> get_new_data_from_tensorlist(
const std::vector<const Tensor*>& list_new_data_tensor) {
// get tensor from
std::vector<int> vec_new_data;
for (size_t i = 0; i < list_new_data_tensor.size(); ++i) {
auto tensor = list_new_data_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_data.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_data.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_data;
}
inline std::vector<int> get_new_data_from_tensor(
const Tensor* new_data_tensor) {
std::vector<int> vec_new_data;
auto* new_data = new_data_tensor->data<int>();
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<int>();
}
vec_new_data =
std::vector<int>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> {
......@@ -58,8 +91,90 @@ class SliceKernel : public framework::OpKernel<T> {
auto out_dims = out->dims();
auto in_dims = in->dims();
// resize out_dims
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
bool need_infer = false;
if (context.HasInput("StartsTensor") || context.HasInput("EndsTensor")) {
need_infer = true;
}
if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) {
need_infer = true;
}
if (need_infer) {
if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
} else if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
}
PADDLE_ENFORCE_EQ(
starts.size(), axes.size(),
"The size of starts must be equal to the size of axes.");
if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
} else if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
}
PADDLE_ENFORCE_EQ(ends.size(), axes.size(),
"The size of ends must be equal to the size of axes.");
out_dims = in_dims;
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
// when end = start+1 and start == -1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret =
std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = 10000000;
}
}
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
out_dims[axes[i]] = end - start;
}
}
out->Resize(out_dims);
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
}
}
// resize out_dims
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
std::vector<int> vec_origin_out_shape(decrease_axis.size(), 1);
......@@ -85,8 +200,6 @@ class SliceKernel : public framework::OpKernel<T> {
}
out->mutable_data<T>(context.GetPlace());
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>();
......@@ -157,6 +270,26 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto in_dims = d_input->dims();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
if (decrease_axis.size() > 0) {
......
......@@ -902,6 +902,21 @@ class Variable(object):
slice_end = []
reverse_axis = []
def fill_constant(shape, dtype, value, force_cpu=False, out=None):
self.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu()
},
stop_gradient=True)
out.stop_gradient = True
return out
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
......@@ -927,17 +942,81 @@ class Variable(object):
slice_start.append(start)
slice_end.append(end)
else:
# int
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
if isinstance(slice_item, Variable):
temp_1 = self.block.create_var(dtype='int32')
fill_constant([1], 'int32', 1, force_cpu=True, out=temp_1)
temp_end = self.block.create_var(dtype='int32')
self.block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = self.block.create_var(dtype='int32')
fill_constant(
[1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': [self]}
attrs = {
'axes': slice_axis,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
}
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if not contain_var(slice_start):
attrs['starts'] = slice_start
else:
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if not contain_var(slice_end):
attrs['ends'] = slice_end
else:
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# infer_flags
attrs['infer_flags'] = infer_flags
out = self
if len(slice_axis) > 0:
# append slice_op here
slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_slice"),
......@@ -945,14 +1024,9 @@ class Variable(object):
self.block.append_op(
type="slice",
inputs={'Input': [out]},
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs={
'axes': slice_axis,
'starts': slice_start,
'ends': slice_end,
'decrease_axis': decrease_axis
})
attrs=attrs)
out = slice_out_var
......
......@@ -10635,8 +10635,8 @@ def slice(input, axes, starts, ends):
Args:
input (Variable): ${input_comment}.
axes (List): ${axes_comment}
starts (List): ${starts_comment}
ends (List): ${ends_comment}
starts (List|Variable): ${starts_comment}
ends (List|Variable): ${ends_comment}
Returns:
out (Variable): ${out_comment}
......@@ -10645,27 +10645,105 @@ def slice(input, axes, starts, ends):
.. code-block:: python
import paddle.fluid as fluid
starts = [1, 0, 2]
ends = [3, 3, 4]
axes = [0, 1, 2]
input = fluid.layers.data(
name="input", shape=[3, 4, 5, 6], dtype='float32')
out = fluid.layers.slice(input, axes=axes, starts=starts, ends=ends)
# example 1:
# attr starts is a list which doesn't contain tensor Variable.
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]
sliced_1 = fluid.layers.slice(input, axes=axes, starts=starts, ends=ends)
# example 2:
# attr starts is a list which contain tensor Variable.
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
sliced_2 = fluid.layers.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends)
"""
if not isinstance(starts, (list, tuple, Variable)):
raise ValueError(
"Input starts must be an Variable, python list or tuple.")
if not isinstance(ends, (list, tuple, Variable)):
raise ValueError(
"Input ends must be an Variable, python list or tuple.")
helper = LayerHelper('slice', **locals())
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': input}
attrs = {'axes': axes}
infer_flags = list(1 for i in range(len(axes)))
if in_dygraph_mode():
inputs = {'Input': input}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'infer_flags': infer_flags
}
else:
# starts
if isinstance(starts, Variable):
starts.stop_gradient = True
inputs['StartsTensor'] = starts
infer_flags = list(-1 for i in range(len(axes)))
elif isinstance(starts, (list, tuple)):
attrs['starts'] = []
if not contain_var(starts):
attrs['starts'] = starts
else:
inputs['StartsTensorList'] = get_new_list_tensor(starts)
for i, dim in enumerate(starts):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if isinstance(ends, Variable):
ends.stop_gradient = True
inputs['EndsTensor'] = ends
infer_flags = list(-1 for i in range(len(axes)))
elif isinstance(ends, (list, tuple)):
attrs['ends'] = []
if not contain_var(ends):
attrs['ends'] = ends
else:
inputs['EndsTensorList'] = get_new_list_tensor(ends)
for i, dim in enumerate(ends):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# infer_flags
attrs['infer_flags'] = infer_flags
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op(
type='slice',
inputs={'Input': input},
outputs={'Out': out},
attrs={'axes': axes,
'starts': starts,
'ends': ends})
type='slice', inputs=inputs, attrs=attrs, outputs={'Out': out})
return out
......
......@@ -18,8 +18,11 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
# Situation 1: starts(list, no tensor), ends(list, no tensor)
# 1.1 without attr(decrease)
class TestSliceOp(OpTest):
def setUp(self):
self.op_type = "slice"
......@@ -29,7 +32,8 @@ class TestSliceOp(OpTest):
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends
'ends': self.ends,
'infer_flags': self.infer_flags
}
def config(self):
......@@ -37,6 +41,7 @@ class TestSliceOp(OpTest):
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.infer_flags = [1, 1, 1]
self.out = self.input[1:3, 0:3, 2:4, :]
def test_check_output(self):
......@@ -46,6 +51,27 @@ class TestSliceOp(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase1(TestSliceOp):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 2]
self.infer_flags = [1, 1, 1]
self.out = self.input[-3:3, 0:100, 2:-1, :]
class TestCase2(TestSliceOp):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.infer_flags = [1, 1, 1]
self.out = self.input[-3:3, 0:100, :, 2:-1]
# 1.2 with attr(decrease)
class TestSliceOp_decs_dim(OpTest):
def setUp(self):
self.op_type = "slice"
......@@ -56,6 +82,7 @@ class TestSliceOp_decs_dim(OpTest):
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
......@@ -65,6 +92,7 @@ class TestSliceOp_decs_dim(OpTest):
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0]
self.infer_flags = [1, 1, 1]
self.out = self.input[1, 0:3, 2:4, :]
def test_check_output(self):
......@@ -74,26 +102,91 @@ class TestSliceOp_decs_dim(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_2(OpTest):
class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.infer_flags = [1, 1, 1]
self.out = self.input[1, 0, 2:4, :]
class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1, 0, 2]
self.ends = [1000000, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.infer_flags = [1, 1, 1]
self.out = self.input[-1, 0, 2:4, :]
class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim):
def config(self):
self.input = np.random.random([3, 4, 5, 7]).astype("float32")
self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4]
self.axes = [0, 1, 2, 3]
self.decrease_axis = [0, 1, 2, 3]
self.infer_flags = [1, 1, 1]
self.out = self.input[0, 1, 2, 3:4]
class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1]
self.ends = [1000000]
self.axes = [3]
self.decrease_axis = [3]
self.infer_flags = [1, 1, 1]
self.out = self.input[:, :, :, -1]
class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4]
self.axes = [0, 1, 2, 3]
self.decrease_axis = [0, 1, 2, 3]
self.infer_flags = [1, 1, 1]
self.out = self.input[0, 1, 2, 3:4]
# Situation 2: starts(list, have tensor), ends(list, no tensor)
# without attr(decrease)
class TestSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'starts': self.starts_infer,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 1, 4]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.out = self.input[1, 0, 2:4, :]
self.infer_flags = [-1, 1, -1]
self.out = self.input[1:3, 0:3, 2:4, :]
self.starts_infer = [-1, 0, -1]
def test_check_output(self):
self.check_output()
......@@ -102,26 +195,39 @@ class TestSliceOp_decs_dim_2(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_3(OpTest):
# Situation 2: starts(list, have tensor), ends(list, no tensor)
# with attr(decrease)
class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'starts': self.starts_infer,
'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1, 0, 2]
self.ends = [1000000, 1, 4]
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.out = self.input[-1, 0, 2:4, :]
self.decrease_axis = [0]
self.infer_flags = [1, -1, 1]
self.out = self.input[1, 0:3, 2:4, :]
self.starts_infer = [1, -1, 2]
def test_check_output(self):
self.check_output()
......@@ -130,26 +236,48 @@ class TestSliceOp_decs_dim_3(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_5(OpTest):
class TestSliceOp_decs_dim_5_starts_ListTensor(
TestSliceOp_decs_dim_starts_ListTensor):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1]
self.ends = [1000000]
self.axes = [3]
self.decrease_axis = [3]
self.infer_flags = [-1]
self.out = self.input[:, :, :, -1]
self.starts_infer = [-1]
# Situation 3: starts(tensor), ends(list, no tensor)
# with attr(decrease)
class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32")
}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
#'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1]
self.ends = [1000000]
self.axes = [3]
self.decrease_axis = [3]
self.out = self.input[:, :, :, -1]
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0]
self.infer_flags = [-1, -1, -1]
self.out = self.input[1, 0:3, 2:4, :]
def test_check_output(self):
self.check_output()
......@@ -158,26 +286,35 @@ class TestSliceOp_decs_dim_5(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestSliceOp_decs_dim_6(OpTest):
# Situation 4: starts(tensor), ends(tensor)
# without attr(decrease)
class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
"EndsTensor": np.array(
self.ends, dtype="int32")
}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'decrease_axis': self.decrease_axis,
#'starts': self.starts,
#'ends': self.ends_infer,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4]
self.axes = [0, 1, 2, 3]
self.decrease_axis = [0, 1, 2, 3]
self.out = self.input[0, 1, 2, 3:4]
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.infer_flags = [-1, -1, -1]
self.out = self.input[1:3, 0:3, 2:4, :]
def test_check_output(self):
self.check_output()
......@@ -186,27 +323,103 @@ class TestSliceOp_decs_dim_6(OpTest):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase1(TestSliceOp):
# Situation 5: starts(tensor), ends(tensor)
# with attr(decrease)
class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
"EndsTensor": np.array(
self.ends, dtype="int32")
}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
#'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.starts = [1, 0, 2]
self.ends = [2, 1, 4]
self.axes = [0, 1, 2]
self.out = self.input[-3:3, 0:100, 2:-1, :]
self.decrease_axis = [0, 1]
self.infer_flags = [-1, -1, -1]
self.out = self.input[1, 0, 2:4, :]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestCase2(TestSliceOp):
# Situation 6: starts(tensor), ends(list, have tensor)
# without attr(decrease)
class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("y" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
'EndsTensorList': ends_tensor
}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
'ends': self.ends_infer,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.out = self.input[-3:3, 0:100, :, 2:-1]
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.infer_flags = [-1, -1, -1]
self.out = self.input[1:3, 0:3, 2:4, :]
self.ends_infer = [-1, 3, 4]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Test CUDA float16
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16(TestSliceOp):
class TestFP16(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags
}
def config(self):
self.dtype = "float16"
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
......@@ -214,6 +427,7 @@ class TestFP16(TestSliceOp):
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.out = self.input[-3:3, 0:100, :, 2:-1]
self.infer_flags = [1, 1, 1]
def test_check_output(self):
place = core.CUDAPlace(0)
......@@ -229,7 +443,19 @@ class TestFP16(TestSliceOp):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16_2(TestSliceOp):
class TestFP16_2(OpTest):
def setUp(self):
self.op_type = "slice"
self.config()
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags
}
def config(self):
self.dtype = "float16"
self.input = np.random.random([3, 4, 5]).astype(self.dtype)
......@@ -237,6 +463,7 @@ class TestFP16_2(TestSliceOp):
self.ends = [1]
self.axes = [1]
self.out = self.input[:, 0:1, :]
self.infer_flags = [1]
def test_check_output(self):
place = core.CUDAPlace(0)
......@@ -253,5 +480,53 @@ class TestFP16_2(TestSliceOp):
numeric_grad_delta=0.5)
# Test python API
class TestSliceAPI(OpTest):
def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float32")
minus_1 = fluid.layers.fill_constant([1], "int32", -1)
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
starts = fluid.layers.data(
name='starts', shape=[1, 3], append_batch_size=False)
ends = fluid.layers.data(
name='ends', shape=[3], append_batch_size=False)
x = fluid.layers.data(
name="x",
shape=[3, 4, 5, 6],
append_batch_size=False,
dtype="float32")
out_1 = fluid.layers.slice(
x, axes=[0, 1, 2], starts=[-3, 0, 2], ends=[3, 100, -1])
out_2 = fluid.layers.slice(
x, axes=[0, 1, 3], starts=[minus_3, 0, 2], ends=[3, 100, -1])
out_3 = fluid.layers.slice(
x, axes=[0, 1, 3], starts=[minus_3, 0, 2], ends=[3, 100, minus_1])
out_4 = fluid.layers.slice(x, axes=[0, 1, 2], starts=starts, ends=ends)
out_5 = x[-3:3, 0:100, 2:-1]
out_6 = x[minus_3:3, 0:100, :, 2:-1]
out_7 = x[minus_1, 0:100, :, 2:minus_1]
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
fluid.default_main_program(),
feed={
"x": input,
'starts': np.array([-3, 0, 2]).astype("int32"),
'ends': np.array([3, 100, -1]).astype("int32")
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册