提交 9388dace 编写于 作者: T TianXiaogang 提交者: Xiaoyang LI

feat: (#1836)

add model_run_test_image
    add range_max_quant op
    add flatten op
    add flatten2 op
fix:
    fix density_prior_box density_size type from float to int
    fix prior_box and density_prior_box some check for get_attr
test=develop
上级 9067c2f3
......@@ -145,8 +145,13 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl
--model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL)
add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz)
# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc
# DEPS ${lite_model_test_DEPS})
# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc
# DEPS ${lite_model_test_DEPS})
# lite_cc_test(model_run_test_image SRCS model_run_test_image.cc
# DEPS ${lite_model_test_DEPS}
# CL_DEPS ${opencl_kernels}
# FPGA_DEPS ${fpga_kernels})
endif()
# These tests needs CLI arguments, and is not supported in ARM CI.
......
......@@ -71,6 +71,13 @@ const lite::Tensor *Predictor::GetOutput(size_t offset) const {
return &fetch_list.at(offset);
}
const std::vector<lite::Tensor> *Predictor::GetOutputs() const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
return &fetch_list;
}
const cpp::ProgramDesc &Predictor::program_desc() const {
return program_desc_;
}
......
......@@ -69,6 +69,7 @@ class LITE_API Predictor {
// Get offset-th col of fetch results.
const lite::Tensor* GetOutput(size_t offset) const;
const std::vector<lite::Tensor>* GetOutputs() const;
const cpp::ProgramDesc& program_desc() const;
const lite::Tensor* GetTensor(const std::string& name) const;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
TEST(model, test) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)}});
auto precision = PRECISION(kFloat);
if (FLAGS_int8) {
precision = PRECISION(kInt8);
}
predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), precision}, valid_places);
int im_width = FLAGS_im_width;
int im_height = FLAGS_im_height;
auto* input_tensor = predictor.GetInput(0);
auto in_dims = input_tensor->dims();
input_tensor->Resize(
DDim(std::vector<DDim::value_type>({1, 3, im_width, im_height})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor.Run();
}
auto* output_tensors = predictor.GetOutputs();
LOG(INFO) << "======output:========";
for (auto t : *output_tensors) {
LOG(INFO) << t;
}
LOG(INFO)
<< "=====RUN_finished!!============= Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
#endif
}
} // namespace lite
} // namespace paddle
......@@ -21,6 +21,8 @@
#ifndef LITE_WITH_FPGA
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(flatten, kHost, kAny, kAny, def);
USE_LITE_KERNEL(flatten2, kHost, kAny, kAny, def);
#else
USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def);
USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def);
......
......@@ -73,9 +73,12 @@ USE_LITE_OP(prior_box)
USE_LITE_OP(density_prior_box)
USE_LITE_OP(reshape)
USE_LITE_OP(reshape2)
USE_LITE_OP(flatten)
USE_LITE_OP(flatten2)
USE_LITE_OP(split)
USE_LITE_OP(fake_quantize_moving_average_abs_max);
USE_LITE_OP(fake_dequantize_max_abs);
USE_LITE_OP(fake_quantize_range_abs_max);
USE_LITE_OP(calib);
USE_LITE_OP(calib_once);
USE_LITE_OP(norm);
......
......@@ -23,6 +23,9 @@ DEFINE_string(model_dir, "", "model dir");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(im_width, 224, "image width");
DEFINE_int32(im_height, 224, "image height");
DEFINE_bool(int8, false, "is run int8");
namespace paddle {
namespace lite {
......
......@@ -51,7 +51,7 @@ void density_prior_box(const lite::Tensor* input,
const std::vector<float>& min_size_,
const std::vector<float>& fixed_size_,
const std::vector<float>& fixed_ratio_,
const std::vector<float>& density_size_,
const std::vector<int>& density_size_,
const std::vector<float>& max_size_,
const std::vector<float>& aspect_ratio_,
const std::vector<float>& variance_,
......@@ -82,14 +82,12 @@ void density_prior_box(const lite::Tensor* input,
img_width = image->dims()[3];
img_height = image->dims()[2];
}
float step_w = step_w_;
float step_h = step_h_;
if (step_w == 0 || step_h == 0) {
step_w = static_cast<float>(img_width) / width;
step_h = static_cast<float>(img_height) / height;
}
float offset = offset_;
int step_average = static_cast<int>((step_w + step_h) * 0.5); // add
int channel_size = height * width * prior_num_ * 4;
......@@ -343,7 +341,7 @@ void prior_box(const lite::Tensor* input,
min_size,
std::vector<float>(),
std::vector<float>(),
std::vector<float>(),
std::vector<int>(),
max_size,
aspect_ratio,
variance,
......
......@@ -30,7 +30,7 @@ void density_prior_box(const lite::Tensor* input,
const std::vector<float>& min_size_,
const std::vector<float>& fixed_size_,
const std::vector<float>& fixed_ratio_,
const std::vector<float>& density_size_,
const std::vector<int>& density_size_,
const std::vector<float>& max_size_,
const std::vector<float>& aspect_ratio_,
const std::vector<float>& variance_,
......
......@@ -26,17 +26,49 @@ namespace paddle {
namespace lite {
namespace profile {
template <typename dtype>
static void write_tensorfile(const Tensor* tensor, const std::string& locate) {
if (locate.find('/') != std::string::npos) {
return;
}
FILE* fp = fopen(locate.c_str(), "w");
if (fp == nullptr) {
LOG(ERROR) << "file open field " << locate;
} else {
const dtype* data = tensor->data<dtype>();
for (int i = 0; i < tensor->numel(); ++i) {
fprintf(fp, "[%d] %f \n", i, static_cast<float>(data[i]));
}
}
fclose(fp);
}
class PrecisionProfiler {
public:
explicit PrecisionProfiler(const Instruction* inst) : inst_(inst) {}
~PrecisionProfiler() {
LOG(INFO) << ">> Running kernel: " << inst_->op()->op_info()->Repr()
<< " on Target " << TargetToStr(inst_->kernel()->target());
auto tensor_mean = [](const Tensor* in, PrecisionType ptype) -> double {
<< " on Target " << TargetToStr(inst_->kernel()->target()) << " "
<< PrecisionToStr(inst_->kernel()->precision());
auto tensor_mean = [](const Tensor* in,
PrecisionType ptype,
std::string name = "inst") -> double {
if (!in->data<int8_t>()) {
return -99999;
}
double sum = 0.;
switch (ptype) {
case PRECISION(kFloat): {
auto ptr = in->data<float>();
// write_tensorfile<float>(in, name);
for (int i = 0; i < in->numel(); ++i) {
sum += ptr[i];
}
return sum / in->numel();
}
case PRECISION(kAny): {
auto ptr = in->data<float>();
// write_tensorfile<float>(in, name);
for (int i = 0; i < in->numel(); ++i) {
sum += ptr[i];
}
......@@ -44,6 +76,7 @@ class PrecisionProfiler {
}
case PRECISION(kInt8): {
auto ptr = in->data<int8_t>();
// write_tensorfile<int8_t>(in, name);
for (int i = 0; i < in->numel(); ++i) {
sum += ptr[i];
}
......@@ -51,6 +84,7 @@ class PrecisionProfiler {
}
case PRECISION(kInt32): {
auto ptr = in->data<int32_t>();
// write_tensorfile<int32_t>(in, name);
for (int i = 0; i < in->numel(); ++i) {
sum += ptr[i];
}
......@@ -70,17 +104,18 @@ class PrecisionProfiler {
std::string out_arg_name;
op->op_info()->GetOutputArgname(out_name, &out_arg_name);
auto type = kernel->GetOutputDeclType(out_arg_name);
if (type->IsTensor()) {
auto tout = op_scope->FindVar(out_name)->GetMutable<Tensor>();
double mean = tensor_mean(tout, type->precision());
double mean = tensor_mean(tout, type->precision(), out_name);
LOG(INFO) << "output name: " << out_name << ", dims: " << tout->dims()
<< ", precision: " << PrecisionToStr(type->precision())
<< ", mean value: " << mean;
<< ", mean value: " << mean << " shape:" << tout->dims();
} else if (type->IsTensorList()) {
auto tout =
op_scope->FindVar(out_name)->GetMutable<std::vector<Tensor>>();
for (auto& t : *tout) {
double mean = tensor_mean(&t, type->precision());
double mean = tensor_mean(&t, type->precision(), out_name);
LOG(INFO) << "output name: " << out_name << ", dims: " << t.dims()
<< ", precision: " << PrecisionToStr(type->precision())
<< ", mean value: " << mean;
......
......@@ -48,13 +48,12 @@ inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
void DensityPriorBoxCompute::Run() {
auto& param = Param<operators::DensityPriorBoxParam>();
bool is_flip = param.flip;
bool is_clip = param.clip;
std::vector<float> min_size = param.min_sizes;
std::vector<float> fixed_size = param.fixed_sizes;
std::vector<float> fixed_ratio = param.fixed_ratios;
std::vector<float> density_size = param.density_sizes;
auto density_size = param.density_sizes;
std::vector<float> max_size = param.max_sizes;
std::vector<float> aspect_ratio = param.aspect_ratios;
std::vector<float> variance = param.variances_;
......
......@@ -93,3 +93,40 @@ REGISTER_LITE_KERNEL(reshape2,
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
REGISTER_LITE_KERNEL(flatten,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("Shape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
REGISTER_LITE_KERNEL(flatten2,
kHost,
kAny,
kAny,
paddle::lite::kernels::host::ReshapeCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindInput("Shape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("XShape",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize();
......@@ -9,6 +9,7 @@ lite_cc_library(matmul_op SRCS matmul_op.cc DEPS ${op_DEPS})
lite_cc_library(scale_op SRCS scale_op.cc DEPS ${op_DEPS})
lite_cc_library(softmax_op SRCS softmax_op.cc DEPS ${op_DEPS})
lite_cc_library(reshape_op SRCS reshape_op.cc DEPS ${op_DEPS} )
lite_cc_library(flatten_op SRCS flatten_op.cc DEPS ${op_DEPS} )
lite_cc_library(batch_norm_op SRCS batch_norm_op.cc DEPS ${op_DEPS})
lite_cc_library(feed_op SRCS feed_op.cc DEPS ${op_DEPS})
lite_cc_library(fetch_op SRCS fetch_op.cc DEPS ${op_DEPS})
......@@ -52,6 +53,7 @@ lite_cc_library(calib_once_op SRCS calib_once_op.cc DEPS ${op_DEPS})
lite_cc_library(split_op SRCS split_op.cc DEPS ${op_DEPS})
lite_cc_library(transpose_op SRCS transpose_op.cc DEPS ${op_DEPS})
lite_cc_library(fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
lite_cc_library(fake_quant_range SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS})
lite_cc_library(fake_dequant SRCS fake_dequantize_max_abs.cc DEPS ${op_DEPS})
lite_cc_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS ${op_DEPS})
lite_cc_library(im2sequence_op SRCS im2sequence_op.cc DEPS ${op_DEPS})
......@@ -96,6 +98,7 @@ set(ops
scale_op
softmax_op
reshape_op
flatten_op
batch_norm_op
feed_op
fetch_op
......@@ -128,6 +131,7 @@ set(ops
split_op
transpose_op
fake_quant
fake_quant_range
fake_dequant
sgd_op
uniform_random_op
......
......@@ -85,7 +85,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
}
}
}
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
}
return true;
}
......
......@@ -41,15 +41,29 @@ bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc,
param_.boxes = scope->FindVar(boxes)->GetMutable<lite::Tensor>();
param_.variances = scope->FindVar(variances)->GetMutable<lite::Tensor>();
param_.flip = opdesc.GetAttr<bool>("flip");
param_.clip = opdesc.GetAttr<bool>("clip");
param_.min_sizes = opdesc.GetAttr<std::vector<float>>("min_sizes");
param_.fixed_sizes = opdesc.GetAttr<std::vector<float>>("fixed_sizes");
param_.fixed_ratios = opdesc.GetAttr<std::vector<float>>("fixed_ratios");
param_.density_sizes = opdesc.GetAttr<std::vector<float>>("density_sizes");
param_.max_sizes = opdesc.GetAttr<std::vector<float>>("max_sizes");
param_.aspect_ratios = opdesc.GetAttr<std::vector<float>>("aspect_ratios");
param_.variances_ = opdesc.GetAttr<std::vector<float>>("variances");
if (opdesc.HasAttr("aspect_ratios")) {
param_.aspect_ratios = opdesc.GetAttr<std::vector<float>>("aspect_ratios");
}
if (opdesc.HasAttr("max_sizes")) {
param_.max_sizes = opdesc.GetAttr<std::vector<float>>("max_sizes");
}
if (opdesc.HasAttr("density_sizes")) {
param_.density_sizes = opdesc.GetAttr<std::vector<int>>("density_sizes");
}
if (opdesc.HasAttr("densities")) {
param_.density_sizes = opdesc.GetAttr<std::vector<int>>("densities");
}
if (opdesc.HasAttr("min_sizes")) {
param_.min_sizes = opdesc.GetAttr<std::vector<float>>("min_sizes");
}
if (opdesc.HasAttr("flip")) {
param_.flip = opdesc.GetAttr<bool>("flip");
}
if (opdesc.HasAttr("img_w")) {
param_.img_w = opdesc.GetAttr<int>("img_w");
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_quantize_range_abs_max.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fake_quantize_range_abs_max,
paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
public:
FakeQuantizeRangeMaxAbsOpLite() {}
explicit FakeQuantizeRangeMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("InScale").front();
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
param_.bit_length = op_desc.GetAttr<int>("bit_length");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_quantize_range_max_abs";
}
private:
mutable FakeQuantizeMovingAvgMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/flatten_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool FlattenOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool FlattenOp::InferShape() const {
auto x_dims = param_.x->dims();
auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod();
int64_t outer = 1, inner = 1;
for (int i = 0; i < x_dims.size(); ++i) {
if (i < axis_) {
outer *= x_dims[i];
} else {
inner *= x_dims[i];
}
}
std::vector<int64_t> out_shape(2);
out_shape[0] = outer;
out_shape[1] = inner;
param_.output->Resize(out_shape);
return true;
}
bool FlattenOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>();
axis_ = opdesc.GetAttr<int>("axis");
param_.inplace = false;
CHECK(param_.x) << "Input(X) of FlattenOp should not be null.";
CHECK(param_.output) << "Output(Out) of FlattenOp should not be null.";
CHECK_GE(axis_, 0) << "Flatten op axis should >=0.";
return true;
}
bool Flatten2Op::CheckShape() const {
FlattenOp::CheckShape();
CHECK_OR_FALSE(param_.xshape);
return true;
}
bool Flatten2Op::InferShape() const {
FlattenOp::InferShape();
auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
param_.xshape->Resize(DDim(xshape_dims));
return true;
}
bool Flatten2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
FlattenOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.xshape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.xshape) << "Output(XShape) of FlattenOp should not be null.";
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(flatten, paddle::lite::operators::FlattenOp);
REGISTER_LITE_OP(flatten2, paddle::lite::operators::Flatten2Op);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FlattenOp : public OpLite {
public:
FlattenOp() {}
explicit FlattenOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "flatten"; }
protected:
mutable ReshapeParam param_;
int axis_;
};
class Flatten2Op : public FlattenOp {
public:
Flatten2Op() : FlattenOp() {}
explicit Flatten2Op(const std::string &op_type) : FlattenOp(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "flatten2"; }
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -521,7 +521,7 @@ struct PriorBoxParam {
struct DensityPriorBoxParam : public PriorBoxParam {
std::vector<float> fixed_sizes;
std::vector<float> fixed_ratios;
std::vector<float> density_sizes;
std::vector<int> density_sizes;
};
/// ----------------------- GRU operators ----------------------f
struct GRUParam {
......
......@@ -40,12 +40,14 @@ bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
param_.boxes = scope->FindVar(boxes)->GetMutable<lite::Tensor>();
param_.variances = scope->FindVar(variances)->GetMutable<lite::Tensor>();
param_.flip = opdesc.GetAttr<bool>("flip");
param_.clip = opdesc.GetAttr<bool>("clip");
param_.min_sizes = opdesc.GetAttr<std::vector<float>>("min_sizes");
param_.max_sizes = opdesc.GetAttr<std::vector<float>>("max_sizes");
param_.aspect_ratios = opdesc.GetAttr<std::vector<float>>("aspect_ratios");
param_.variances_ = opdesc.GetAttr<std::vector<float>>("variances");
if (opdesc.HasAttr("flip")) {
param_.flip = opdesc.GetAttr<bool>("flip");
}
if (opdesc.HasAttr("img_w")) {
param_.img_w = opdesc.GetAttr<int>("img_w");
}
......
......@@ -75,7 +75,7 @@ void prior_box_compute_ref(const lite::Tensor* input,
const std::vector<float>& min_size_,
const std::vector<float>& fixed_size_,
const std::vector<float>& fixed_ratio_,
const std::vector<float>& density_size_,
const std::vector<int>& density_size_,
const std::vector<float>& max_size_,
const std::vector<float>& aspect_ratio_,
const std::vector<float>& variance_,
......@@ -352,7 +352,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase {
std::vector<float> min_size_;
std::vector<float> fixed_size_;
std::vector<float> fixed_ratio_;
std::vector<float> density_size_;
std::vector<int> density_size_;
std::vector<float> max_size_;
std::vector<float> aspect_ratio_;
std::vector<float> variance_;
......@@ -375,7 +375,7 @@ class DensityPriorBoxComputeTester : public arena::TestCase {
const std::vector<float>& min_size,
const std::vector<float>& fixed_size,
const std::vector<float>& fixed_ratio,
const std::vector<float>& density_size,
const std::vector<int>& density_size,
const std::vector<float>& max_size,
const std::vector<float>& aspect_ratio,
const std::vector<float>& variance,
......@@ -561,7 +561,7 @@ class PriorBoxComputeTester : public arena::TestCase {
min_size_,
std::vector<float>(),
std::vector<float>(),
std::vector<float>(),
std::vector<int>(),
max_size_,
aspect_ratio_,
variance_,
......@@ -621,7 +621,7 @@ void test_density_prior_box(Place place) {
std::vector<float> variance{0.1f, 0.1f, 0.2f, 0.2f};
std::vector<float> fixed_size{60, 30};
std::vector<float> fixed_ratio{1., 2.};
std::vector<float> density_size{1., 3.};
std::vector<int> density_size{1, 3};
bool flip = true;
bool clip = false;
float step_h = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册