diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 5212d7a4ca763bd582e829e190fdf7ad56d78da5..8a99bea42805093a108808e07dfa7779997d40b7 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -145,8 +145,13 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL) add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz) -# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc -# DEPS ${lite_model_test_DEPS}) + # lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc + # DEPS ${lite_model_test_DEPS}) + + # lite_cc_test(model_run_test_image SRCS model_run_test_image.cc + # DEPS ${lite_model_test_DEPS} + # CL_DEPS ${opencl_kernels} + # FPGA_DEPS ${fpga_kernels}) endif() # These tests needs CLI arguments, and is not supported in ARM CI. diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 36529ecf30003f5749eb2160ebe856d77f5539b4..622db412853cd780d6e2d2b00ec6c3c3fa788ae3 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -71,6 +71,13 @@ const lite::Tensor *Predictor::GetOutput(size_t offset) const { return &fetch_list.at(offset); } +const std::vector *Predictor::GetOutputs() const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + return &fetch_list; +} + const cpp::ProgramDesc &Predictor::program_desc() const { return program_desc_; } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 5d94a75bb1233ffc157f06096dfd32c9848951f6..d664565993c80d3853907906f53672d9b7df4a71 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -69,6 +69,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; + const std::vector* GetOutputs() const; const cpp::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; diff --git a/lite/api/model_run_test_image.cc b/lite/api/model_run_test_image.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ef2ecb08805398ec89cb86bf883a59cc713e08d --- /dev/null +++ b/lite/api/model_run_test_image.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +TEST(model, test) { +#ifdef LITE_WITH_ARM + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt8)}}); + + auto precision = PRECISION(kFloat); + if (FLAGS_int8) { + precision = PRECISION(kInt8); + } + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), precision}, valid_places); + int im_width = FLAGS_im_width; + int im_height = FLAGS_im_height; + auto* input_tensor = predictor.GetInput(0); + auto in_dims = input_tensor->dims(); + input_tensor->Resize( + DDim(std::vector({1, 3, im_width, im_height}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + auto* output_tensors = predictor.GetOutputs(); + + LOG(INFO) << "======output:========"; + for (auto t : *output_tensors) { + LOG(INFO) << t; + } + LOG(INFO) + << "=====RUN_finished!!============= Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/paddle_use_kernels.h b/lite/api/paddle_use_kernels.h index f2fe0ce34f826926404bb613d15bf52edc206643..2f4d7350b56f7c56a329b629b27ed5b517708ef6 100644 --- a/lite/api/paddle_use_kernels.h +++ b/lite/api/paddle_use_kernels.h @@ -21,6 +21,8 @@ #ifndef LITE_WITH_FPGA USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_KERNEL(flatten, kHost, kAny, kAny, def); +USE_LITE_KERNEL(flatten2, kHost, kAny, kAny, def); #else USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def); USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); diff --git a/lite/api/paddle_use_ops.h b/lite/api/paddle_use_ops.h index cf01c74adbcf2a9dc29883d82644275b2b8be465..5cf62224de33f14d5e0637fc9cc54752a79ba445 100644 --- a/lite/api/paddle_use_ops.h +++ b/lite/api/paddle_use_ops.h @@ -73,9 +73,12 @@ USE_LITE_OP(prior_box) USE_LITE_OP(density_prior_box) USE_LITE_OP(reshape) USE_LITE_OP(reshape2) +USE_LITE_OP(flatten) +USE_LITE_OP(flatten2) USE_LITE_OP(split) USE_LITE_OP(fake_quantize_moving_average_abs_max); USE_LITE_OP(fake_dequantize_max_abs); +USE_LITE_OP(fake_quantize_range_abs_max); USE_LITE_OP(calib); USE_LITE_OP(calib_once); USE_LITE_OP(norm); diff --git a/lite/api/test_helper.h b/lite/api/test_helper.h index 1a5ab31abd3e97c5bfc484547af5d36d53e49b39..d835c030f03a3c95575217020cd298dabbf1a15a 100644 --- a/lite/api/test_helper.h +++ b/lite/api/test_helper.h @@ -23,6 +23,9 @@ DEFINE_string(model_dir, "", "model dir"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(im_width, 224, "image width"); +DEFINE_int32(im_height, 224, "image height"); +DEFINE_bool(int8, false, "is run int8"); namespace paddle { namespace lite { diff --git a/lite/arm/math/prior_box.cc b/lite/arm/math/prior_box.cc index e6f455e72a231f36e10b2cde54140ca68fcd4a43..6ec312796578863dc9c7a046950aa4dcf79d38fa 100644 --- a/lite/arm/math/prior_box.cc +++ b/lite/arm/math/prior_box.cc @@ -51,7 +51,7 @@ void density_prior_box(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, @@ -82,14 +82,12 @@ void density_prior_box(const lite::Tensor* input, img_width = image->dims()[3]; img_height = image->dims()[2]; } - float step_w = step_w_; float step_h = step_h_; if (step_w == 0 || step_h == 0) { step_w = static_cast(img_width) / width; step_h = static_cast(img_height) / height; } - float offset = offset_; int step_average = static_cast((step_w + step_h) * 0.5); // add int channel_size = height * width * prior_num_ * 4; @@ -343,7 +341,7 @@ void prior_box(const lite::Tensor* input, min_size, std::vector(), std::vector(), - std::vector(), + std::vector(), max_size, aspect_ratio, variance, diff --git a/lite/arm/math/prior_box.h b/lite/arm/math/prior_box.h index 59efb2ab0027d3d5cab68118ea48fa70436d1c48..ffa821b75e54ee3e2329e4dcced8ddee2a003802 100644 --- a/lite/arm/math/prior_box.h +++ b/lite/arm/math/prior_box.h @@ -30,7 +30,7 @@ void density_prior_box(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index 65cc1600773297f935149c040a264400e13f91cc..d9111e5c46c9217b181e5a3e5a8c7981f46250df 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -26,17 +26,49 @@ namespace paddle { namespace lite { namespace profile { +template +static void write_tensorfile(const Tensor* tensor, const std::string& locate) { + if (locate.find('/') != std::string::npos) { + return; + } + FILE* fp = fopen(locate.c_str(), "w"); + if (fp == nullptr) { + LOG(ERROR) << "file open field " << locate; + } else { + const dtype* data = tensor->data(); + for (int i = 0; i < tensor->numel(); ++i) { + fprintf(fp, "[%d] %f \n", i, static_cast(data[i])); + } + } + fclose(fp); +} + class PrecisionProfiler { public: explicit PrecisionProfiler(const Instruction* inst) : inst_(inst) {} ~PrecisionProfiler() { LOG(INFO) << ">> Running kernel: " << inst_->op()->op_info()->Repr() - << " on Target " << TargetToStr(inst_->kernel()->target()); - auto tensor_mean = [](const Tensor* in, PrecisionType ptype) -> double { + << " on Target " << TargetToStr(inst_->kernel()->target()) << " " + << PrecisionToStr(inst_->kernel()->precision()); + auto tensor_mean = [](const Tensor* in, + PrecisionType ptype, + std::string name = "inst") -> double { + if (!in->data()) { + return -99999; + } double sum = 0.; switch (ptype) { case PRECISION(kFloat): { auto ptr = in->data(); + // write_tensorfile(in, name); + for (int i = 0; i < in->numel(); ++i) { + sum += ptr[i]; + } + return sum / in->numel(); + } + case PRECISION(kAny): { + auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -44,6 +76,7 @@ class PrecisionProfiler { } case PRECISION(kInt8): { auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -51,6 +84,7 @@ class PrecisionProfiler { } case PRECISION(kInt32): { auto ptr = in->data(); + // write_tensorfile(in, name); for (int i = 0; i < in->numel(); ++i) { sum += ptr[i]; } @@ -70,17 +104,18 @@ class PrecisionProfiler { std::string out_arg_name; op->op_info()->GetOutputArgname(out_name, &out_arg_name); auto type = kernel->GetOutputDeclType(out_arg_name); + if (type->IsTensor()) { auto tout = op_scope->FindVar(out_name)->GetMutable(); - double mean = tensor_mean(tout, type->precision()); + double mean = tensor_mean(tout, type->precision(), out_name); LOG(INFO) << "output name: " << out_name << ", dims: " << tout->dims() << ", precision: " << PrecisionToStr(type->precision()) - << ", mean value: " << mean; + << ", mean value: " << mean << " shape:" << tout->dims(); } else if (type->IsTensorList()) { auto tout = op_scope->FindVar(out_name)->GetMutable>(); for (auto& t : *tout) { - double mean = tensor_mean(&t, type->precision()); + double mean = tensor_mean(&t, type->precision(), out_name); LOG(INFO) << "output name: " << out_name << ", dims: " << t.dims() << ", precision: " << PrecisionToStr(type->precision()) << ", mean value: " << mean; diff --git a/lite/kernels/arm/density_prior_box_compute.cc b/lite/kernels/arm/density_prior_box_compute.cc index 47d14d6572281e212322be38cab67cdb5c1581b5..35616bc6e8ac1e8c142616cf633578a057bb967f 100644 --- a/lite/kernels/arm/density_prior_box_compute.cc +++ b/lite/kernels/arm/density_prior_box_compute.cc @@ -48,13 +48,12 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, void DensityPriorBoxCompute::Run() { auto& param = Param(); - bool is_flip = param.flip; bool is_clip = param.clip; std::vector min_size = param.min_sizes; std::vector fixed_size = param.fixed_sizes; std::vector fixed_ratio = param.fixed_ratios; - std::vector density_size = param.density_sizes; + auto density_size = param.density_sizes; std::vector max_size = param.max_sizes; std::vector aspect_ratio = param.aspect_ratios; std::vector variance = param.variances_; diff --git a/lite/kernels/host/reshape_compute.cc b/lite/kernels/host/reshape_compute.cc index cb3420fbbda06f34772ca672b2bc7a8444056185..a5934999cdd9c88037936bbf73f7d810aaffc3e7 100644 --- a/lite/kernels/host/reshape_compute.cc +++ b/lite/kernels/host/reshape_compute.cc @@ -93,3 +93,40 @@ REGISTER_LITE_KERNEL(reshape2, {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) .Finalize(); + +REGISTER_LITE_KERNEL(flatten, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ReshapeCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Shape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .Finalize(); + +REGISTER_LITE_KERNEL(flatten2, + kHost, + kAny, + kAny, + paddle::lite::kernels::host::ReshapeCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindInput("Shape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .BindOutput("XShape", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 1362a86797f699ba37e328d8a4a2ffd166bb55b2..4e7f7436e3d5e43360da06525f033a4c2a98fd3a 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -9,6 +9,7 @@ lite_cc_library(matmul_op SRCS matmul_op.cc DEPS ${op_DEPS}) lite_cc_library(scale_op SRCS scale_op.cc DEPS ${op_DEPS}) lite_cc_library(softmax_op SRCS softmax_op.cc DEPS ${op_DEPS}) lite_cc_library(reshape_op SRCS reshape_op.cc DEPS ${op_DEPS} ) +lite_cc_library(flatten_op SRCS flatten_op.cc DEPS ${op_DEPS} ) lite_cc_library(batch_norm_op SRCS batch_norm_op.cc DEPS ${op_DEPS}) lite_cc_library(feed_op SRCS feed_op.cc DEPS ${op_DEPS}) lite_cc_library(fetch_op SRCS fetch_op.cc DEPS ${op_DEPS}) @@ -52,6 +53,7 @@ lite_cc_library(calib_once_op SRCS calib_once_op.cc DEPS ${op_DEPS}) lite_cc_library(split_op SRCS split_op.cc DEPS ${op_DEPS}) lite_cc_library(transpose_op SRCS transpose_op.cc DEPS ${op_DEPS}) lite_cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) +lite_cc_library(fake_quant_range SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) lite_cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS}) lite_cc_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS ${op_DEPS}) lite_cc_library(im2sequence_op SRCS im2sequence_op.cc DEPS ${op_DEPS}) @@ -96,6 +98,7 @@ set(ops scale_op softmax_op reshape_op + flatten_op batch_norm_op feed_op fetch_op @@ -128,6 +131,7 @@ set(ops split_op transpose_op fake_quant + fake_quant_range fake_dequant sgd_op uniform_random_op diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index b84b4ff16993b51410bf741db91c5ec46960d410..fb6b431fff8ab20dd1a6d1abc8aff7443771ee2f 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -85,7 +85,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, } } } - param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + if (op_desc.HasAttr("fuse_relu")) { + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + } return true; } diff --git a/lite/operators/density_prior_box_op.cc b/lite/operators/density_prior_box_op.cc index c6b646b33d64eaf8dc3ca34254d9a756e01fb1d6..86830df2f19b5615e8b9cfb4b3b57eb22000f588 100644 --- a/lite/operators/density_prior_box_op.cc +++ b/lite/operators/density_prior_box_op.cc @@ -41,15 +41,29 @@ bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, param_.boxes = scope->FindVar(boxes)->GetMutable(); param_.variances = scope->FindVar(variances)->GetMutable(); - param_.flip = opdesc.GetAttr("flip"); param_.clip = opdesc.GetAttr("clip"); - param_.min_sizes = opdesc.GetAttr>("min_sizes"); param_.fixed_sizes = opdesc.GetAttr>("fixed_sizes"); param_.fixed_ratios = opdesc.GetAttr>("fixed_ratios"); - param_.density_sizes = opdesc.GetAttr>("density_sizes"); - param_.max_sizes = opdesc.GetAttr>("max_sizes"); - param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); param_.variances_ = opdesc.GetAttr>("variances"); + + if (opdesc.HasAttr("aspect_ratios")) { + param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); + } + if (opdesc.HasAttr("max_sizes")) { + param_.max_sizes = opdesc.GetAttr>("max_sizes"); + } + if (opdesc.HasAttr("density_sizes")) { + param_.density_sizes = opdesc.GetAttr>("density_sizes"); + } + if (opdesc.HasAttr("densities")) { + param_.density_sizes = opdesc.GetAttr>("densities"); + } + if (opdesc.HasAttr("min_sizes")) { + param_.min_sizes = opdesc.GetAttr>("min_sizes"); + } + if (opdesc.HasAttr("flip")) { + param_.flip = opdesc.GetAttr("flip"); + } if (opdesc.HasAttr("img_w")) { param_.img_w = opdesc.GetAttr("img_w"); } diff --git a/lite/operators/fake_quantize_range_abs_max.cc b/lite/operators/fake_quantize_range_abs_max.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8ce3f75a59fec5b032c60f51177f428bd15fe0d --- /dev/null +++ b/lite/operators/fake_quantize_range_abs_max.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/fake_quantize_range_abs_max.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators {} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fake_quantize_range_abs_max, + paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite); diff --git a/lite/operators/fake_quantize_range_abs_max.h b/lite/operators/fake_quantize_range_abs_max.h new file mode 100644 index 0000000000000000000000000000000000000000..726731595a9c4b7cd2e30db911230cc2f00b5b92 --- /dev/null +++ b/lite/operators/fake_quantize_range_abs_max.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FakeQuantizeRangeMaxAbsOpLite : public OpLite { + public: + FakeQuantizeRangeMaxAbsOpLite() {} + + explicit FakeQuantizeRangeMaxAbsOpLite(const std::string &type) + : OpLite(type) {} + + bool CheckShape() const override { return true; } + + bool InferShape() const override { return true; } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto in_scale = op_desc.Input("InScale").front(); + + auto out = op_desc.Output("Out").front(); + auto out_scale = op_desc.Output("OutScale").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + param_.in_scale = scope->FindVar(in_scale)->GetMutable(); + + param_.out = scope->FindVar(out)->GetMutable(); + param_.out_scale = scope->FindVar(out_scale)->GetMutable(); + param_.bit_length = op_desc.GetAttr("bit_length"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "fake_quantize_range_max_abs"; + } + + private: + mutable FakeQuantizeMovingAvgMaxAbsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/flatten_op.cc b/lite/operators/flatten_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6deab45023876b1a5707ef5cea6ec69af3875328 --- /dev/null +++ b/lite/operators/flatten_op.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/flatten_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool FlattenOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool FlattenOp::InferShape() const { + auto x_dims = param_.x->dims(); + + auto out_lod = param_.output->mutable_lod(); + *out_lod = param_.x->lod(); + + int64_t outer = 1, inner = 1; + for (int i = 0; i < x_dims.size(); ++i) { + if (i < axis_) { + outer *= x_dims[i]; + } else { + inner *= x_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + + param_.output->Resize(out_shape); + + return true; +} + +bool FlattenOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.x = const_cast(&(x_var->Get())); + param_.output = output_var->GetMutable(); + axis_ = opdesc.GetAttr("axis"); + + param_.inplace = false; + + CHECK(param_.x) << "Input(X) of FlattenOp should not be null."; + CHECK(param_.output) << "Output(Out) of FlattenOp should not be null."; + CHECK_GE(axis_, 0) << "Flatten op axis should >=0."; + return true; +} + +bool Flatten2Op::CheckShape() const { + FlattenOp::CheckShape(); + CHECK_OR_FALSE(param_.xshape); + return true; +} + +bool Flatten2Op::InferShape() const { + FlattenOp::InferShape(); + auto x_dims = param_.x->dims(); + std::vector xshape_dims(x_dims.size() + 1, 0); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.xshape->Resize(DDim(xshape_dims)); + return true; +} + +bool Flatten2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + FlattenOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.xshape = xshape_var->GetMutable(); + CHECK(param_.xshape) << "Output(XShape) of FlattenOp should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(flatten, paddle::lite::operators::FlattenOp); +REGISTER_LITE_OP(flatten2, paddle::lite::operators::Flatten2Op); diff --git a/lite/operators/flatten_op.h b/lite/operators/flatten_op.h new file mode 100644 index 0000000000000000000000000000000000000000..61680fd3903b77f8826cda6f6a242739720155d7 --- /dev/null +++ b/lite/operators/flatten_op.h @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FlattenOp : public OpLite { + public: + FlattenOp() {} + explicit FlattenOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "flatten"; } + + protected: + mutable ReshapeParam param_; + int axis_; +}; + +class Flatten2Op : public FlattenOp { + public: + Flatten2Op() : FlattenOp() {} + explicit Flatten2Op(const std::string &op_type) : FlattenOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "flatten2"; } +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index c011aa5e0c9ba6ba0ac9032946bf880d4162d325..2d1fd8bfe6d9e1775dec8da506efa5acb82eafbd 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -521,7 +521,7 @@ struct PriorBoxParam { struct DensityPriorBoxParam : public PriorBoxParam { std::vector fixed_sizes; std::vector fixed_ratios; - std::vector density_sizes; + std::vector density_sizes; }; /// ----------------------- GRU operators ----------------------f struct GRUParam { diff --git a/lite/operators/prior_box_op.cc b/lite/operators/prior_box_op.cc index 8053b24b623e38491876efc1ff486193a5a08cce..3cc8938f4eb3ffc5720a6e1cfc1746e1defd048e 100644 --- a/lite/operators/prior_box_op.cc +++ b/lite/operators/prior_box_op.cc @@ -40,12 +40,14 @@ bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { param_.boxes = scope->FindVar(boxes)->GetMutable(); param_.variances = scope->FindVar(variances)->GetMutable(); - param_.flip = opdesc.GetAttr("flip"); param_.clip = opdesc.GetAttr("clip"); param_.min_sizes = opdesc.GetAttr>("min_sizes"); param_.max_sizes = opdesc.GetAttr>("max_sizes"); param_.aspect_ratios = opdesc.GetAttr>("aspect_ratios"); param_.variances_ = opdesc.GetAttr>("variances"); + if (opdesc.HasAttr("flip")) { + param_.flip = opdesc.GetAttr("flip"); + } if (opdesc.HasAttr("img_w")) { param_.img_w = opdesc.GetAttr("img_w"); } diff --git a/lite/tests/kernels/prior_box_compute_test.cc b/lite/tests/kernels/prior_box_compute_test.cc index 57bea3e96dbf55014d55b0c7d34e8aa4db7b4b48..47f7bc9447b1b33b57c4bc4a495a106f49d6abbc 100644 --- a/lite/tests/kernels/prior_box_compute_test.cc +++ b/lite/tests/kernels/prior_box_compute_test.cc @@ -75,7 +75,7 @@ void prior_box_compute_ref(const lite::Tensor* input, const std::vector& min_size_, const std::vector& fixed_size_, const std::vector& fixed_ratio_, - const std::vector& density_size_, + const std::vector& density_size_, const std::vector& max_size_, const std::vector& aspect_ratio_, const std::vector& variance_, @@ -352,7 +352,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase { std::vector min_size_; std::vector fixed_size_; std::vector fixed_ratio_; - std::vector density_size_; + std::vector density_size_; std::vector max_size_; std::vector aspect_ratio_; std::vector variance_; @@ -375,7 +375,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase { const std::vector& min_size, const std::vector& fixed_size, const std::vector& fixed_ratio, - const std::vector& density_size, + const std::vector& density_size, const std::vector& max_size, const std::vector& aspect_ratio, const std::vector& variance, @@ -561,7 +561,7 @@ class PriorBoxComputeTester : public arena::TestCase { min_size_, std::vector(), std::vector(), - std::vector(), + std::vector(), max_size_, aspect_ratio_, variance_, @@ -621,7 +621,7 @@ void test_density_prior_box(Place place) { std::vector variance{0.1f, 0.1f, 0.2f, 0.2f}; std::vector fixed_size{60, 30}; std::vector fixed_ratio{1., 2.}; - std::vector density_size{1., 3.}; + std::vector density_size{1, 3}; bool flip = true; bool clip = false; float step_h = 0;