提交 32e977af 编写于 作者: S Shixiaowei02

test: kernel

上级 09a07a23
......@@ -96,6 +96,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
#undef INIT_FOR
......
......@@ -50,7 +50,7 @@ void CalibCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(calib, kARM, kAny, kAny,
REGISTER_LITE_KERNEL(calib, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::CalibCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -21,12 +21,10 @@ namespace lite {
namespace kernels {
namespace arm {
class CalibCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
class CalibCompute : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::CalibParam;
// void PrepareForRun() override;
void Run() override;
~CalibCompute() override{};
......
......@@ -84,14 +84,14 @@ void calib_ref(const operators::CalibParam& param) {
TEST(calib_arm, retrive_op) {
auto calib =
KernelRegistry::Global()
.Create<TARGET(kARM), PRECISION(kAny), DATALAYOUT(kAny)>("calib");
.Create<TARGET(kARM), PRECISION(kInt8), DATALAYOUT(kNCHW)>("calib");
ASSERT_FALSE(calib.empty());
ASSERT_TRUE(calib.front());
}
TEST(calib_arm, init) {
CalibCompute calib;
ASSERT_EQ(calib.precision(), PRECISION(kAny));
ASSERT_EQ(calib.precision(), PRECISION(kInt8));
ASSERT_EQ(calib.target(), TARGET(kARM));
}
......@@ -146,4 +146,4 @@ TEST(calib_arm, int8_to_fp32) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(calib, kARM, kAny, kAny, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, def);
......@@ -27,6 +27,13 @@ namespace paddle {
namespace lite {
namespace operators {
/*
* The data types used by the two adjacent layers in the model should
* be the same. When the two operators accept different data types,
* we may need to implicitly add a data type conversion operator.
* Currently, this operator only supports mutual conversion of int8
* and float32 types.
*/
class CalibOpLite : public OpLite {
public:
CalibOpLite() {}
......@@ -37,14 +44,6 @@ class CalibOpLite : public OpLite {
bool InferShape() const override;
/*
bool Run() override {
CHECK(kernel_);
kernel_->Run();
return true;
}
*/
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope);
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
......@@ -48,9 +48,9 @@ TEST(calib_op_lite, TestARM) {
CalibOpLite calib("calib");
calib.SetValidPlaces({Place{TARGET(kARM), PRECISION(kAny)}});
calib.SetValidPlaces({Place{TARGET(kARM), PRECISION(kInt8)}});
calib.Attach(desc, &scope);
auto kernels = calib.CreateKernels({Place{TARGET(kARM), PRECISION(kAny)}});
auto kernels = calib.CreateKernels({Place{TARGET(kARM), PRECISION(kInt8)}});
ASSERT_FALSE(kernels.empty());
}
#endif
......@@ -60,5 +60,5 @@ TEST(calib_op_lite, TestARM) {
} // namespace paddle
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(calib, kARM, kAny, kAny, def);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, def);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册