未验证 提交 92c6f80b 编写于 作者: Z zhupengyang 提交者: GitHub

move is_empty kernel to host and add ut (#3412)

上级 a7db928e
......@@ -99,7 +99,6 @@ add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm)
# 4. training kernels
......
......@@ -5,6 +5,7 @@ add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kerne
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_host Host extra SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(is_empty_compute_host Host extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps})
add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps})
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
......@@ -12,19 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/is_empty_compute.h"
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/host/is_empty_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void IsEmptyCompute::PrepareForRun() {}
namespace host {
void IsEmptyCompute::Run() {
auto& param = this->Param<operators::IsEmptyParam>();
......@@ -32,16 +25,22 @@ void IsEmptyCompute::Run() {
param.Out->mutable_data<bool>()[0] = (count == 0);
}
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(is_empty,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::IsEmptyCompute,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::IsEmptyCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kBool),
DATALAYOUT(kAny))})
.Finalize();
......@@ -14,27 +14,23 @@
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/operators/logical_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
class IsEmptyCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class IsEmptyCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::IsEmptyParam;
void PrepareForRun() override;
void Run() override;
~IsEmptyCompute() {}
};
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -19,15 +19,20 @@ namespace paddle {
namespace lite {
namespace operators {
bool IsEmptyOp::CheckShape() const { return true; }
bool IsEmptyOp::CheckShape() const {
CHECK(param_.X);
CHECK(param_.Out);
return true;
}
bool IsEmptyOp::InferShapeImpl() const { return true; }
bool IsEmptyOp::InferShapeImpl() const {
param_.Out->Resize({1});
return true;
}
bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.X = scope->FindTensor(opdesc.Input("X").front());
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
CHECK(param_.X);
CHECK(param_.Out);
return true;
......
......@@ -78,6 +78,7 @@ endif()
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_interp_compute SRCS interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${bm_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shape_compute SRCS shape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_is_empty_compute SRCS is_empty_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${bm_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class IsEmptyComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "x";
std::string out_ = "out";
DDim x_dims_;
public:
IsEmptyComputeTester(const Place& place,
const std::string& alias,
DDim x_dims)
: TestCase(place, alias), x_dims_(x_dims) {}
void RunBaseline(Scope* scope) override {
const auto* x = scope->FindTensor(x_);
auto* out = scope->NewTensor(out_);
out->Resize(DDim({1}));
auto* out_data = out->mutable_data<bool>();
out_data[0] = (x->numel() == 0) ? true : false;
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("is_empty");
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
std::vector<float> din(x_dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, x_dims_.production());
SetCommonTensor(x_, x_dims_, din.data());
}
};
void TestIsEmptyHelper(Place place,
float abs_error,
std::vector<int64_t> x_dims) {
std::unique_ptr<arena::TestCase> tester(
new IsEmptyComputeTester(place, "def", DDim(x_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
void TestIsEmpty(Place place, float abs_error) {
TestIsEmptyHelper(place, abs_error, {2, 3, 4, 5});
TestIsEmptyHelper(place, abs_error, {0});
}
TEST(is_empty, precision) {
Place place;
float abs_error = 1e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kHost);
#else
return;
#endif
TestIsEmpty(place, abs_error);
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册