提交 52f933d4 编写于 作者: W Wilber 提交者: Xiaoyang LI

modify slice op and add slice test (#1944)

* modify slice op and add slice test

* modify slice op bug
上级 3618f488
......@@ -51,7 +51,7 @@ void slice(const Dtype* input,
real_ends[axes[i]] = end;
}
}
const int LEN = in_dims.size() - 1;
const int LEN = in_dims.size();
int dst_step[LEN];
for (int i = 0; i < in_dims.size(); ++i) {
dst_step[i] = 1;
......@@ -62,15 +62,17 @@ void slice(const Dtype* input,
}
int out_num = out_dims[in_dims.size() - 1];
for (int i = in_dims.size() - 2; i >= 0; i--) {
dst_step[i] = out_dims[i] * dst_step[i + 1];
src_step[i] = in_dims[i] * src_step[i + 1];
dst_step[i] = out_dims[i + 1] * dst_step[i + 1];
src_step[i] = in_dims[i + 1] * src_step[i + 1];
out_num *= out_dims[i];
}
for (int dst_id = 0; dst_id < out_num; dst_id++) {
int src_id = 0;
int index_id = dst_id;
for (int j = 0; j < out_dims.size(); j++) {
int cur_id = dst_id / dst_step[j];
int cur_id = index_id / dst_step[j];
index_id = index_id % dst_step[j];
src_id += (cur_id + real_starts[j]) * src_step[j];
}
out[dst_id] = input[src_id];
......
......@@ -41,6 +41,7 @@ endif()
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
static void slice_ref(const float* input,
std::vector<int64_t> in_dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out) {
auto out_dims = in_dims;
std::vector<int> real_starts(in_dims.size(), 0);
std::vector<int> real_ends(in_dims.size(), 0);
std::vector<int> real_step(in_dims.size(), 0);
for (int i = 0; i < in_dims.size(); i++) {
real_ends[i] = in_dims[i];
}
for (int i = 0; i < axes.size(); i++) {
int dim_value = in_dims[axes[i]];
if (dim_value > 0) {
int start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
out_dims[axes[i]] = end - start;
real_starts[axes[i]] = start;
real_ends[axes[i]] = end;
}
}
const int LEN = in_dims.size();
int dst_step[LEN];
for (int i = 0; i < in_dims.size(); ++i) {
dst_step[i] = 1;
}
int src_step[LEN];
for (int i = 0; i < in_dims.size(); ++i) {
src_step[i] = 1;
}
int out_num = out_dims[in_dims.size() - 1];
for (int i = in_dims.size() - 2; i >= 0; i--) {
dst_step[i] = out_dims[i + 1] * dst_step[i + 1];
src_step[i] = in_dims[i + 1] * src_step[i + 1];
out_num *= out_dims[i];
}
for (int dst_id = 0; dst_id < out_num; dst_id++) {
int src_id = 0;
int index_id = dst_id;
for (int j = 0; j < out_dims.size(); j++) {
int cur_id = index_id / dst_step[j];
index_id = index_id % dst_step[j];
src_id += (cur_id + real_starts[j]) * src_step[j];
}
out[dst_id] = input[src_id];
}
}
class SliceComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "Input";
std::string output_ = "Out";
std::vector<int> axes_;
std::vector<int> starts_;
std::vector<int> ends_;
std::vector<int> decrease_axis_;
DDim dims_;
public:
SliceComputeTester(const Place& place,
const std::string& alias,
const std::vector<int>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int>& decrease_axis,
const DDim& dims)
: TestCase(place, alias),
axes_(axes),
starts_(starts),
ends_(ends),
decrease_axis_(decrease_axis),
dims_(dims) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
auto* input = scope->FindTensor(input_);
CHECK(out);
CHECK(input);
const auto* input_data = input->data<float>();
auto in_dims = input->dims();
auto out_dims = in_dims;
int dim_value, start, end;
for (size_t i = 0; i < axes_.size(); ++i) {
dim_value = out_dims[axes_[i]];
if (dim_value > 0) {
start = starts_[i] < 0 ? (starts_[i] + dim_value) : starts_[i];
end = ends_[i] < 0 ? (ends_[i] + dim_value) : ends_[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
out_dims[axes_[i]] = end - start;
}
}
if (decrease_axis_.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis_.size(); ++i) {
out_dims[decrease_axis_[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
DDim new_dims;
new_dims.ConstructFrom(new_out_shape);
out_dims = new_dims;
}
out->Resize(out_dims);
auto* out_data = out->mutable_data<float>();
slice_ref(input_data, in_dims.data(), axes_, starts_, ends_, out_data);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("slice");
op_desc->SetInput("Input", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axes", axes_);
op_desc->SetAttr("starts", starts_);
op_desc->SetAttr("ends", ends_);
op_desc->SetAttr("decrease_axis", decrease_axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i;
}
SetCommonTensor(input_, dims_, data.data());
}
};
void test_slice(Place place) {
std::vector<int> axes({0, 1, 2});
std::vector<int> starts({2, 2, 2});
std::vector<int> ends({5, 6, 7});
std::vector<int> decrease_axis({});
DDim dims({10, 10, 10});
std::unique_ptr<arena::TestCase> tester(new SliceComputeTester(
place, "def", axes, starts, ends, decrease_axis, dims));
arena::Arena arena(std::move(tester), place, 2e-4);
arena.TestPrecision();
}
TEST(Slice, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_slice(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册