未验证 提交 feea5977 编写于 作者: W Wilber 提交者: GitHub

[host] [Kernel] Add unsqueeze kernel for host. (#3629)

movs unsqueeze frm arm to host.
上级 546d4da8
......@@ -40,7 +40,6 @@ add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc DEPS ${lite
add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(unsqueeze_compute_arm ARM basic SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -3,6 +3,7 @@ message(STATUS "compile with lite host kernels")
add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(unsqueeze_compute_host Host basic SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(expand_compute_host Host basic SRCS expand_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_host Host extra SRCS shape_compute.cc DEPS ${lite_kernel_deps})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/unsqueeze_compute.h"
#include "lite/kernels/host/unsqueeze_compute.h"
#include <vector>
namespace paddle {
......@@ -24,10 +25,9 @@ void UnsqueezeCompute::Run() {
auto& param = Param<operators::UnsqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
auto output_dims = output->dims();
output->CopyDataFrom(*x);
output->Resize(output_dims);
}
void Unsqueeze2Compute::Run() {
......@@ -35,12 +35,12 @@ void Unsqueeze2Compute::Run() {
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
memcpy(xshape_data, x_data, x_dims.production() * sizeof(float));
auto output_dims = output->dims();
auto xshape_dims = xshape->dims();
output->CopyDataFrom(*x);
xshape->CopyDataFrom(*x);
output->Resize(output_dims);
xshape->Resize(xshape_dims);
}
} // namespace host
......@@ -49,30 +49,44 @@ void Unsqueeze2Compute::Run() {
} // namespace paddle
REGISTER_LITE_KERNEL(unsqueeze,
kARM,
kFloat,
kNCHW,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::UnsqueezeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
REGISTER_LITE_KERNEL(unsqueeze2,
kARM,
kFloat,
kNCHW,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::Unsqueeze2Compute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("XShape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -22,14 +22,16 @@ namespace lite {
namespace kernels {
namespace host {
class UnsqueezeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class UnsqueezeCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
virtual ~UnsqueezeCompute() = default;
};
class Unsqueeze2Compute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class Unsqueeze2Compute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
......
......@@ -47,7 +47,7 @@ class SqueezeComputeTester : public arena::TestCase {
bool should_squeeze[9] = {false};
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
for (size_t idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
......@@ -71,7 +71,7 @@ class SqueezeComputeTester : public arena::TestCase {
}
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
for (size_t in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
......@@ -135,7 +135,7 @@ class Squeeze2ComputeTester : public arena::TestCase {
bool should_squeeze[9] = {false};
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
for (size_t idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
......@@ -159,7 +159,7 @@ class Squeeze2ComputeTester : public arena::TestCase {
}
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
for (size_t in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
......
......@@ -84,7 +84,6 @@ class UnsqueezeComputeTester : public arena::TestCase {
output_shape[out_idx] = in_dims[in_idx++];
}
}
for (size_t i = 0; i < output_shape.size(); ++i)
out->Resize(DDim(output_shape));
auto* input_data = input->data<float>();
auto* out_data = out->mutable_data<float>();
......@@ -258,22 +257,19 @@ void test_unsqueeze2(Place place,
}
}
TEST(squeeze, precision) {
TEST(unsqueeze, precision) {
Place place;
float abs_error = 2e-5;
#ifdef LITE_WITH_NPU
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
place = TARGET(kHost);
#endif
test_unsqueeze(place, abs_error);
}
TEST(squeeze2, precision) {
TEST(unsqueeze2, precision) {
Place place;
float abs_error = 2e-5;
std::vector<std::string> ignored_outs = {};
......@@ -281,10 +277,8 @@ TEST(squeeze2, precision) {
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
ignored_outs.push_back("XShape"); // not supported out in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
place = TARGET(kHost);
#endif
test_unsqueeze2(place, abs_error, ignored_outs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册