未验证 提交 13bbd2b8 编写于 作者: L liu zhengxi 提交者: GitHub

add slice op, reshape op, reshape2 op, squeeze op, squeeze2 op for x86 (#2005)

add slice op, reshape op,  reshape2 op, squeeze op and squeeze2 op and their unittests for x86
上级 2dcff5ca
......@@ -72,6 +72,9 @@ build
cmake-build-debug
cmake-build-release
# vscode
.vscode
# ios
tools/libomp.a
......
......@@ -139,6 +139,7 @@ USE_LITE_KERNEL(assign, kARM, kFloat, kNCHW, def);
// USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(slice, kX86, kFloat, kNCHW, def);
// USE_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW, def);
// USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def);
// USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
......
......@@ -11,6 +11,9 @@ endif()
# lite_cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
# lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op)
# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
# lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
......@@ -36,6 +39,9 @@ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEP
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86)
lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86)
lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/reshape_compute.h"
REGISTER_LITE_KERNEL(reshape,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ReshapeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(reshape2,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::Reshape2Compute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/reshape_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void Compute(const lite::Tensor* in,
const lite::Tensor* actual_shape,
lite::Tensor* out) {
auto out_dims = out->dims();
auto in_dims = in->dims();
if (actual_shape) {
auto shape_dims = actual_shape->dims();
const int* shape_data = actual_shape->data<int>();
std::vector<int> shape =
std::vector<int>(shape_data, shape_data + shape_dims.production());
out_dims = lite::operators::ValidateShape(shape, in_dims);
out->Resize(out_dims);
}
out->CopyDataFrom(*in);
out->Resize(out_dims);
}
template <typename T>
class ReshapeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReshapeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
Compute<T>(param.x, param.actual_shape, param.output);
}
virtual ~ReshapeCompute() = default;
};
template <typename T>
void reshape2_compute() {}
template <typename T>
class Reshape2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReshapeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
Compute<T>(param.x, param.actual_shape, param.output);
}
virtual ~Reshape2Compute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/reshape_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
// reshape
TEST(reshape_x86, retrive_op) {
auto reshape =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape");
ASSERT_FALSE(reshape.empty());
ASSERT_TRUE(reshape.front());
}
TEST(reshape_x86, init) {
lite::kernels::x86::ReshapeCompute<float> reshape;
ASSERT_EQ(reshape.precision(), PRECISION(kFloat));
ASSERT_EQ(reshape.target(), TARGET(kX86));
}
TEST(reshape_x86, run_test) {
lite::Tensor x, actual_shape;
lite::Tensor out;
std::vector<int64_t> x_shape({1, 2, 4, 1});
x.Resize(lite::DDim(x_shape));
actual_shape.Resize(lite::DDim(std::vector<int64_t>({3})));
std::vector<int64_t> out_shape({1, 8, 1, 1});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto actual_data = actual_shape.mutable_data<int>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
actual_data[0] = 1;
actual_data[1] = 4;
actual_data[2] = 2;
std::vector<int> shape({1, 8, 1, 1});
// ReshapeCompute reshape;
ReshapeCompute<float> reshape;
operators::ReshapeParam param;
param.x = &x;
param.output = &out;
param.shape = shape;
param.actual_shape = &actual_shape;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
if (1 == i) param.actual_shape = nullptr;
reshape.SetContext(std::move(ctx));
reshape.SetParam(param);
reshape.Run();
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
}
// reshape2
TEST(reshape2_x86, retrive_op) {
auto reshape2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"reshape2");
ASSERT_FALSE(reshape2.empty());
ASSERT_TRUE(reshape2.front());
}
TEST(reshape2_x86, init) {
lite::kernels::x86::Reshape2Compute<float> reshape2;
ASSERT_EQ(reshape2.precision(), PRECISION(kFloat));
ASSERT_EQ(reshape2.target(), TARGET(kX86));
}
TEST(reshape2_x86, run_test) {
lite::Tensor x, actual_shape;
lite::Tensor out, xshape;
std::vector<int64_t> x_shape({1, 2, 4});
x.Resize(lite::DDim(x_shape));
actual_shape.Resize(lite::DDim(std::vector<int64_t>({3})));
std::vector<int64_t> out_shape({1, 4, 2});
out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({1, 2, 4});
xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>();
auto actual_data = actual_shape.mutable_data<int>();
auto out_data = out.mutable_data<float>();
auto xshape_data = xshape.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
xshape_data[i] = static_cast<float>(i);
}
actual_data[0] = 1;
actual_data[1] = 4;
actual_data[2] = 2;
std::vector<int> shape({0, -1, 2});
// Reshape2Compute reshape2;
Reshape2Compute<float> reshape2;
operators::ReshapeParam param;
param.x = &x;
param.output = &out;
param.xshape = &xshape;
param.shape = shape;
param.actual_shape = &actual_shape;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
if (1 == i) param.actual_shape = nullptr;
reshape2.SetContext(std::move(ctx));
reshape2.SetParam(param);
reshape2.Run();
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(reshape, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(reshape2, kX86, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/slice_compute.h"
REGISTER_LITE_KERNEL(slice,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SliceCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <algorithm>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/fluid/eigen.h"
#include "lite/operators/relu_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <size_t D>
void slice_compute(const lite::Tensor* in,
lite::Tensor* out,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
std::vector<int> decrease_axis) {
auto out_dims = out->dims();
auto in_dims = in->dims();
// resize out_dims
if (decrease_axis.size() > 0) {
if (decrease_axis.size() == (size_t)in_dims.size()) {
std::vector<int64_t> vec_origin_out_shape(decrease_axis.size(), 1);
// lite::DDim dims(vec_origin_out_shape);
out->Resize(vec_origin_out_shape);
} else {
std::vector<int64_t> vec_origin_out_shape(
out_dims.size() + decrease_axis.size(), -1);
for (size_t i = 0; i < decrease_axis.size(); ++i) {
vec_origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < vec_origin_out_shape.size(); ++i) {
if (-1 == vec_origin_out_shape[i]) {
vec_origin_out_shape[i] = out_dims[index];
++index;
}
}
// lite::DDim dims(vec_origin_out_shape);
out->Resize(vec_origin_out_shape);
}
}
out->mutable_data<float>(lite::TargetType::kX86);
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = new_out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
offsets[axes[i]] = start;
}
auto in_t =
lite::fluid::EigenTensor<float, D, Eigen::RowMajor, Eigen::DenseIndex>::
From(*in, in->dims());
auto out_t =
lite::fluid::EigenTensor<float, D, Eigen::RowMajor, Eigen::DenseIndex>::
From(*out, new_out_dims);
out_t = in_t.slice(offsets, extents);
out->Resize(out_dims);
}
template <typename T>
void slice_compute_(const lite::Tensor* Input,
lite::Tensor* Out,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
std::vector<int> decrease_axis) {
int rank = Input->dims().size();
switch (rank) {
case 1:
slice_compute<1>(Input, Out, axes, starts, ends, decrease_axis);
break;
case 2:
slice_compute<2>(Input, Out, axes, starts, ends, decrease_axis);
break;
case 3:
slice_compute<3>(Input, Out, axes, starts, ends, decrease_axis);
break;
case 4:
slice_compute<4>(Input, Out, axes, starts, ends, decrease_axis);
break;
case 5:
slice_compute<5>(Input, Out, axes, starts, ends, decrease_axis);
break;
case 6:
slice_compute<6>(Input, Out, axes, starts, ends, decrease_axis);
break;
}
}
template <typename T>
class SliceCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SliceParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
slice_compute_<T>(param.X,
param.Out,
param.axes,
param.starts,
param.ends,
param.decrease_axis);
}
virtual ~SliceCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/slice_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(slice_x86, retrive_op) {
auto slice =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("slice");
ASSERT_FALSE(slice.empty());
ASSERT_TRUE(slice.front());
}
TEST(slice_x86, init) {
lite::kernels::x86::SliceCompute<float> slice;
ASSERT_EQ(slice.precision(), PRECISION(kFloat));
ASSERT_EQ(slice.target(), TARGET(kX86));
}
void test_case1(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3});
std::vector<int> ends({3});
std::vector<int> axes({0});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
void test_case2(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0});
std::vector<int> ends({3, 100});
std::vector<int> axes({0, 1});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
void test_case3(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4, 2});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0, 2});
std::vector<int> ends({3, 100, -1});
std::vector<int> axes({0, 1, 2});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
void test_case4(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5, 6});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4, 2, 6});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0, 2});
std::vector<int> ends({3, 100, -1});
std::vector<int> axes({0, 1, 2});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
void test_case5(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5, 6, 3});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4, 2, 6, 3});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0, 2});
std::vector<int> ends({3, 100, -1});
std::vector<int> axes({0, 1, 2});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
void test_case6(lite::Tensor x, lite::Tensor out) {
std::vector<int64_t> x_shape({3, 4, 5, 6, 5, 2});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 4, 2, 6, 5, 2});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
std::vector<int> starts({-3, 0, 2});
std::vector<int> ends({3, 100, -1});
std::vector<int> axes({0, 1, 2});
// SliceCompute slice;
SliceCompute<float> slice;
operators::SliceParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
slice.SetContext(std::move(ctx));
slice.SetParam(param);
slice.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
TEST(slice_x86, run_test) {
lite::Tensor x;
lite::Tensor out;
test_case1(x, out);
test_case2(x, out);
test_case3(x, out);
test_case4(x, out);
test_case5(x, out);
test_case6(x, out);
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(slice, kX86, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/squeeze_compute.h"
REGISTER_LITE_KERNEL(squeeze,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SqueezeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(squeeze2,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::Squeeze2Compute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/squeeze_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SqueezeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SqueezeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto output = param.Out;
auto x_dims = x->dims();
auto* x_data = x->data<T>();
auto* out_data = output->mutable_data<T>();
memcpy(out_data, x_data, x_dims.production() * sizeof(T));
}
virtual ~SqueezeCompute() = default;
};
template <typename T>
class Squeeze2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SqueezeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto x_dims = x->dims();
auto* x_data = x->data<T>();
auto* out_data = output->mutable_data<T>();
auto* xshape_data = xshape->mutable_data<T>();
memcpy(out_data, x_data, x_dims.production() * sizeof(T));
memcpy(xshape_data, x_data, x_dims.production() * sizeof(T));
}
virtual ~Squeeze2Compute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/squeeze_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
// squeeze
TEST(squeeze_x86, retrive_op) {
auto squeeze =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"squeeze");
ASSERT_FALSE(squeeze.empty());
ASSERT_TRUE(squeeze.front());
}
TEST(squeeze_x86, init) {
lite::kernels::x86::SqueezeCompute<float> squeeze;
ASSERT_EQ(squeeze.precision(), PRECISION(kFloat));
ASSERT_EQ(squeeze.target(), TARGET(kX86));
}
TEST(squeeze_x86, run_test) {
lite::Tensor x;
lite::Tensor out;
std::vector<int64_t> x_shape({1, 3, 1, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
// SqueezeCompute squeeze;
SqueezeCompute<float> squeeze;
operators::SqueezeParam param;
param.X = &x;
param.Out = &out;
std::vector<std::vector<float>> ref_res({{3, 5}, {3, 5}});
std::vector<std::vector<int>> axes({{0, -2}, {}});
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
param.axes = axes[i];
squeeze.SetContext(std::move(ctx));
squeeze.SetParam(param);
squeeze.Run();
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
}
// squeeze2
TEST(squeeze2_x86, retrive_op) {
auto squeeze2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"squeeze2");
ASSERT_FALSE(squeeze2.empty());
ASSERT_TRUE(squeeze2.front());
}
TEST(squeeze2_x86, init) {
lite::kernels::x86::Squeeze2Compute<float> squeeze2;
ASSERT_EQ(squeeze2.precision(), PRECISION(kFloat));
ASSERT_EQ(squeeze2.target(), TARGET(kX86));
}
TEST(squeeze2_x86, run_test) {
lite::Tensor x;
lite::Tensor xshape;
lite::Tensor out;
std::vector<int64_t> x_shape({1, 3, 1, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5});
out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({1, 3, 1, 5});
xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto xshape_data = xshape.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
xshape_data[i] = static_cast<float>(i);
}
// Squeeze2Compute squeeze2;
Squeeze2Compute<float> squeeze2;
operators::SqueezeParam param;
param.X = &x;
param.Out = &out;
param.XShape = &xshape;
std::vector<std::vector<float>> ref_res({{3, 5}, {3, 5}});
std::vector<std::vector<int>> axes({{0, -2}, {}});
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
param.axes = axes[i];
squeeze2.SetContext(std::move(ctx));
squeeze2.SetParam(param);
squeeze2.Run();
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(squeeze, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(squeeze2, kX86, kFloat, kNCHW, def);
......@@ -60,6 +60,7 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
}
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize({static_cast<int64_t>(shape_vector.size())});
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
......@@ -83,6 +84,7 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
<< "The shape information must be set by Attr(shape).";
const std::vector<int> shape_vector = param_.shape;
lite::Tensor *shape_tensor = new lite::Tensor;
shape_tensor->Resize({static_cast<int64_t>(shape_vector.size())});
int *data_shape = shape_tensor->mutable_data<int>();
for (int i = 0; i < shape_vector.size(); i++) {
......@@ -120,18 +122,19 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
}
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
const DDim::value_type input_size = input_dims.production();
const lite::DDim::value_type input_size = input_dims.production();
auto input_shape = input_dims.Vectorize();
bool all_positive = std::all_of(input_shape.cbegin(),
input_shape.cend(),
[](DDim::value_type i) { return i > 0; });
bool all_positive = std::all_of(
input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) {
return i > 0;
});
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int unk_dim_val = -1;
const int copy_dim_val = 0;
std::vector<DDim::value_type> output_shape(shape.size(), 0);
DDim::value_type capacity = 1;
std::vector<lite::DDim::value_type> output_shape(shape.size(), 0);
lite::DDim::value_type capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
......@@ -147,10 +150,10 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
"be negtive except one unknown dimension.";
}
capacity *=
(shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_shape[i]);
output_shape[i] =
(shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_shape[i]);
capacity *= (shape[i] ? static_cast<lite::DDim::value_type>(shape[i])
: input_shape[i]);
output_shape[i] = (shape[i] ? static_cast<lite::DDim::value_type>(shape[i])
: input_shape[i]);
}
if (unk_dim_idx != -1) {
......@@ -168,7 +171,7 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
} else {
CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
}
return DDim(output_shape);
return lite::DDim(output_shape);
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册