未验证 提交 142ee7f2 编写于 作者: Z zhupengyang 提交者: GitHub

unsqueeze&squeeze ops' xshape holds no data (#3762)

上级 24d37695
......@@ -39,7 +39,6 @@ add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc DEPS ${
add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -3,6 +3,7 @@ message(STATUS "compile with lite host kernels")
add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_host Host basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(unsqueeze_compute_host Host basic SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(expand_compute_host Host basic SRCS expand_compute.cc DEPS ${lite_kernel_deps})
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/squeeze_compute.h"
#include "lite/kernels/host/squeeze_compute.h"
#include <vector>
namespace paddle {
......@@ -24,23 +24,18 @@ void SqueezeCompute::Run() {
auto& param = Param<operators::SqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
auto output_dims = output->dims();
output->CopyDataFrom(*x);
output->Resize(output_dims);
}
void Squeeze2Compute::Run() {
auto& param = Param<operators::SqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
memcpy(xshape_data, x_data, x_dims.production() * sizeof(float));
auto output_dims = output->dims();
output->CopyDataFrom(*x);
output->Resize(output_dims);
}
} // namespace host
......@@ -49,22 +44,32 @@ void Squeeze2Compute::Run() {
} // namespace paddle
REGISTER_LITE_KERNEL(squeeze,
kARM,
kFloat,
kNCHW,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::SqueezeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
REGISTER_LITE_KERNEL(squeeze2,
kARM,
kFloat,
kNCHW,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::Squeeze2Compute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("XShape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
......@@ -22,14 +22,16 @@ namespace lite {
namespace kernels {
namespace host {
class SqueezeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class SqueezeCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
virtual ~SqueezeCompute() = default;
};
class Squeeze2Compute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class Squeeze2Compute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
......
......@@ -34,13 +34,9 @@ void Unsqueeze2Compute::Run() {
auto& param = Param<operators::UnsqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto output_dims = output->dims();
auto xshape_dims = xshape->dims();
output->CopyDataFrom(*x);
xshape->CopyDataFrom(*x);
output->Resize(output_dims);
xshape->Resize(xshape_dims);
}
} // namespace host
......
......@@ -11,7 +11,6 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} fluid_data_type)
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function)
add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
......@@ -74,7 +73,6 @@ lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_gather_compute_x86 SRCS gather_compute_test.cc DEPS gather_compute_x86)
lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86)
lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86)
lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86)
lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/squeeze_compute.h"
REGISTER_LITE_KERNEL(squeeze,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SqueezeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(squeeze2,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::Squeeze2Compute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/squeeze_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SqueezeCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SqueezeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto output = param.Out;
auto x_dims = x->dims();
auto* x_data = x->template data<T>();
auto* out_data = output->template mutable_data<T>();
memcpy(out_data, x_data, x_dims.production() * sizeof(T));
}
virtual ~SqueezeCompute() = default;
};
template <typename T>
class Squeeze2Compute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SqueezeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto x_dims = x->dims();
auto* x_data = x->template data<T>();
auto* out_data = output->template mutable_data<T>();
auto* xshape_data = xshape->template mutable_data<T>();
memcpy(out_data, x_data, x_dims.production() * sizeof(T));
memcpy(xshape_data, x_data, x_dims.production() * sizeof(T));
}
virtual ~Squeeze2Compute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/squeeze_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
// squeeze
TEST(squeeze_x86, retrive_op) {
auto squeeze =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"squeeze");
ASSERT_FALSE(squeeze.empty());
ASSERT_TRUE(squeeze.front());
}
TEST(squeeze_x86, init) {
lite::kernels::x86::SqueezeCompute<float> squeeze;
ASSERT_EQ(squeeze.precision(), PRECISION(kFloat));
ASSERT_EQ(squeeze.target(), TARGET(kX86));
}
TEST(squeeze_x86, run_test) {
lite::Tensor x;
lite::Tensor out;
std::vector<int64_t> x_shape({1, 3, 1, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5});
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
}
// SqueezeCompute squeeze;
SqueezeCompute<float> squeeze;
operators::SqueezeParam param;
param.X = &x;
param.Out = &out;
std::vector<std::vector<float>> ref_res({{3, 5}, {3, 5}});
std::vector<std::vector<int>> axes({{0, -2}, {}});
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
param.axes = axes[i];
squeeze.SetContext(std::move(ctx));
squeeze.SetParam(param);
squeeze.Run();
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
}
// squeeze2
TEST(squeeze2_x86, retrive_op) {
auto squeeze2 =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"squeeze2");
ASSERT_FALSE(squeeze2.empty());
ASSERT_TRUE(squeeze2.front());
}
TEST(squeeze2_x86, init) {
lite::kernels::x86::Squeeze2Compute<float> squeeze2;
ASSERT_EQ(squeeze2.precision(), PRECISION(kFloat));
ASSERT_EQ(squeeze2.target(), TARGET(kX86));
}
TEST(squeeze2_x86, run_test) {
lite::Tensor x;
lite::Tensor xshape;
lite::Tensor out;
std::vector<int64_t> x_shape({1, 3, 1, 5});
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape({3, 5});
out.Resize(lite::DDim(out_shape));
std::vector<int64_t> xshape_shape({1, 3, 1, 5});
xshape.Resize(lite::DDim(xshape_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<float>();
auto xshape_data = xshape.mutable_data<float>();
for (int64_t i = 0; i < x.dims().production(); ++i) {
x_data[i] = static_cast<float>(i);
xshape_data[i] = static_cast<float>(i);
}
// Squeeze2Compute squeeze2;
Squeeze2Compute<float> squeeze2;
operators::SqueezeParam param;
param.X = &x;
param.Out = &out;
param.XShape = &xshape;
std::vector<std::vector<float>> ref_res({{3, 5}, {3, 5}});
std::vector<std::vector<int>> axes({{0, -2}, {}});
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
for (int i = 0; i < 2; ++i) {
param.axes = axes[i];
squeeze2.SetContext(std::move(ctx));
squeeze2.SetParam(param);
squeeze2.Run();
for (int j = 0; j < out.dims().production(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(squeeze, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(squeeze2, kX86, kFloat, kNCHW, def);
......@@ -85,12 +85,8 @@ bool SqueezeOp::InferShapeImpl() const {
bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.X = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.Out = output_var->GetMutable<lite::Tensor>();
param_.X = scope->FindTensor(opdesc.Input("X").front());
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
if (opdesc.HasAttr("axes")) {
param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
......@@ -109,7 +105,7 @@ bool Squeeze2Op::CheckShape() const {
bool Squeeze2Op::InferShapeImpl() const {
SqueezeOp::InferShapeImpl();
auto x_dims = param_.X->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
......@@ -119,9 +115,7 @@ bool Squeeze2Op::InferShapeImpl() const {
bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
SqueezeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.XShape = xshape_var->GetMutable<lite::Tensor>();
param_.XShape = scope->FindMutableTensor(opdesc.Output("XShape").front());
CHECK(param_.XShape) << "Output(XShape) of SqueezeOp should not be null.";
return true;
}
......
......@@ -90,12 +90,8 @@ bool UnsqueezeOp::InferShapeImpl() const {
bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.X = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.Out = output_var->GetMutable<lite::Tensor>();
param_.X = scope->FindTensor(opdesc.Input("X").front());
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
if (opdesc.HasAttr("axes")) {
param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
......@@ -133,7 +129,7 @@ bool Unsqueeze2Op::CheckShape() const {
bool Unsqueeze2Op::InferShapeImpl() const {
UnsqueezeOp::InferShapeImpl();
auto x_dims = param_.X->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
......@@ -143,9 +139,7 @@ bool Unsqueeze2Op::InferShapeImpl() const {
bool Unsqueeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
UnsqueezeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.XShape = xshape_var->GetMutable<lite::Tensor>();
param_.XShape = scope->FindMutableTensor(opdesc.Output("XShape").front());
CHECK(param_.XShape) << "Output(XShape) of Unsqueeze2Op should not be null.";
return true;
}
......
......@@ -123,7 +123,7 @@ class Squeeze2ComputeTester : public arena::TestCase {
CHECK(out);
auto* xshape = scope->NewTensor(xshape_);
CHECK(xshape);
std::vector<int64_t> xshape_sp(dims_.size() + 1, 1);
std::vector<int64_t> xshape_sp(dims_.size() + 1, 0);
for (size_t i = 0; i < dims_.size(); ++i) {
xshape_sp[i + 1] = dims_[i];
}
......@@ -169,9 +169,7 @@ class Squeeze2ComputeTester : public arena::TestCase {
auto* input_data = input->data<float>();
auto* out_data = out->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, input_data, sizeof(float) * dims_.production());
memcpy(xshape_data, input_data, sizeof(float) * dims_.production());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
......@@ -221,7 +219,7 @@ void test_squeeze2(Place place) {
std::unique_ptr<arena::TestCase> tester(new Squeeze2ComputeTester(
place, "def", axes, DDim({N, C, H, W})));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
arena.TestPrecision({"XShape"});
}
}
}
......@@ -230,23 +228,17 @@ void test_squeeze2(Place place) {
}
TEST(squeeze, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
Place place(TARGET(kHost));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_squeeze(place);
#endif
}
TEST(squeeze2, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
Place place(TARGET(kHost));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_squeeze2(place);
#endif
}
} // namespace lite
......
......@@ -153,7 +153,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase {
CHECK(out);
auto* xshape = scope->NewTensor(xshape_);
CHECK(xshape);
std::vector<int64_t> xshape_sp(dims_.size() + 1, 1);
std::vector<int64_t> xshape_sp(dims_.size() + 1, 0);
for (size_t i = 0; i < dims_.size(); ++i) {
xshape_sp[i + 1] = dims_[i];
}
......@@ -198,9 +198,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase {
auto* input_data = input->data<float>();
auto* out_data = out->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, input_data, sizeof(float) * dims_.production());
memcpy(xshape_data, input_data, sizeof(float) * dims_.production());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
......@@ -238,9 +236,7 @@ void test_unsqueeze(Place place, float abs_error = 2e-5) {
}
}
void test_unsqueeze2(Place place,
float abs_error = 2e-5,
std::vector<std::string> ignored_outs = {}) {
void test_unsqueeze2(Place place, float abs_error = 2e-5) {
for (std::vector<int> axes : {std::vector<int>({0}),
std::vector<int>({0, 2}),
std::vector<int>({0, -2})}) {
......@@ -252,7 +248,7 @@ void test_unsqueeze2(Place place,
std::unique_ptr<arena::TestCase> tester(
new Unsqueeze2ComputeTester(place, "def", axes, DDim(dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(ignored_outs);
arena.TestPrecision({"XShape"});
}
}
}
......@@ -263,7 +259,7 @@ TEST(unsqueeze, precision) {
#ifdef LITE_WITH_NPU
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#else
#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
place = TARGET(kHost);
#endif
test_unsqueeze(place, abs_error);
......@@ -272,16 +268,14 @@ TEST(unsqueeze, precision) {
TEST(unsqueeze2, precision) {
Place place;
float abs_error = 2e-5;
std::vector<std::string> ignored_outs = {};
#ifdef LITE_WITH_NPU
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
ignored_outs.push_back("XShape"); // not supported out in NPU
#else
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
place = TARGET(kHost);
#endif
test_unsqueeze2(place, abs_error, ignored_outs);
test_unsqueeze2(place, abs_error);
}
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册