From 9388dace85b83d893124a4b2930d6a24f727b613 Mon Sep 17 00:00:00 2001 From: TianXiaogang Date: Fri, 23 Aug 2019 19:40:05 +0800 Subject: [PATCH] feat: (#1836) add model_run_test_image add range_max_quant op add flatten op add flatten2 op fix: fix density_prior_box density_size type from float to int fix prior_box and density_prior_box some check for get_attr test=develop --- lite/api/CMakeLists.txt | 9 +- lite/api/cxx_api.cc | 7 ++ lite/api/cxx_api.h | 1 + lite/api/model_run_test_image.cc | 79 +++++++++++++++ lite/api/paddle_use_kernels.h | 2 + lite/api/paddle_use_ops.h | 3 + lite/api/test_helper.h | 3 + lite/arm/math/prior_box.cc | 6 +- lite/arm/math/prior_box.h | 2 +- lite/core/profile/precision_profiler.h | 45 ++++++++- lite/kernels/arm/density_prior_box_compute.cc | 3 +- lite/kernels/host/reshape_compute.cc | 37 +++++++ lite/operators/CMakeLists.txt | 4 + lite/operators/conv_transpose_op.cc | 4 +- lite/operators/density_prior_box_op.cc | 24 ++++- lite/operators/fake_quantize_range_abs_max.cc | 25 +++++ lite/operators/fake_quantize_range_abs_max.h | 69 +++++++++++++ lite/operators/flatten_op.cc | 99 +++++++++++++++++++ lite/operators/flatten_op.h | 62 ++++++++++++ lite/operators/op_params.h | 2 +- lite/operators/prior_box_op.cc | 4 +- lite/tests/kernels/prior_box_compute_test.cc | 10 +- 22 files changed, 473 insertions(+), 27 deletions(-) create mode 100644 lite/api/model_run_test_image.cc create mode 100644 lite/operators/fake_quantize_range_abs_max.cc create mode 100644 lite/operators/fake_quantize_range_abs_max.h create mode 100644 lite/operators/flatten_op.cc create mode 100644 lite/operators/flatten_op.h diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 5212d7a4ca..8a99bea428 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -145,8 +145,13 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL) add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz) -# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc -# DEPS ${lite_model_test_DEPS}) + # lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc + # DEPS ${lite_model_test_DEPS}) + + # lite_cc_test(model_run_test_image SRCS model_run_test_image.cc + # DEPS ${lite_model_test_DEPS} + # CL_DEPS ${opencl_kernels} + # FPGA_DEPS ${fpga_kernels}) endif() # These tests needs CLI arguments, and is not supported in ARM CI. diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 36529ecf30..622db41285 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -71,6 +71,13 @@ const lite::Tensor *Predictor::GetOutput(size_t offset) const { return &fetch_list.at(offset); } +const std::vector *Predictor::GetOutputs() const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + return &fetch_list; +} + const cpp::ProgramDesc &Predictor::program_desc() const { return program_desc_; } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 5d94a75bb1..d664565993 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -69,6 +69,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; + const std::vector* GetOutputs() const; const cpp::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; diff --git a/lite/api/model_run_test_image.cc b/lite/api/model_run_test_image.cc new file mode 100644 index 0000000000..0ef2ecb088 --- /dev/null +++ b/lite/api/model_run_test_image.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +TEST(model, test) { +#ifdef LITE_WITH_ARM + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt8)}}); + + auto precision = PRECISION(kFloat); + if (FLAGS_int8) { + precision = PRECISION(kInt8); + } + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), precision}, valid_places); + int im_width = FLAGS_im_width; + int im_height = FLAGS_im_height; + auto* input_tensor = predictor.GetInput(0); + auto in_dims = input_tensor->dims(); + input_tensor->Resize( + DDim(std::vector({1, 3, im_width, im_height}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + auto* output_tensors = predictor.GetOutputs(); + + LOG(INFO) << "======output:========"; + for (auto t : *output_tensors) { + LOG(INFO) << t; + } + LOG(INFO) + << "=====RUN_finished!!============= Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/paddle_use_kernels.h b/lite/api/paddle_use_kernels.h index f2fe0ce34f..2f4d7350b5 100644 --- a/lite/api/paddle_use_kernels.h +++ b/lite/api/paddle_use_kernels.h @@ -21,6 +21,8 @@ #ifndef LITE_WITH_FPGA USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_KERNEL(flatten, kHost, kAny, kAny, def); +USE_LITE_KERNEL(flatten2, kHost, kAny, kAny, def); #else USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def); USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); diff --git a/lite/api/paddle_use_ops.h b/lite/api/paddle_use_ops.h index cf01c74adb..5cf62224de 100644 --- a/lite/api/paddle_use_ops.h +++ b/lite/api/paddle_use_ops.h @@ -73,9 +73,12 @@ USE_LITE_OP(prior_box) USE_LITE_OP(density_prior_box) USE_LITE_OP(reshape) USE_LITE_OP(reshape2) +USE_LITE_OP(flatten) +USE_LITE_OP(flatten2) USE_LITE_OP(split) USE_LITE_OP(fake_quantize_moving_average_abs_max); USE_LITE_OP(fake_dequantize_max_abs); +USE_LITE_OP(fake_quantize_range_abs_max); USE_LITE_OP(calib); USE_LITE_OP(calib_once); USE_LITE_OP(norm); diff --git a/lite/api/test_helper.h b/lite/api/test_helper.h index 1a5ab31abd..d835c030f0 100644 --- a/lite/api/test_helper.h +++ b/lite/api/test_helper.h @@ -23,6 +23,9 @@ DEFINE_string(model_dir, "", "model dir"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(im_width, 224, "image width"); +DEFINE_int32(im_height, 224, "image height"); +DEFINE_bool(int8, false, "is run int8"); namespace paddle { namespace lite { diff --git a/lite/arm/math/prior_box.cc b/lite/arm/math/prior_box.cc index e6f455e72a..6ec3127965 100644 --- a/lite/arm/math/prior_box.cc +++ b/lite/arm/math/prior_box.cc @@ -51,7 +51,7 @@ void density_prior_box(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, @@ -82,14 +82,12 @@ void density_prior_box(const lite::Tensor* input, img_width = image->dims()[3]; img_height = image->dims()[2]; } - float step_w = step_w_; float step_h = step_h_; if (step_w == 0 || step_h == 0) { step_w = static_cast(img_width) / width; step_h = static_cast(img_height) / height; } - float offset = offset_; int step_average = static_cast((step_w + step_h) * 0.5); // add int channel_size = height * width * prior_num_ * 4; @@ -343,7 +341,7 @@ void prior_box(const lite::Tensor* input, min_size, std::vector(), std::vector(), - std::vector(), + std::vector(), max_size, aspect_ratio, variance, diff --git a/lite/arm/math/prior_box.h b/lite/arm/math/prior_box.h index 59efb2ab00..ffa821b75e 100644 --- a/lite/arm/math/prior_box.h +++ b/lite/arm/math/prior_box.h @@ -30,7 +30,7 @@ void density_prior_box(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index 65cc160077..d9111e5c46 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -26,17 +26,49 @@ namespace paddle { namespace lite { namespace profile { +template +static void write_tensorfile(const Tensor* tensor, const std::string& locate) { + if (locate.find('/') != std::string::npos) { + return; + } + FILE* fp = fopen(locate.c_str(), "w"); + if (fp == nullptr) { + LOG(ERROR) << "file open field " << locate; + } else { + const dtype* data = tensor->data(); + for (int i = 0; i < tensor->numel(); ++i) { + fprintf(fp, "[%d] %f \n", i, static_cast(data[i])); + } + } + fclose(fp); +} + class PrecisionProfiler { public: explicit PrecisionProfiler(const Instruction* inst) : inst_(inst) {} ~PrecisionProfiler() { LOG(INFO) << ">> Running kernel: " << inst_->op()->op_info()->Repr() - << " on Target " << TargetToStr(inst_->kernel()->target()); - auto tensor_mean = [](const Tensor* in, PrecisionType ptype) -> double { + << " on Target " << TargetToStr(inst_->kernel()->target()) << " " + << PrecisionToStr(inst_->kernel()->precision()); + auto tensor_mean = [](const Tensor* in, + PrecisionType ptype, + std::string name = "inst") -> double { + if (!in->data()) { + return -99999; + } double sum = 0.; switch (ptype) { case PRECISION(kFloat): { auto ptr = in->data(); + // write_tensorfile(in, name); + for (int i = 0; i < in->numel(); ++i) { + sum += ptr[i]; + } + return sum / in->numel(); + } + case PRECISION(kAny): { + auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -44,6 +76,7 @@ class PrecisionProfiler { } case PRECISION(kInt8): { auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -51,6 +84,7 @@ class PrecisionProfiler { } case PRECISION(kInt32): { auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -70,17 +104,18 @@ class PrecisionProfiler { std::string out_arg_name; op->op_info()->GetOutputArgname(out_name, &out_arg_name); auto type = kernel->GetOutputDeclType(out_arg_name); + if (type->IsTensor()) { auto tout = op_scope->FindVar(out_name)->GetMutable(); - double mean = tensor_mean(tout, type->precision()); + double mean = tensor_mean(tout, type->precision(), out_name); LOG(INFO) << "output name: " << out_name << ", dims: " << tout->dims() << ", precision: " << PrecisionToStr(type->precision()) - << ", mean value: " << mean; + << ", mean value: " << mean << " shape:" << tout->dims(); } else if (type->IsTensorList()) { auto tout = op_scope->FindVar(out_name)->GetMutable>(); for (auto& t : *tout) { - double mean = tensor_mean(&t, type->precision()); + double mean = tensor_mean(&t, type->precision(), out_name); LOG(INFO) << "output name: " << out_name << ", dims: " << t.dims() << ", precision: " << PrecisionToStr(type->precision()) << ", mean value: " << mean; diff --git a/lite/kernels/arm/density_prior_box_compute.cc b/lite/kernels/arm/density_prior_box_compute.cc index 47d14d6572..35616bc6e8 100644 --- a/lite/kernels/arm/density_prior_box_compute.cc +++ b/lite/kernels/arm/density_prior_box_compute.cc @@ -48,13 +48,12 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, void DensityPriorBoxCompute::Run() { auto& param = Param(); - bool is_flip = param.flip; bool is_clip = param.clip; std::vector min_size = param.min_sizes; std::vector fixed_size = param.fixed_sizes; std::vector fixed_ratio = param.fixed_ratios; - std::vector density_size = param.density_sizes; + auto density_size = param.density_sizes; std::vector max_size = param.max_sizes; std::vector aspect_ratio = param.aspect_ratios; std::vector variance = param.variances_; diff --git a/lite/kernels/host/reshape_compute.cc b/lite/kernels/host/reshape_compute.cc index cb3420fbbd..a5934999cd 100644 --- a/lite/kernels/host/reshape_compute.cc +++ b/lite/kernels/host/reshape_compute.cc @@ -93,3 +93,40 @@ REGISTER_LITE_KERNEL(reshape2, {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); + +REGISTER_LITE_KERNEL(flatten, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ReshapeCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Shape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .Finalize(); + +REGISTER_LITE_KERNEL(flatten2, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ReshapeCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Shape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("XShape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 1362a86797..4e7f7436e3 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -9,6 +9,7 @@ lite_cc_library(matmul_op SRCS matmul_op.cc DEPS ${op_DEPS}) lite_cc_library(scale_op SRCS scale_op.cc DEPS ${op_DEPS}) lite_cc_library(softmax_op SRCS softmax_op.cc DEPS ${op_DEPS}) lite_cc_library(reshape_op SRCS reshape_op.cc DEPS ${op_DEPS} ) +lite_cc_library(flatten_op SRCS flatten_op.cc DEPS ${op_DEPS} ) lite_cc_library(batch_norm_op SRCS batch_norm_op.cc DEPS ${op_DEPS}) lite_cc_library(feed_op SRCS feed_op.cc DEPS ${op_DEPS}) lite_cc_library(fetch_op SRCS fetch_op.cc DEPS ${op_DEPS}) @@ -52,6 +53,7 @@ lite_cc_library(calib_once_op SRCS calib_once_op.cc DEPS ${op_DEPS}) lite_cc_library(split_op SRCS split_op.cc DEPS ${op_DEPS}) lite_cc_library(transpose_op SRCS transpose_op.cc DEPS ${op_DEPS}) lite_cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +lite_cc_library(fake_quant_range SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) lite_cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS}) lite_cc_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS ${op_DEPS}) lite_cc_library(im2sequence_op SRCS im2sequence_op.cc DEPS ${op_DEPS}) @@ -96,6 +98,7 @@ set(ops scale_op softmax_op reshape_op + flatten_op batch_norm_op feed_op fetch_op @@ -128,6 +131,7 @@ set(ops split_op transpose_op fake_quant + fake_quant_range fake_dequant sgd_op uniform_random_op diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index b84b4ff169..fb6b431fff 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -85,7 +85,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, } } } - param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + if (op_desc.HasAttr("fuse_relu")) { + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + } return true; } diff --git a/lite/operators/density_prior_box_op.cc b/lite/operators/density_prior_box_op.cc index c6b646b33d..86830df2f1 100644 --- a/lite/operators/density_prior_box_op.cc +++ b/lite/operators/density_prior_box_op.cc @@ -41,15 +41,29 @@ bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, param_.boxes = scope->FindVar(boxes)->GetMutable(); param_.variances = scope->FindVar(variances)->GetMutable(); - param_.flip = opdesc.GetAttr("flip"); param_.clip = opdesc.GetAttr("clip"); - param_.min_sizes = opdesc.GetAttr>("min_sizes"); param_.fixed_sizes = opdesc.GetAttr>("fixed_sizes"); param_.fixed_ratios = opdesc.GetAttr>("fixed_ratios"); - param_.density_sizes = opdesc.GetAttr>("density_sizes"); - param_.max_sizes = opdesc.GetAttr>("max_sizes"); - param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); param_.variances_ = opdesc.GetAttr>("variances"); + + if (opdesc.HasAttr("aspect_ratios")) { + param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); + } + if (opdesc.HasAttr("max_sizes")) { + param_.max_sizes = opdesc.GetAttr>("max_sizes"); + } + if (opdesc.HasAttr("density_sizes")) { + param_.density_sizes = opdesc.GetAttr>("density_sizes"); + } + if (opdesc.HasAttr("densities")) { + param_.density_sizes = opdesc.GetAttr>("densities"); + } + if (opdesc.HasAttr("min_sizes")) { + param_.min_sizes = opdesc.GetAttr>("min_sizes"); + } + if (opdesc.HasAttr("flip")) { + param_.flip = opdesc.GetAttr("flip"); + } if (opdesc.HasAttr("img_w")) { param_.img_w = opdesc.GetAttr("img_w"); } diff --git a/lite/operators/fake_quantize_range_abs_max.cc b/lite/operators/fake_quantize_range_abs_max.cc new file mode 100644 index 0000000000..a8ce3f75a5 --- /dev/null +++ b/lite/operators/fake_quantize_range_abs_max.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_quantize_range_abs_max.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_quantize_range_abs_max, + paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h new file mode 100644 index 0000000000..726731595a --- /dev/null +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeRangeMaxAbsOpLite : public OpLite { + public: + FakeQuantizeRangeMaxAbsOpLite() {} + + explicit FakeQuantizeRangeMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShape() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_range_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/flatten_op.cc b/lite/operators/flatten_op.cc new file mode 100644 index 0000000000..6deab45023 --- /dev/null +++ b/lite/operators/flatten_op.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/flatten_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool FlattenOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool FlattenOp::InferShape() const { + auto x_dims = param_.x->dims(); + + auto out_lod = param_.output->mutable_lod(); + *out_lod = param_.x->lod(); + + int64_t outer = 1, inner = 1; + for (int i = 0; i < x_dims.size(); ++i) { + if (i < axis_) { + outer *= x_dims[i]; + } else { + inner *= x_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + + param_.output->Resize(out_shape); + + return true; +} + +bool FlattenOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.x = const_cast(&(x_var->Get())); + param_.output = output_var->GetMutable(); + axis_ = opdesc.GetAttr("axis"); + + param_.inplace = false; + + CHECK(param_.x) << "Input(X) of FlattenOp should not be null."; + CHECK(param_.output) << "Output(Out) of FlattenOp should not be null."; + CHECK_GE(axis_, 0) << "Flatten op axis should >=0."; + return true; +} + +bool Flatten2Op::CheckShape() const { + FlattenOp::CheckShape(); + CHECK_OR_FALSE(param_.xshape); + return true; +} + +bool Flatten2Op::InferShape() const { + FlattenOp::InferShape(); + auto x_dims = param_.x->dims(); + std::vector xshape_dims(x_dims.size() + 1, 0); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.xshape->Resize(DDim(xshape_dims)); + return true; +} + +bool Flatten2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + FlattenOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.xshape = xshape_var->GetMutable(); + CHECK(param_.xshape) << "Output(XShape) of FlattenOp should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(flatten, paddle::lite::operators::FlattenOp); +REGISTER_LITE_OP(flatten2, paddle::lite::operators::Flatten2Op); diff --git a/lite/operators/flatten_op.h b/lite/operators/flatten_op.h new file mode 100644 index 0000000000..61680fd390 --- /dev/null +++ b/lite/operators/flatten_op.h @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FlattenOp : public OpLite { + public: + FlattenOp() {} + explicit FlattenOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "flatten"; } + + protected: + mutable ReshapeParam param_; + int axis_; +}; + +class Flatten2Op : public FlattenOp { + public: + Flatten2Op() : FlattenOp() {} + explicit Flatten2Op(const std::string &op_type) : FlattenOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "flatten2"; } +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index c011aa5e0c..2d1fd8bfe6 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -521,7 +521,7 @@ struct PriorBoxParam { struct DensityPriorBoxParam : public PriorBoxParam { std::vector fixed_sizes; std::vector fixed_ratios; - std::vector density_sizes; + std::vector density_sizes; }; /// ----------------------- GRU operators ----------------------f struct GRUParam { diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index 8053b24b62..3cc8938f4e 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -40,12 +40,14 @@ bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { param_.boxes = scope->FindVar(boxes)->GetMutable(); param_.variances = scope->FindVar(variances)->GetMutable(); - param_.flip = opdesc.GetAttr("flip"); param_.clip = opdesc.GetAttr("clip"); param_.min_sizes = opdesc.GetAttr>("min_sizes"); param_.max_sizes = opdesc.GetAttr>("max_sizes"); param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); param_.variances_ = opdesc.GetAttr>("variances"); + if (opdesc.HasAttr("flip")) { + param_.flip = opdesc.GetAttr("flip"); + } if (opdesc.HasAttr("img_w")) { param_.img_w = opdesc.GetAttr("img_w"); } diff --git a/lite/tests/kernels/prior_box_compute_test.cc b/lite/tests/kernels/prior_box_compute_test.cc index 57bea3e96d..47f7bc9447 100644 --- a/lite/tests/kernels/prior_box_compute_test.cc +++ b/lite/tests/kernels/prior_box_compute_test.cc @@ -75,7 +75,7 @@ void prior_box_compute_ref(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, @@ -352,7 +352,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase { std::vector min_size_; std::vector fixed_size_; std::vector fixed_ratio_; - std::vector density_size_; + std::vector density_size_; std::vector max_size_; std::vector aspect_ratio_; std::vector variance_; @@ -375,7 +375,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase { const std::vector& min_size, const std::vector& fixed_size, const std::vector& fixed_ratio, - const std::vector& density_size, + const std::vector& density_size, const std::vector& max_size, const std::vector& aspect_ratio, const std::vector& variance, @@ -561,7 +561,7 @@ class PriorBoxComputeTester : public arena::TestCase { min_size_, std::vector(), std::vector(), - std::vector(), + std::vector(), max_size_, aspect_ratio_, variance_, @@ -621,7 +621,7 @@ void test_density_prior_box(Place place) { std::vector variance{0.1f, 0.1f, 0.2f, 0.2f}; std::vector fixed_size{60, 30}; std::vector fixed_ratio{1., 2.}; - std::vector density_size{1., 3.}; + std::vector density_size{1, 3}; bool flip = true; bool clip = false; float step_h = 0; -- GitLab