未验证 提交 e74609b7 编写于 作者: W Wilber 提交者: GitHub

update slice and reshape op and test on one op fake model test=develop (#2377)

update reshape op to support multiple input types of shape.
priority: input(ShapeTensor) > input(Shape) > attr(shape)

update slice op to support multiple iput types of starts and ends.
priority: input(StartsTensor) > input(StartsTensorList) > attr(starts)
上级 9d97d56e
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/slice_compute.h" #include "lite/kernels/arm/slice_compute.h"
#include <algorithm>
#include <vector> #include <vector>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
...@@ -20,22 +21,145 @@ namespace lite { ...@@ -20,22 +21,145 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
inline std::vector<int32_t> get_new_data_from_tensorlist(
const std::vector<lite::Tensor*>& list_new_data_tensor) {
// get tensor
std::vector<int32_t> vec_new_data;
for (size_t i = 0; i < list_new_data_tensor.size(); ++i) {
auto tensor = list_new_data_tensor[i];
CHECK_EQ(tensor->dims(), DDim({1})) << "shape of dim tensor should be [1]";
vec_new_data.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
return vec_new_data;
}
inline std::vector<int32_t> get_new_data_from_tensor(
const lite::Tensor* new_data_tensor) {
std::vector<int32_t> vec_new_data;
auto* new_data = new_data_tensor->data<int32_t>();
vec_new_data =
std::vector<int32_t>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
void SliceCompute::PrepareForRun() {} void SliceCompute::PrepareForRun() {}
void SliceCompute::Run() { void SliceCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::SliceParam>(); auto& param = this->Param<operators::SliceParam>();
auto input_dims = param.X->dims(); auto in = param.X;
int dim_size = param.X->dims().size(); auto in_dims = in->dims();
auto out = param.Out;
auto out_dims = out->dims();
std::vector<int> starts = param.starts;
std::vector<int> ends = param.ends;
std::vector<int> axes = param.axes; std::vector<int> axes = param.axes;
const auto* x_data = param.X->data<int>(); std::vector<int32_t> starts = param.starts;
auto* o_data = param.Out->mutable_data<int>(); std::vector<int32_t> ends = param.ends;
std::vector<int> decrease_axis = param.decrease_axis;
std::vector<int> infer_flags = param.infer_flags;
auto list_new_ends_tensor = param.EndsTensorList;
auto list_new_starts_tensor = param.StartsTensorList;
bool need_infer = false;
if (param.StartsTensor || param.EndsTensor) {
need_infer = true;
}
if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) {
need_infer = true;
}
if (need_infer) {
if (param.StartsTensor) {
starts = get_new_data_from_tensor(param.StartsTensor);
} else if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
}
CHECK_EQ(starts.size(), axes.size())
<< "The size of starts must be equal to the size of axes.";
if (param.EndsTensor) {
ends = get_new_data_from_tensor(param.EndsTensor);
} else if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
}
CHECK_EQ(ends.size(), axes.size())
<< "The size of ends must be equal to the size of axes.";
out_dims = in_dims;
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
// when end = start+1 and start == -1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret =
std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = 10000000;
}
}
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
CHECK_GT(end, start) << "end should greater than start";
out_dims[axes[i]] = end - start;
}
}
out->Resize(out_dims);
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
CHECK_EQ(out_dims[decrease_axis[i]], 1) << "decrease dim should be 1";
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
DDim new_dims;
new_dims.ConstructFrom(new_out_shape);
out_dims = new_dims;
}
}
// resize out dims
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
std::vector<int64_t> vec_origin_out_shape(decrease_axis.size(), 1);
out->Resize(DDim(vec_origin_out_shape));
} else {
std::vector<int64_t> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
if (vec_origin_out_shape[i] == -1) {
vec_origin_out_shape[i] = out_dims[index];
++index;
}
}
out->Resize(DDim(vec_origin_out_shape));
}
}
auto new_out_dims = out->dims();
const auto* x_data = in->data<int>();
auto* o_data = out->mutable_data<int>();
lite::arm::math::slice( lite::arm::math::slice(
x_data, input_dims.data(), axes, starts, ends, o_data, &ctx); x_data, in_dims.data(), axes, starts, ends, o_data, &ctx);
} }
} // namespace arm } // namespace arm
...@@ -46,12 +170,9 @@ void SliceCompute::Run() { ...@@ -46,12 +170,9 @@ void SliceCompute::Run() {
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
slice, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SliceCompute, def) slice, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SliceCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
// REGISTER_LITE_KERNEL(
// slice, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SliceCompute, def)
// .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), Precision(kINT32))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM),
// Precision(kINT32))})
// .Finalize();
...@@ -21,6 +21,8 @@ REGISTER_LITE_KERNEL(reshape, ...@@ -21,6 +21,8 @@ REGISTER_LITE_KERNEL(reshape,
paddle::lite::kernels::x86::ReshapeCompute<float>, paddle::lite::kernels::x86::ReshapeCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -31,6 +33,8 @@ REGISTER_LITE_KERNEL(reshape2, ...@@ -31,6 +33,8 @@ REGISTER_LITE_KERNEL(reshape2,
paddle::lite::kernels::x86::Reshape2Compute<float>, paddle::lite::kernels::x86::Reshape2Compute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -41,6 +45,8 @@ REGISTER_LITE_KERNEL(reshape2, ...@@ -41,6 +45,8 @@ REGISTER_LITE_KERNEL(reshape2,
paddle::lite::kernels::x86::Reshape2Compute<int64_t>, paddle::lite::kernels::x86::Reshape2Compute<int64_t>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("XShape", .BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
......
...@@ -21,5 +21,9 @@ REGISTER_LITE_KERNEL(slice, ...@@ -21,5 +21,9 @@ REGISTER_LITE_KERNEL(slice,
paddle::lite::kernels::x86::SliceCompute<float>, paddle::lite::kernels::x86::SliceCompute<float>,
def) def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -28,16 +28,111 @@ namespace lite { ...@@ -28,16 +28,111 @@ namespace lite {
namespace kernels { namespace kernels {
namespace x86 { namespace x86 {
inline std::vector<int> get_new_data_from_tensorlist(
const std::vector<lite::Tensor*>& list_new_data_tensor) {
// get tensor from
std::vector<int> vec_new_data;
for (size_t i = 0; i < list_new_data_tensor.size(); ++i) {
auto tensor = list_new_data_tensor[i];
CHECK_EQ(tensor->dims(), DDim({1})) << "shape of dim tensor should be [1]";
vec_new_data.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
return vec_new_data;
}
inline std::vector<int> get_new_data_from_tensor(
const Tensor* new_data_tensor) {
std::vector<int> vec_new_data;
auto* new_data = new_data_tensor->data<int>();
vec_new_data =
std::vector<int>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
template <size_t D> template <size_t D>
void slice_compute(const lite::Tensor* in, void slice_compute(const lite::Tensor* in,
lite::Tensor* out, lite::Tensor* out,
std::vector<int> axes, std::vector<int> axes,
std::vector<int> starts, std::vector<int> starts,
std::vector<int> ends, std::vector<int> ends,
std::vector<int> decrease_axis) { std::vector<int> decrease_axis,
lite::Tensor* StartsTensor,
lite::Tensor* EndsTensor,
std::vector<lite::Tensor*> StartsTensorList,
std::vector<lite::Tensor*> EndsTensorList,
std::vector<int> infer_flags) {
auto out_dims = out->dims(); auto out_dims = out->dims();
auto in_dims = in->dims(); auto in_dims = in->dims();
bool need_infer = false;
if (StartsTensor || EndsTensor) {
need_infer = true;
} else if (StartsTensorList.size() > 0 || EndsTensorList.size() > 0) {
need_infer = true;
}
if (need_infer) {
if (StartsTensor) {
starts = get_new_data_from_tensor(StartsTensor);
} else if (StartsTensorList.size() > 0) {
starts = get_new_data_from_tensorlist(StartsTensorList);
}
CHECK_EQ(starts.size(), axes.size())
<< "The size of starts must be equal to the size of axes.";
if (EndsTensor) {
ends = get_new_data_from_tensor(EndsTensor);
} else if (EndsTensorList.size() > 0) {
ends = get_new_data_from_tensorlist(EndsTensorList);
}
CHECK_EQ(ends.size(), axes.size())
<< "The size of ends must be equal to the size of axes.";
out_dims = in_dims;
int dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
// when end = start + 1 and start == -1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret =
std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = 10000000;
}
}
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
CHECK_GT(end, start) << "end should greater than start";
out_dims[axes[i]] = end - start;
}
}
out->Resize(out_dims);
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
CHECK_EQ(out_dims[decrease_axis[i]], 1) << "decrease dim should be 1";
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
DDim new_dims;
new_dims.ConstructFrom(new_out_shape);
out_dims = new_dims;
}
}
// resize out_dims // resize out_dims
if (decrease_axis.size() > 0) { if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) { if (decrease_axis.size() == (size_t)in_dims.size()) {
...@@ -97,26 +192,91 @@ void slice_compute_(const lite::Tensor* Input, ...@@ -97,26 +192,91 @@ void slice_compute_(const lite::Tensor* Input,
std::vector<int> axes, std::vector<int> axes,
std::vector<int> starts, std::vector<int> starts,
std::vector<int> ends, std::vector<int> ends,
std::vector<int> decrease_axis) { std::vector<int> decrease_axis,
lite::Tensor* StartsTensor,
lite::Tensor* EndsTensor,
std::vector<lite::Tensor*> StartsTensorList,
std::vector<lite::Tensor*> EndsTensorList,
std::vector<int> infer_flags) {
int rank = Input->dims().size(); int rank = Input->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
slice_compute<1>(Input, Out, axes, starts, ends, decrease_axis); slice_compute<1>(Input,
Out,
axes,
starts,
ends,
decrease_axis,
StartsTensor,
EndsTensor,
StartsTensorList,
EndsTensorList,
infer_flags);
break; break;
case 2: case 2:
slice_compute<2>(Input, Out, axes, starts, ends, decrease_axis); slice_compute<2>(Input,
Out,
axes,
starts,
ends,
decrease_axis,
StartsTensor,
EndsTensor,
StartsTensorList,
EndsTensorList,
infer_flags);
break; break;
case 3: case 3:
slice_compute<3>(Input, Out, axes, starts, ends, decrease_axis); slice_compute<3>(Input,
Out,
axes,
starts,
ends,
decrease_axis,
StartsTensor,
EndsTensor,
StartsTensorList,
EndsTensorList,
infer_flags);
break; break;
case 4: case 4:
slice_compute<4>(Input, Out, axes, starts, ends, decrease_axis); slice_compute<4>(Input,
Out,
axes,
starts,
ends,
decrease_axis,
StartsTensor,
EndsTensor,
StartsTensorList,
EndsTensorList,
infer_flags);
break; break;
case 5: case 5:
slice_compute<5>(Input, Out, axes, starts, ends, decrease_axis); slice_compute<5>(Input,
Out,
axes,
starts,
ends,
decrease_axis,
StartsTensor,
EndsTensor,
StartsTensorList,
EndsTensorList,
infer_flags);
break; break;
case 6: case 6:
slice_compute<6>(Input, Out, axes, starts, ends, decrease_axis); slice_compute<6>(Input,
Out,
axes,
starts,
ends,
decrease_axis,
StartsTensor,
EndsTensor,
StartsTensorList,
EndsTensorList,
infer_flags);
break; break;
} }
} }
...@@ -133,7 +293,12 @@ class SliceCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -133,7 +293,12 @@ class SliceCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param.axes, param.axes,
param.starts, param.starts,
param.ends, param.ends,
param.decrease_axis); param.decrease_axis,
param.StartsTensor,
param.EndsTensor,
param.StartsTensorList,
param.EndsTensorList,
param.infer_flags);
} }
virtual ~SliceCompute() = default; virtual ~SliceCompute() = default;
......
...@@ -24,6 +24,60 @@ namespace lite { ...@@ -24,6 +24,60 @@ namespace lite {
namespace kernels { namespace kernels {
namespace x86 { namespace x86 {
static void slice_ref(const float* input,
std::vector<int64_t> in_dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out) {
auto out_dims = in_dims;
std::vector<int> real_starts(in_dims.size(), 0);
std::vector<int> real_ends(in_dims.size(), 0);
std::vector<int> real_step(in_dims.size(), 0);
for (int i = 0; i < in_dims.size(); i++) {
real_ends[i] = in_dims[i];
}
for (int i = 0; i < axes.size(); i++) {
int dim_value = in_dims[axes[i]];
if (dim_value > 0) {
int start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
out_dims[axes[i]] = end - start;
real_starts[axes[i]] = start;
real_ends[axes[i]] = end;
}
}
const int LEN = in_dims.size();
int dst_step[LEN];
for (int i = 0; i < in_dims.size(); ++i) {
dst_step[i] = 1;
}
int src_step[LEN];
for (int i = 0; i < in_dims.size(); ++i) {
src_step[i] = 1;
}
int out_num = out_dims[in_dims.size() - 1];
for (int i = in_dims.size() - 2; i >= 0; i--) {
dst_step[i] = out_dims[i + 1] * dst_step[i + 1];
src_step[i] = in_dims[i + 1] * src_step[i + 1];
out_num *= out_dims[i];
}
for (int dst_id = 0; dst_id < out_num; dst_id++) {
int src_id = 0;
int index_id = dst_id;
for (int j = 0; j < out_dims.size(); j++) {
int cur_id = index_id / dst_step[j];
index_id = index_id % dst_step[j];
src_id += (cur_id + real_starts[j]) * src_step[j];
}
out[dst_id] = input[src_id];
}
}
TEST(slice_x86, retrive_op) { TEST(slice_x86, retrive_op) {
auto slice = auto slice =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("slice"); KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("slice");
...@@ -59,6 +113,9 @@ void test_case1(lite::Tensor x, lite::Tensor out) { ...@@ -59,6 +113,9 @@ void test_case1(lite::Tensor x, lite::Tensor out) {
operators::SliceParam param; operators::SliceParam param;
param.X = &x; param.X = &x;
param.axes = axes;
param.starts = starts;
param.ends = ends;
param.Out = &out; param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
...@@ -67,8 +124,11 @@ void test_case1(lite::Tensor x, lite::Tensor out) { ...@@ -67,8 +124,11 @@ void test_case1(lite::Tensor x, lite::Tensor out) {
slice.SetParam(param); slice.SetParam(param);
slice.Run(); slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
} }
} }
...@@ -95,6 +155,9 @@ void test_case2(lite::Tensor x, lite::Tensor out) { ...@@ -95,6 +155,9 @@ void test_case2(lite::Tensor x, lite::Tensor out) {
param.X = &x; param.X = &x;
param.Out = &out; param.Out = &out;
param.axes = axes;
param.starts = starts;
param.ends = ends;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
...@@ -102,8 +165,11 @@ void test_case2(lite::Tensor x, lite::Tensor out) { ...@@ -102,8 +165,11 @@ void test_case2(lite::Tensor x, lite::Tensor out) {
slice.SetParam(param); slice.SetParam(param);
slice.Run(); slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
} }
} }
...@@ -130,6 +196,9 @@ void test_case3(lite::Tensor x, lite::Tensor out) { ...@@ -130,6 +196,9 @@ void test_case3(lite::Tensor x, lite::Tensor out) {
param.X = &x; param.X = &x;
param.Out = &out; param.Out = &out;
param.axes = axes;
param.starts = starts;
param.ends = ends;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
...@@ -137,8 +206,11 @@ void test_case3(lite::Tensor x, lite::Tensor out) { ...@@ -137,8 +206,11 @@ void test_case3(lite::Tensor x, lite::Tensor out) {
slice.SetParam(param); slice.SetParam(param);
slice.Run(); slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
} }
} }
void test_case4(lite::Tensor x, lite::Tensor out) { void test_case4(lite::Tensor x, lite::Tensor out) {
...@@ -164,6 +236,9 @@ void test_case4(lite::Tensor x, lite::Tensor out) { ...@@ -164,6 +236,9 @@ void test_case4(lite::Tensor x, lite::Tensor out) {
param.X = &x; param.X = &x;
param.Out = &out; param.Out = &out;
param.axes = axes;
param.starts = starts;
param.ends = ends;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
...@@ -171,8 +246,11 @@ void test_case4(lite::Tensor x, lite::Tensor out) { ...@@ -171,8 +246,11 @@ void test_case4(lite::Tensor x, lite::Tensor out) {
slice.SetParam(param); slice.SetParam(param);
slice.Run(); slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
} }
} }
...@@ -199,6 +277,9 @@ void test_case5(lite::Tensor x, lite::Tensor out) { ...@@ -199,6 +277,9 @@ void test_case5(lite::Tensor x, lite::Tensor out) {
param.X = &x; param.X = &x;
param.Out = &out; param.Out = &out;
param.axes = axes;
param.starts = starts;
param.ends = ends;
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
...@@ -206,10 +287,14 @@ void test_case5(lite::Tensor x, lite::Tensor out) { ...@@ -206,10 +287,14 @@ void test_case5(lite::Tensor x, lite::Tensor out) {
slice.SetParam(param); slice.SetParam(param);
slice.Run(); slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
} }
} }
void test_case6(lite::Tensor x, lite::Tensor out) { void test_case6(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5, 6, 5, 2}); std::vector<int64_t> x_shape({3, 4, 5, 6, 5, 2});
x.Resize(lite::DDim(x_shape)); x.Resize(lite::DDim(x_shape));
...@@ -233,6 +318,153 @@ void test_case6(lite::Tensor x, lite::Tensor out) { ...@@ -233,6 +318,153 @@ void test_case6(lite::Tensor x, lite::Tensor out) {
param.X = &x; param.X = &x;
param.Out = &out; param.Out = &out;
param.axes = axes;
param.starts = starts;
param.ends = ends;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
}
}
void test_tensor_case1(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({10});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({5});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({3});
std::vector<int> ends({8});
std::vector<int> axes({0});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
param.axes = axes;
lite::Tensor starts_tensor, ends_tensor;
starts_tensor.Resize(DDim({1}));
ends_tensor.Resize(DDim({1}));
starts_tensor.mutable_data<int>()[0] = starts[0];
ends_tensor.mutable_data<int>()[0] = ends[0];
param.StartsTensor = &starts_tensor;
param.EndsTensor = &ends_tensor;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
}
}
void test_tensor_case3(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4, 2});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0, 2});
std::vector<int> ends({3, 100, -1});
std::vector<int> axes({0, 1, 2});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
param.axes = axes;
lite::Tensor starts_tensor, ends_tensor;
starts_tensor.Resize(DDim({3}));
ends_tensor.Resize(DDim({3}));
for (int i = 0; i < starts.size(); ++i) {
starts_tensor.mutable_data<int>()[i] = starts[i];
ends_tensor.mutable_data<int>()[i] = ends[i];
}
param.StartsTensor = &starts_tensor;
param.EndsTensor = &ends_tensor;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
}
}
void test_tensor_list_case1(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({10});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({5});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({3});
std::vector<int> ends({8});
std::vector<int> axes({0});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
param.axes = axes;
param.StartsTensorList.clear();
param.EndsTensorList.clear();
lite::Tensor starts_tensor, ends_tensor;
for (int i = 0; i < 1; ++i) {
starts_tensor.Resize(DDim({1}));
ends_tensor.Resize(DDim({1}));
starts_tensor.mutable_data<int>()[0] = starts[0];
ends_tensor.mutable_data<int>()[0] = ends[0];
param.StartsTensorList.push_back(&starts_tensor);
param.EndsTensorList.push_back(&ends_tensor);
}
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
...@@ -240,8 +472,73 @@ void test_case6(lite::Tensor x, lite::Tensor out) { ...@@ -240,8 +472,73 @@ void test_case6(lite::Tensor x, lite::Tensor out) {
slice.SetParam(param); slice.SetParam(param);
slice.Run(); slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i]; EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
}
}
void test_tensor_list_case3(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4, 2});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0, 2});
std::vector<int> ends({3, 100, -1});
std::vector<int> axes({0, 1, 2});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
param.axes = axes;
param.StartsTensorList.clear();
param.EndsTensorList.clear();
lite::Tensor starts_tensor0, ends_tensor0;
lite::Tensor starts_tensor1, ends_tensor1;
lite::Tensor starts_tensor2, ends_tensor2;
starts_tensor0.Resize(DDim({1}));
starts_tensor1.Resize(DDim({1}));
starts_tensor2.Resize(DDim({1}));
ends_tensor0.Resize(DDim({1}));
ends_tensor1.Resize(DDim({1}));
ends_tensor2.Resize(DDim({1}));
starts_tensor0.mutable_data<int>()[0] = starts[0];
starts_tensor1.mutable_data<int>()[0] = starts[1];
starts_tensor2.mutable_data<int>()[0] = starts[2];
ends_tensor0.mutable_data<int>()[0] = ends[0];
ends_tensor1.mutable_data<int>()[0] = ends[1];
ends_tensor2.mutable_data<int>()[0] = ends[2];
param.StartsTensorList.emplace_back(&starts_tensor0);
param.StartsTensorList.emplace_back(&starts_tensor1);
param.StartsTensorList.emplace_back(&starts_tensor2);
param.EndsTensorList.emplace_back(&ends_tensor0);
param.EndsTensorList.emplace_back(&ends_tensor1);
param.EndsTensorList.emplace_back(&ends_tensor2);
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
std::vector<float> out_ref(out.numel(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
} }
} }
...@@ -257,6 +554,22 @@ TEST(slice_x86, run_test) { ...@@ -257,6 +554,22 @@ TEST(slice_x86, run_test) {
test_case6(x, out); test_case6(x, out);
} }
TEST(slice_x86, test_tensor) {
lite::Tensor x;
lite::Tensor out;
test_tensor_case1(x, out);
test_tensor_case3(x, out);
}
TEST(slice_x86, test_tensor_list) {
lite::Tensor x;
lite::Tensor out;
test_tensor_list_case1(x, out);
test_tensor_list_case3(x, out);
}
} // namespace x86 } // namespace x86
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -788,6 +788,11 @@ struct SliceParam { ...@@ -788,6 +788,11 @@ struct SliceParam {
std::vector<int> starts{}; std::vector<int> starts{};
std::vector<int> ends{}; std::vector<int> ends{};
std::vector<int> decrease_axis{}; std::vector<int> decrease_axis{};
std::vector<int> infer_flags{};
std::vector<lite::Tensor*> StartsTensorList{};
std::vector<lite::Tensor*> EndsTensorList{};
lite::Tensor* StartsTensor{nullptr};
lite::Tensor* EndsTensor{nullptr};
}; };
struct AffineChannelParam { struct AffineChannelParam {
......
...@@ -60,8 +60,8 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -60,8 +60,8 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.output = param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
if (opdesc.HasInput("ShapeTensor") && // prority: input(ShapeTensor) > input(Shape) > attr(shape)
opdesc.Input("ShapeTensor").size() > 0) { if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
auto args = opdesc.Input("ShapeTensor"); auto args = opdesc.Input("ShapeTensor");
for (auto arg : args) { for (auto arg : args) {
auto *var = scope->FindVar(arg); auto *var = scope->FindVar(arg);
...@@ -69,8 +69,13 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -69,8 +69,13 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>()); param_.shape_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
} }
} }
CHECK_GT(param_.shape_tensor_vct.size(), 0)
<< "ShapeError: When `shape` in ReshapeOp is a list or tuple "
"which contains Tensor, the shape's size can't be zero. "
"But received shape's size is "
<< param_.shape_tensor_vct.size();
} }
if (opdesc.HasInput("Shape") && opdesc.Input("Shape").size() > 0) { if (opdesc.HasInput("Shape") && !opdesc.Input("Shape").empty()) {
auto var = scope->FindVar(opdesc.Input("Shape").front()); auto var = scope->FindVar(opdesc.Input("Shape").front());
if (var != nullptr) { if (var != nullptr) {
param_.shape_tensor = var->GetMutable<lite::Tensor>(); param_.shape_tensor = var->GetMutable<lite::Tensor>();
......
...@@ -22,6 +22,8 @@ namespace operators { ...@@ -22,6 +22,8 @@ namespace operators {
bool SliceOp::CheckShape() const { bool SliceOp::CheckShape() const {
CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
CHECK_LT(param_.X->dims().size(), 7)
<< "The rank of input X should be less than 7";
return true; return true;
} }
...@@ -30,14 +32,21 @@ bool SliceOp::InferShape() const { ...@@ -30,14 +32,21 @@ bool SliceOp::InferShape() const {
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto in_dims = param_.X->dims(); auto in_dims = param_.X->dims();
auto out_dims = in_dims; auto out_dims = in_dims;
CHECK_EQ(param_.starts.size(), param_.ends.size()) // CHECK_EQ(param_.starts.size(), param_.ends.size())
<< "for slice op starts and ends must be equal"; // << "for slice op starts and ends must be equal";
int dim_value, start, end; int dim_value, start, end;
auto axes = param_.axes; auto axes = param_.axes;
auto starts = param_.starts; auto starts = param_.starts;
auto ends = param_.ends; auto ends = param_.ends;
auto decrease_axis = param_.decrease_axis; auto decrease_axis = param_.decrease_axis;
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
CHECK_LT(param_.axes[i], in_dims.size()) << "The index of dimension in "
"axes must be less than the "
"size of input shape.";
if (param_.infer_flags[i] == -1) {
out_dims[axes[i]] = -1;
} else {
// infer out_dim shape
dim_value = out_dims[axes[i]]; dim_value = out_dims[axes[i]];
if (dim_value > 0) { if (dim_value > 0) {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
...@@ -48,9 +57,14 @@ bool SliceOp::InferShape() const { ...@@ -48,9 +57,14 @@ bool SliceOp::InferShape() const {
out_dims[axes[i]] = end - start; out_dims[axes[i]] = end - start;
} }
} }
}
// generate new shape
if (decrease_axis.size() > 0) { if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape; std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) { for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (param_.infer_flags[i] != -1) {
CHECK_EQ(out_dims[decrease_axis[i]], 1) << "decrease dim should be 1";
}
out_dims[decrease_axis[i]] = 0; out_dims[decrease_axis[i]] = 0;
} }
for (int i = 0; i < out_dims.size(); ++i) { for (int i = 0; i < out_dims.size(); ++i) {
...@@ -66,6 +80,10 @@ bool SliceOp::InferShape() const { ...@@ -66,6 +80,10 @@ bool SliceOp::InferShape() const {
out_dims = new_dims; out_dims = new_dims;
} }
param_.Out->Resize(out_dims); param_.Out->Resize(out_dims);
if (axes[0] != 0) {
param_.Out->set_lod(param_.X->lod());
}
LOG(INFO) << "infer shape done";
return true; return true;
} }
...@@ -77,11 +95,74 @@ bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -77,11 +95,74 @@ bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
CHECK(param_.X); CHECK(param_.X);
CHECK(param_.Out); CHECK(param_.Out);
param_.axes = opdesc.GetAttr<std::vector<int>>("axes"); param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
param_.starts = opdesc.GetAttr<std::vector<int>>("starts");
param_.ends = opdesc.GetAttr<std::vector<int>>("ends"); if (opdesc.HasAttr("infer_flags")) {
param_.infer_flags = opdesc.GetAttr<std::vector<int>>("infer_flags");
} else {
// Initialize infer_flags with 1.
// To be compatible with other op tests in which infer_flags is not set.
param_.infer_flags = std::vector<int>(param_.axes.size(), 1);
}
if (opdesc.HasAttr("decrease_axis")) { if (opdesc.HasAttr("decrease_axis")) {
param_.decrease_axis = opdesc.GetAttr<std::vector<int>>("decrease_axis"); param_.decrease_axis = opdesc.GetAttr<std::vector<int>>("decrease_axis");
} }
// The priority: StartsTensor > StartsTensorList > attr(starts).
// The priority: EndsTensor > EndsTensorList > attr(ends).
int starts_size, ends_size;
if (opdesc.HasAttr("starts")) {
param_.starts = opdesc.GetAttr<std::vector<int>>("starts");
}
if (opdesc.HasAttr("ends")) {
param_.ends = opdesc.GetAttr<std::vector<int>>("ends");
}
starts_size = param_.starts.size();
ends_size = param_.ends.size();
if (opdesc.HasInput("StartsTensorList") &&
!opdesc.Input("StartsTensorList").empty()) {
LOG(INFO) << "opdesc input size "
<< opdesc.Input("StartsTensorList").size();
LOG(INFO) << "param init size " << param_.StartsTensorList.size();
auto StartsTensorList = opdesc.Input("StartsTensorList");
param_.StartsTensorList.clear();
for (auto var : StartsTensorList) {
param_.StartsTensorList.push_back(
scope->FindVar(var)->GetMutable<lite::Tensor>());
}
CHECK_GT(param_.StartsTensorList.size(), 0)
<< "StartsTensorList size can't be zero";
starts_size = param_.StartsTensorList.size();
}
if (opdesc.HasInput("EndsTensorList") &&
!opdesc.Input("EndsTensorList").empty()) {
auto EndsTensorList = opdesc.Input("EndsTensorList");
param_.EndsTensorList.clear();
for (auto var : EndsTensorList) {
param_.EndsTensorList.push_back(
scope->FindVar(var)->GetMutable<lite::Tensor>());
}
CHECK_GT(param_.EndsTensorList.size(), 0)
<< "EndsTensorList size can't be zero";
ends_size = param_.EndsTensorList.size();
}
if (opdesc.HasInput("StartsTensor") &&
!opdesc.Input("StartsTensor").empty()) {
param_.StartsTensor = scope->FindVar(opdesc.Input("StartsTensor").front())
->GetMutable<lite::Tensor>();
} else {
CHECK_EQ(starts_size, param_.axes.size())
<< "The size of starts must be equal to the size of axes.";
}
if (opdesc.HasInput("EndsTensor") && !opdesc.Input("EndsTensor").empty()) {
param_.EndsTensor = scope->FindVar(opdesc.Input("EndsTensor").front())
->GetMutable<lite::Tensor>();
} else {
CHECK_EQ(ends_size, param_.axes.size())
<< "The size of ends must be equal to the size of axes.";
}
LOG(INFO) << "attach impl done";
return true; return true;
} }
......
...@@ -84,6 +84,13 @@ class SliceComputeTester : public arena::TestCase { ...@@ -84,6 +84,13 @@ class SliceComputeTester : public arena::TestCase {
std::vector<int> ends_; std::vector<int> ends_;
std::vector<int> decrease_axis_; std::vector<int> decrease_axis_;
DDim dims_; DDim dims_;
std::vector<int> infer_flags_;
std::string starts_tensor_ = "StartsTensor";
std::string ends_tensor_ = "EndsTensor";
// std::string starts_tensor_list_ = "StartsTensorList";
// std::string ends_tensor_list_ = "EndsTensorList";
bool use_tensor_;
bool use_tensor_list_;
public: public:
SliceComputeTester(const Place& place, SliceComputeTester(const Place& place,
...@@ -92,13 +99,19 @@ class SliceComputeTester : public arena::TestCase { ...@@ -92,13 +99,19 @@ class SliceComputeTester : public arena::TestCase {
const std::vector<int>& starts, const std::vector<int>& starts,
const std::vector<int>& ends, const std::vector<int>& ends,
const std::vector<int>& decrease_axis, const std::vector<int>& decrease_axis,
const DDim& dims) const DDim& dims,
bool use_tensor = false,
bool use_tensor_list = false,
const std::vector<int>& infer_flags = {})
: TestCase(place, alias), : TestCase(place, alias),
axes_(axes), axes_(axes),
starts_(starts), starts_(starts),
ends_(ends), ends_(ends),
decrease_axis_(decrease_axis), decrease_axis_(decrease_axis),
dims_(dims) {} dims_(dims),
infer_flags_(infer_flags),
use_tensor_(use_tensor),
use_tensor_list_(use_tensor_list) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_); auto* out = scope->NewTensor(output_);
...@@ -146,6 +159,25 @@ class SliceComputeTester : public arena::TestCase { ...@@ -146,6 +159,25 @@ class SliceComputeTester : public arena::TestCase {
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("slice"); op_desc->SetType("slice");
op_desc->SetInput("Input", {input_}); op_desc->SetInput("Input", {input_});
if (use_tensor_) {
op_desc->SetInput("StartsTensor", {starts_tensor_});
op_desc->SetInput("EndsTensor", {ends_tensor_});
} else if (use_tensor_list_) {
std::vector<std::string> starts_tensor_list_;
std::vector<std::string> ends_tensor_list_;
for (int i = 0; i < starts_.size(); ++i) {
starts_tensor_list_.push_back("starts_tensor_list_" +
std::to_string(i));
ends_tensor_list_.push_back("ends_tensor_list_" + std::to_string(i));
}
op_desc->SetInput("StartsTensorList", {starts_tensor_list_});
op_desc->SetInput("EndsTensorList", {ends_tensor_list_});
}
if (infer_flags_.size() > 0) {
op_desc->SetAttr("infer_flags", infer_flags_);
}
op_desc->SetOutput("Out", {output_}); op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axes", axes_); op_desc->SetAttr("axes", axes_);
op_desc->SetAttr("starts", starts_); op_desc->SetAttr("starts", starts_);
...@@ -161,6 +193,30 @@ class SliceComputeTester : public arena::TestCase { ...@@ -161,6 +193,30 @@ class SliceComputeTester : public arena::TestCase {
} }
SetCommonTensor(input_, dims_, data.data()); SetCommonTensor(input_, dims_, data.data());
if (use_tensor_) {
SetCommonTensor(starts_tensor_,
DDim({static_cast<int64_t>(starts_.size())}),
starts_.data());
SetCommonTensor(ends_tensor_,
DDim({static_cast<int64_t>(ends_.size())}),
ends_.data());
} else if (use_tensor_list_) {
Scope& scope_ = this->scope();
for (int i = 0; i < starts_.size(); ++i) {
auto* tensor =
scope_.NewTensor("starts_tensor_list_" + std::to_string(i));
tensor->Resize(DDim({1}));
auto* d = tensor->mutable_data<int>();
d[0] = starts_[i];
}
for (int i = 0; i < ends_.size(); ++i) {
auto* tensor =
scope_.NewTensor("ends_tensor_list_" + std::to_string(i));
tensor->Resize(DDim({1}));
auto* d = tensor->mutable_data<int>();
d[0] = ends_[i];
}
}
} }
}; };
...@@ -176,6 +232,40 @@ void test_slice(Place place) { ...@@ -176,6 +232,40 @@ void test_slice(Place place) {
arena.TestPrecision(); arena.TestPrecision();
} }
void test_slice_tensor(Place place) {
std::vector<int> axes({0, 1, 2});
std::vector<int> starts({2, 2, 2});
std::vector<int> ends({5, 6, 7});
std::vector<int> decrease_axis({});
DDim dims({10, 10, 10});
std::unique_ptr<arena::TestCase> tester(new SliceComputeTester(
place, "def", axes, starts, ends, decrease_axis, dims, true));
arena::Arena arena(std::move(tester), place, 2e-4);
arena.TestPrecision();
}
void test_slice_tensor_list(Place place) {
std::vector<int> axes({0, 1, 2});
std::vector<int> starts({2, 2, 2});
std::vector<int> ends({5, 6, 7});
std::vector<int> decrease_axis({});
std::vector<int> infer_flags({});
DDim dims({10, 10, 10});
std::unique_ptr<arena::TestCase> tester(new SliceComputeTester(place,
"def",
axes,
starts,
ends,
decrease_axis,
dims,
false,
true,
infer_flags));
arena::Arena arena(std::move(tester), place, 2e-4);
arena.TestPrecision();
}
TEST(Slice, precision) { TEST(Slice, precision) {
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
Place place(TARGET(kX86)); Place place(TARGET(kX86));
...@@ -183,6 +273,8 @@ TEST(Slice, precision) { ...@@ -183,6 +273,8 @@ TEST(Slice, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
test_slice(place); test_slice(place);
test_slice_tensor(place);
test_slice_tensor_list(place);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册