From 142ee7f26e93a8701869980e61b27f33e7669eb1 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 11 Jun 2020 11:18:14 +0800 Subject: [PATCH] unsqueeze&squeeze ops' xshape holds no data (#3762) --- lite/kernels/arm/CMakeLists.txt | 1 - lite/kernels/host/CMakeLists.txt | 1 + lite/kernels/{arm => host}/squeeze_compute.cc | 51 ++++--- lite/kernels/{arm => host}/squeeze_compute.h | 6 +- lite/kernels/host/unsqueeze_compute.cc | 4 - lite/kernels/x86/CMakeLists.txt | 2 - lite/kernels/x86/squeeze_compute.cc | 36 ----- lite/kernels/x86/squeeze_compute.h | 70 --------- lite/kernels/x86/squeeze_compute_test.cc | 142 ------------------ lite/operators/squeeze_op.cc | 14 +- lite/operators/unsqueeze_op.cc | 14 +- lite/tests/kernels/squeeze_compute_test.cc | 20 +-- lite/tests/kernels/unsqueeze_compute_test.cc | 20 +-- 13 files changed, 54 insertions(+), 327 deletions(-) rename lite/kernels/{arm => host}/squeeze_compute.cc (55%) rename lite/kernels/{arm => host}/squeeze_compute.h (84%) delete mode 100644 lite/kernels/x86/squeeze_compute.cc delete mode 100644 lite/kernels/x86/squeeze_compute.h delete mode 100644 lite/kernels/x86/squeeze_compute_test.cc diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 8246f200ae..0ab86b0f04 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -39,7 +39,6 @@ add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc DEPS ${ add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 4334ee220a..a70345708c 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -3,6 +3,7 @@ message(STATUS "compile with lite host kernels") add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(squeeze_compute_host Host basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(unsqueeze_compute_host Host basic SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) add_kernel(expand_compute_host Host basic SRCS expand_compute.cc DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/arm/squeeze_compute.cc b/lite/kernels/host/squeeze_compute.cc similarity index 55% rename from lite/kernels/arm/squeeze_compute.cc rename to lite/kernels/host/squeeze_compute.cc index 0f79d5c385..c1e24b697b 100644 --- a/lite/kernels/arm/squeeze_compute.cc +++ b/lite/kernels/host/squeeze_compute.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/squeeze_compute.h" +#include "lite/kernels/host/squeeze_compute.h" #include namespace paddle { @@ -24,23 +24,18 @@ void SqueezeCompute::Run() { auto& param = Param(); auto x = param.X; auto output = param.Out; - auto x_dims = x->dims(); - auto* x_data = x->data(); - auto* out_data = output->mutable_data(); - memcpy(out_data, x_data, x_dims.production() * sizeof(float)); + auto output_dims = output->dims(); + output->CopyDataFrom(*x); + output->Resize(output_dims); } void Squeeze2Compute::Run() { auto& param = Param(); auto x = param.X; auto output = param.Out; - auto xshape = param.XShape; - auto x_dims = x->dims(); - auto* x_data = x->data(); - auto* out_data = output->mutable_data(); - auto* xshape_data = xshape->mutable_data(); - memcpy(out_data, x_data, x_dims.production() * sizeof(float)); - memcpy(xshape_data, x_data, x_dims.production() * sizeof(float)); + auto output_dims = output->dims(); + output->CopyDataFrom(*x); + output->Resize(output_dims); } } // namespace host @@ -49,22 +44,32 @@ void Squeeze2Compute::Run() { } // namespace paddle REGISTER_LITE_KERNEL(squeeze, - kARM, - kFloat, - kNCHW, + kHost, + kAny, + kAny, paddle::lite::kernels::host::SqueezeCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); REGISTER_LITE_KERNEL(squeeze2, - kARM, - kFloat, - kNCHW, + kHost, + kAny, + kAny, paddle::lite::kernels::host::Squeeze2Compute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("XShape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); diff --git a/lite/kernels/arm/squeeze_compute.h b/lite/kernels/host/squeeze_compute.h similarity index 84% rename from lite/kernels/arm/squeeze_compute.h rename to lite/kernels/host/squeeze_compute.h index c9e4c2a17c..3a30f0bd8d 100644 --- a/lite/kernels/arm/squeeze_compute.h +++ b/lite/kernels/host/squeeze_compute.h @@ -22,14 +22,16 @@ namespace lite { namespace kernels { namespace host { -class SqueezeCompute : public KernelLite { +class SqueezeCompute + : public KernelLite { public: void Run() override; virtual ~SqueezeCompute() = default; }; -class Squeeze2Compute : public KernelLite { +class Squeeze2Compute + : public KernelLite { public: void Run() override; diff --git a/lite/kernels/host/unsqueeze_compute.cc b/lite/kernels/host/unsqueeze_compute.cc index aa525880af..153a860b06 100644 --- a/lite/kernels/host/unsqueeze_compute.cc +++ b/lite/kernels/host/unsqueeze_compute.cc @@ -34,13 +34,9 @@ void Unsqueeze2Compute::Run() { auto& param = Param(); auto x = param.X; auto output = param.Out; - auto xshape = param.XShape; auto output_dims = output->dims(); - auto xshape_dims = xshape->dims(); output->CopyDataFrom(*x); - xshape->CopyDataFrom(*x); output->Resize(output_dims); - xshape->Resize(xshape_dims); } } // namespace host diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index bbc67a242c..521fbb6b24 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -11,7 +11,6 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} fluid_data_type) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) -add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) @@ -74,7 +73,6 @@ lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) lite_cc_test(test_gather_compute_x86 SRCS gather_compute_test.cc DEPS gather_compute_x86) lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86) -lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86) lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86) lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86) lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) diff --git a/lite/kernels/x86/squeeze_compute.cc b/lite/kernels/x86/squeeze_compute.cc deleted file mode 100644 index 17ecd0c49b..0000000000 --- a/lite/kernels/x86/squeeze_compute.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/x86/squeeze_compute.h" - -REGISTER_LITE_KERNEL(squeeze, - kX86, - kFloat, - kNCHW, - paddle::lite::kernels::x86::SqueezeCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) - .Finalize(); - -REGISTER_LITE_KERNEL(squeeze2, - kX86, - kFloat, - kNCHW, - paddle::lite::kernels::x86::Squeeze2Compute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) - .Finalize(); diff --git a/lite/kernels/x86/squeeze_compute.h b/lite/kernels/x86/squeeze_compute.h deleted file mode 100644 index 3288421c14..0000000000 --- a/lite/kernels/x86/squeeze_compute.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include -#include "lite/core/kernel.h" -#include "lite/core/op_lite.h" -#include "lite/core/op_registry.h" -#include "lite/core/type_system.h" -#include "lite/operators/squeeze_op.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace x86 { - -template -class SqueezeCompute : public KernelLite { - public: - using param_t = operators::SqueezeParam; - - void Run() override { - auto& param = *param_.get_mutable(); - auto x = param.X; - auto output = param.Out; - auto x_dims = x->dims(); - auto* x_data = x->template data(); - auto* out_data = output->template mutable_data(); - memcpy(out_data, x_data, x_dims.production() * sizeof(T)); - } - - virtual ~SqueezeCompute() = default; -}; - -template -class Squeeze2Compute : public KernelLite { - public: - using param_t = operators::SqueezeParam; - - void Run() override { - auto& param = *param_.get_mutable(); - auto x = param.X; - auto output = param.Out; - auto xshape = param.XShape; - auto x_dims = x->dims(); - auto* x_data = x->template data(); - auto* out_data = output->template mutable_data(); - auto* xshape_data = xshape->template mutable_data(); - memcpy(out_data, x_data, x_dims.production() * sizeof(T)); - memcpy(xshape_data, x_data, x_dims.production() * sizeof(T)); - } - - virtual ~Squeeze2Compute() = default; -}; - -} // namespace x86 -} // namespace kernels -} // namespace lite -} // namespace paddle diff --git a/lite/kernels/x86/squeeze_compute_test.cc b/lite/kernels/x86/squeeze_compute_test.cc deleted file mode 100644 index 0799a522b3..0000000000 --- a/lite/kernels/x86/squeeze_compute_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/x86/squeeze_compute.h" -#include -#include -#include -#include -#include "lite/core/op_registry.h" -namespace paddle { -namespace lite { -namespace kernels { -namespace x86 { - -// squeeze -TEST(squeeze_x86, retrive_op) { - auto squeeze = - KernelRegistry::Global().Create( - "squeeze"); - ASSERT_FALSE(squeeze.empty()); - ASSERT_TRUE(squeeze.front()); -} - -TEST(squeeze_x86, init) { - lite::kernels::x86::SqueezeCompute squeeze; - ASSERT_EQ(squeeze.precision(), PRECISION(kFloat)); - ASSERT_EQ(squeeze.target(), TARGET(kX86)); -} - -TEST(squeeze_x86, run_test) { - lite::Tensor x; - lite::Tensor out; - std::vector x_shape({1, 3, 1, 5}); - x.Resize(lite::DDim(x_shape)); - std::vector out_shape({3, 5}); - out.Resize(lite::DDim(out_shape)); - - auto x_data = x.mutable_data(); - auto out_data = out.mutable_data(); - - for (int64_t i = 0; i < x.dims().production(); ++i) { - x_data[i] = static_cast(i); - } - - // SqueezeCompute squeeze; - SqueezeCompute squeeze; - operators::SqueezeParam param; - - param.X = &x; - param.Out = &out; - std::vector> ref_res({{3, 5}, {3, 5}}); - std::vector> axes({{0, -2}, {}}); - std::unique_ptr ctx(new KernelContext); - ctx->As(); - for (int i = 0; i < 2; ++i) { - param.axes = axes[i]; - squeeze.SetContext(std::move(ctx)); - squeeze.SetParam(param); - squeeze.Run(); - - for (int j = 0; j < out.dims().production(); ++j) { - EXPECT_NEAR(out_data[j], x_data[j], 1e-5); - } - } -} - -// squeeze2 -TEST(squeeze2_x86, retrive_op) { - auto squeeze2 = - KernelRegistry::Global().Create( - "squeeze2"); - ASSERT_FALSE(squeeze2.empty()); - ASSERT_TRUE(squeeze2.front()); -} - -TEST(squeeze2_x86, init) { - lite::kernels::x86::Squeeze2Compute squeeze2; - ASSERT_EQ(squeeze2.precision(), PRECISION(kFloat)); - ASSERT_EQ(squeeze2.target(), TARGET(kX86)); -} - -TEST(squeeze2_x86, run_test) { - lite::Tensor x; - lite::Tensor xshape; - lite::Tensor out; - std::vector x_shape({1, 3, 1, 5}); - x.Resize(lite::DDim(x_shape)); - std::vector out_shape({3, 5}); - out.Resize(lite::DDim(out_shape)); - std::vector xshape_shape({1, 3, 1, 5}); - xshape.Resize(lite::DDim(xshape_shape)); - - auto x_data = x.mutable_data(); - auto out_data = out.mutable_data(); - auto xshape_data = xshape.mutable_data(); - - for (int64_t i = 0; i < x.dims().production(); ++i) { - x_data[i] = static_cast(i); - xshape_data[i] = static_cast(i); - } - - // Squeeze2Compute squeeze2; - Squeeze2Compute squeeze2; - operators::SqueezeParam param; - - param.X = &x; - param.Out = &out; - param.XShape = &xshape; - std::vector> ref_res({{3, 5}, {3, 5}}); - std::vector> axes({{0, -2}, {}}); - std::unique_ptr ctx(new KernelContext); - ctx->As(); - for (int i = 0; i < 2; ++i) { - param.axes = axes[i]; - squeeze2.SetContext(std::move(ctx)); - squeeze2.SetParam(param); - squeeze2.Run(); - - for (int j = 0; j < out.dims().production(); ++j) { - EXPECT_NEAR(out_data[j], x_data[j], 1e-5); - } - } -} - -} // namespace x86 -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(squeeze, kX86, kFloat, kNCHW, def); -USE_LITE_KERNEL(squeeze2, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc index 8dada8fed0..cf8d8592a3 100644 --- a/lite/operators/squeeze_op.cc +++ b/lite/operators/squeeze_op.cc @@ -85,12 +85,8 @@ bool SqueezeOp::InferShapeImpl() const { bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { AttachParam(¶m_); - auto x_var = scope->FindVar(opdesc.Input("X").front()); - auto output_var = scope->FindVar(opdesc.Output("Out").front()); - CHECK(x_var); - CHECK(output_var); - param_.X = const_cast(&(x_var->Get())); - param_.Out = output_var->GetMutable(); + param_.X = scope->FindTensor(opdesc.Input("X").front()); + param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front()); if (opdesc.HasAttr("axes")) { param_.axes = opdesc.GetAttr>("axes"); @@ -109,7 +105,7 @@ bool Squeeze2Op::CheckShape() const { bool Squeeze2Op::InferShapeImpl() const { SqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); - std::vector xshape_dims(x_dims.size() + 1, 1); + std::vector xshape_dims(x_dims.size() + 1, 0); for (size_t i = 0; i < x_dims.size(); i++) { xshape_dims[i + 1] = x_dims[i]; } @@ -119,9 +115,7 @@ bool Squeeze2Op::InferShapeImpl() const { bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { SqueezeOp::AttachImpl(opdesc, scope); - auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); - CHECK(xshape_var); - param_.XShape = xshape_var->GetMutable(); + param_.XShape = scope->FindMutableTensor(opdesc.Output("XShape").front()); CHECK(param_.XShape) << "Output(XShape) of SqueezeOp should not be null."; return true; } diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 23865aaabb..287baf838e 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -90,12 +90,8 @@ bool UnsqueezeOp::InferShapeImpl() const { bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { AttachParam(¶m_); - auto x_var = scope->FindVar(opdesc.Input("X").front()); - auto output_var = scope->FindVar(opdesc.Output("Out").front()); - CHECK(x_var); - CHECK(output_var); - param_.X = const_cast(&(x_var->Get())); - param_.Out = output_var->GetMutable(); + param_.X = scope->FindTensor(opdesc.Input("X").front()); + param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front()); if (opdesc.HasAttr("axes")) { param_.axes = opdesc.GetAttr>("axes"); @@ -133,7 +129,7 @@ bool Unsqueeze2Op::CheckShape() const { bool Unsqueeze2Op::InferShapeImpl() const { UnsqueezeOp::InferShapeImpl(); auto x_dims = param_.X->dims(); - std::vector xshape_dims(x_dims.size() + 1, 1); + std::vector xshape_dims(x_dims.size() + 1, 0); for (size_t i = 0; i < x_dims.size(); i++) { xshape_dims[i + 1] = x_dims[i]; } @@ -143,9 +139,7 @@ bool Unsqueeze2Op::InferShapeImpl() const { bool Unsqueeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { UnsqueezeOp::AttachImpl(opdesc, scope); - auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); - CHECK(xshape_var); - param_.XShape = xshape_var->GetMutable(); + param_.XShape = scope->FindMutableTensor(opdesc.Output("XShape").front()); CHECK(param_.XShape) << "Output(XShape) of Unsqueeze2Op should not be null."; return true; } diff --git a/lite/tests/kernels/squeeze_compute_test.cc b/lite/tests/kernels/squeeze_compute_test.cc index 30c56d532e..0fe4f360d4 100644 --- a/lite/tests/kernels/squeeze_compute_test.cc +++ b/lite/tests/kernels/squeeze_compute_test.cc @@ -123,7 +123,7 @@ class Squeeze2ComputeTester : public arena::TestCase { CHECK(out); auto* xshape = scope->NewTensor(xshape_); CHECK(xshape); - std::vector xshape_sp(dims_.size() + 1, 1); + std::vector xshape_sp(dims_.size() + 1, 0); for (size_t i = 0; i < dims_.size(); ++i) { xshape_sp[i + 1] = dims_[i]; } @@ -169,9 +169,7 @@ class Squeeze2ComputeTester : public arena::TestCase { auto* input_data = input->data(); auto* out_data = out->mutable_data(); - auto* xshape_data = xshape->mutable_data(); memcpy(out_data, input_data, sizeof(float) * dims_.production()); - memcpy(xshape_data, input_data, sizeof(float) * dims_.production()); } void PrepareOpDesc(cpp::OpDesc* op_desc) { @@ -221,7 +219,7 @@ void test_squeeze2(Place place) { std::unique_ptr tester(new Squeeze2ComputeTester( place, "def", axes, DDim({N, C, H, W}))); arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); + arena.TestPrecision({"XShape"}); } } } @@ -230,23 +228,17 @@ void test_squeeze2(Place place) { } TEST(squeeze, precision) { -#ifdef LITE_WITH_X86 - Place place(TARGET(kX86)); +#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) + Place place(TARGET(kHost)); #endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); test_squeeze(place); -#endif } TEST(squeeze2, precision) { -#ifdef LITE_WITH_X86 - Place place(TARGET(kX86)); +#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) + Place place(TARGET(kHost)); #endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); test_squeeze2(place); -#endif } } // namespace lite diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index c59e732d7d..c0b7890af2 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -153,7 +153,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase { CHECK(out); auto* xshape = scope->NewTensor(xshape_); CHECK(xshape); - std::vector xshape_sp(dims_.size() + 1, 1); + std::vector xshape_sp(dims_.size() + 1, 0); for (size_t i = 0; i < dims_.size(); ++i) { xshape_sp[i + 1] = dims_[i]; } @@ -198,9 +198,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase { auto* input_data = input->data(); auto* out_data = out->mutable_data(); - auto* xshape_data = xshape->mutable_data(); memcpy(out_data, input_data, sizeof(float) * dims_.production()); - memcpy(xshape_data, input_data, sizeof(float) * dims_.production()); } void PrepareOpDesc(cpp::OpDesc* op_desc) { @@ -238,9 +236,7 @@ void test_unsqueeze(Place place, float abs_error = 2e-5) { } } -void test_unsqueeze2(Place place, - float abs_error = 2e-5, - std::vector ignored_outs = {}) { +void test_unsqueeze2(Place place, float abs_error = 2e-5) { for (std::vector axes : {std::vector({0}), std::vector({0, 2}), std::vector({0, -2})}) { @@ -252,7 +248,7 @@ void test_unsqueeze2(Place place, std::unique_ptr tester( new Unsqueeze2ComputeTester(place, "def", axes, DDim(dims))); arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(ignored_outs); + arena.TestPrecision({"XShape"}); } } } @@ -263,7 +259,7 @@ TEST(unsqueeze, precision) { #ifdef LITE_WITH_NPU place = TARGET(kNPU); abs_error = 1e-2; // Using fp16 in NPU -#else +#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) place = TARGET(kHost); #endif test_unsqueeze(place, abs_error); @@ -272,16 +268,14 @@ TEST(unsqueeze, precision) { TEST(unsqueeze2, precision) { Place place; float abs_error = 2e-5; - std::vector ignored_outs = {}; #ifdef LITE_WITH_NPU place = TARGET(kNPU); - abs_error = 1e-2; // Using fp16 in NPU - ignored_outs.push_back("XShape"); // not supported out in NPU -#else + abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) place = TARGET(kHost); #endif - test_unsqueeze2(place, abs_error, ignored_outs); + test_unsqueeze2(place, abs_error); } } // namespace lite -- GitLab